In Article #1, we introduced sparsity by design using fixed masks.
That raised the obvious next question:
If we want a sparse network, which weights should we remove — and why?
This article answers that question using the simplest and most effective pruning strategy ever proposed:
Magnitude‑based pruning
You will:
- Train a dense neural network
- Identify unimportant weights using their magnitude
- Prune aggressively (60–90%)
- Evaluate what actually changes
Once again, Python code comes first, with just enough math to explain what’s happening.
1. Why Weight Magnitude Matters
During training, neural networks tend to behave in a very uneven way:
- A small subset of weights carries most of the signal
- Many weights shrink toward zero
- Some weights never meaningfully contribute
This observation leads to a practical heuristic:
If a weight’s absolute value is very small, removing it will barely affect the output.
This is not a theory — it is an empirical fact observed across architectures and datasets.
2. What Is Magnitude‑Based Pruning?
Magnitude‑based pruning removes weights according to this rule:
Remove weights with the smallest |weight|
More formally:
- Train a dense model
- Collect all weights in a layer (or model)
- Compute absolute values
- Remove the lowest p%
Where p is the desired pruning ratio.
3. Training a Dense Baseline Model
We start with the same dense MLP from Article #1, now with a training loop.
Dense Model
import torchimport torch.nn as nnimport torch.optim as optimclass DenseMLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): return self.net(x)
Dummy Dataset (for Demonstration)
def generate_data(n=2048, input_dim=100, num_classes=10): X = torch.randn(n, input_dim) y = torch.randint(0, num_classes, (n,)) return X, y
Training Loop
def train(model, X, y, epochs=20): optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() for epoch in range(epochs): logits = model(X) loss = loss_fn(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()
Train the Model
model = DenseMLP(100, 256, 10)X, y = generate_data()final_loss = train(model, X, y)print("Final dense loss:", final_loss)
This trained model is our reference point.
4. Implementing Global Magnitude‑Based Pruning
We now remove a fraction of the smallest‑magnitude weights.
Pruning Function
def global_magnitude_prune(model, pruning_ratio): # Collect all weights weights = [] for param in model.parameters(): if param.dim() > 1: weights.append(param.data.view(-1)) all_weights = torch.cat(weights) threshold = torch.quantile(all_weights.abs(), pruning_ratio) # Apply pruning for param in model.parameters(): if param.dim() > 1: mask = param.data.abs() > threshold param.data *= mask
Key points:
- Biases are left untouched
- Pruning is applied globally, not layer‑by‑layer
5. Pruning the Trained Network
Let’s remove 80% of all weights.
global_magnitude_prune(model, pruning_ratio=0.8)
At this point:
- 80% of weights are exactly zero
- The architecture is unchanged
- No retraining has happened yet
6. Measuring Sparsity After Pruning
def measure_sparsity(model): total = 0 zeros = 0 for param in model.parameters(): if param.dim() > 1: total += param.numel() zeros += (param == 0).sum().item() return zeros / totalprint("Model sparsity:", measure_sparsity(model))
Expected output:
Model sparsity: ~0.8
The pruning ratio is exactly enforced.
7. Does the Model Still Work?
Now we evaluate the pruned model without retraining.
with torch.no_grad(): loss_after_prune = nn.CrossEntropyLoss()(model(X), y)print("Loss after pruning:", loss_after_prune.item())
You will typically observe:
- Slight degradation
- Often surprisingly small
This is the key empirical insight behind pruning.
8. Optional: Fine‑Tuning After Pruning
Pruning works best when followed by brief fine‑tuning.
fine_tuned_loss = train(model, X, y, epochs=5)print("Loss after fine‑tuning:", fine_tuned_loss)
In many cases:
- Performance returns close to baseline
- Even at very high sparsity levels
9. Minimal Math (Why This Works)
Let the output of a layer be:
y = Wx + b
If a weight w_ij ≈ 0, then:
Δy ≈ w_ij · x_j ≈ 0
Removing it changes the output negligibly.
Magnitude‑based pruning simply exploits this linear sensitivity.
10. Dense → Sparse Is a Two‑Step Process
What we have done so far:
- Train a dense model
- Remove unimportant weights
This raises the next question:
Do we really need to train dense models first?
The answer is increasingly no.
11. What Comes Next
In Article #3, we will move beyond pruning and introduce:
Activation sparsity — why most neurons should stay silent
We will measure:
- Active neuron ratios
- Compute savings per batch
- Effects on representation
Code Location
All code from this article lives in:
02_pruning/
Experiment freely:
- Try 90% pruning
- Compare global vs layer‑wise pruning
- Visualize weight histograms
Sparsity only becomes convincing once you see how much you can remove — and still learn.