SolveWithPython

Magnitude‑Based Pruning in Python: Removing Weights That Don’t Matter

In Article #1, we introduced sparsity by design using fixed masks.

That raised the obvious next question:

If we want a sparse network, which weights should we remove — and why?

This article answers that question using the simplest and most effective pruning strategy ever proposed:

Magnitude‑based pruning

You will:

  • Train a dense neural network
  • Identify unimportant weights using their magnitude
  • Prune aggressively (60–90%)
  • Evaluate what actually changes

Once again, Python code comes first, with just enough math to explain what’s happening.

1. Why Weight Magnitude Matters

During training, neural networks tend to behave in a very uneven way:

  • A small subset of weights carries most of the signal
  • Many weights shrink toward zero
  • Some weights never meaningfully contribute

This observation leads to a practical heuristic:

If a weight’s absolute value is very small, removing it will barely affect the output.

This is not a theory — it is an empirical fact observed across architectures and datasets.

2. What Is Magnitude‑Based Pruning?

Magnitude‑based pruning removes weights according to this rule:

Remove weights with the smallest |weight|

More formally:

  1. Train a dense model
  2. Collect all weights in a layer (or model)
  3. Compute absolute values
  4. Remove the lowest p%

Where p is the desired pruning ratio.

3. Training a Dense Baseline Model

We start with the same dense MLP from Article #1, now with a training loop.

Dense Model

Python
import torch
import torch.nn as nn
import torch.optim as optim
class DenseMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)

Dummy Dataset (for Demonstration)

Python
def generate_data(n=2048, input_dim=100, num_classes=10):
X = torch.randn(n, input_dim)
y = torch.randint(0, num_classes, (n,))
return X, y

Training Loop

Python
def train(model, X, y, epochs=20):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
logits = model(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()

Train the Model

Python
model = DenseMLP(100, 256, 10)
X, y = generate_data()
final_loss = train(model, X, y)
print("Final dense loss:", final_loss)

This trained model is our reference point.

4. Implementing Global Magnitude‑Based Pruning

We now remove a fraction of the smallest‑magnitude weights.

Pruning Function

Python
def global_magnitude_prune(model, pruning_ratio):
# Collect all weights
weights = []
for param in model.parameters():
if param.dim() > 1:
weights.append(param.data.view(-1))
all_weights = torch.cat(weights)
threshold = torch.quantile(all_weights.abs(), pruning_ratio)
# Apply pruning
for param in model.parameters():
if param.dim() > 1:
mask = param.data.abs() > threshold
param.data *= mask

Key points:

  • Biases are left untouched
  • Pruning is applied globally, not layer‑by‑layer

5. Pruning the Trained Network

Let’s remove 80% of all weights.

global_magnitude_prune(model, pruning_ratio=0.8)

At this point:

  • 80% of weights are exactly zero
  • The architecture is unchanged
  • No retraining has happened yet

6. Measuring Sparsity After Pruning

Python
def measure_sparsity(model):
total = 0
zeros = 0
for param in model.parameters():
if param.dim() > 1:
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total
print("Model sparsity:", measure_sparsity(model))

Expected output:

Python
Model sparsity: ~0.8

The pruning ratio is exactly enforced.

7. Does the Model Still Work?

Now we evaluate the pruned model without retraining.

Python
with torch.no_grad():
loss_after_prune = nn.CrossEntropyLoss()(model(X), y)
print("Loss after pruning:", loss_after_prune.item())

You will typically observe:

  • Slight degradation
  • Often surprisingly small

This is the key empirical insight behind pruning.

8. Optional: Fine‑Tuning After Pruning

Pruning works best when followed by brief fine‑tuning.

Python
fine_tuned_loss = train(model, X, y, epochs=5)
print("Loss after fine‑tuning:", fine_tuned_loss)

In many cases:

  • Performance returns close to baseline
  • Even at very high sparsity levels

9. Minimal Math (Why This Works)

Let the output of a layer be:

y = Wx + b

If a weight w_ij ≈ 0, then:

Δy ≈ w_ij · x_j ≈ 0

Removing it changes the output negligibly.

Magnitude‑based pruning simply exploits this linear sensitivity.

10. Dense → Sparse Is a Two‑Step Process

What we have done so far:

  1. Train a dense model
  2. Remove unimportant weights

This raises the next question:

Do we really need to train dense models first?

The answer is increasingly no.

11. What Comes Next

In Article #3, we will move beyond pruning and introduce:

Activation sparsity — why most neurons should stay silent

We will measure:

  • Active neuron ratios
  • Compute savings per batch
  • Effects on representation

Code Location

All code from this article lives in:

02_pruning/

Experiment freely:

  • Try 90% pruning
  • Compare global vs layer‑wise pruning
  • Visualize weight histograms

Sparsity only becomes convincing once you see how much you can remove — and still learn.