In the previous articles, we followed a familiar path:
- Train a dense model
- Remove unimportant weights (pruning)
- Fine‑tune
This works — but it raises a deeper question:
Do we really need to start dense at all?
Modern research increasingly shows that the answer is no.
In this article, you will:
- Build a sparse layer that enforces sparsity during training
- Prevent zeroed weights from updating
- Train a network that is sparse from initialization
- Compare it conceptually to prune‑after‑training
As always: Python first. Math only where necessary.
1. Why Avoid Dense Pretraining?
Dense pretraining has two inefficiencies:
• You compute gradients for weights that will later be deleted
• You store parameters you never truly needed
• You waste memory bandwidth during training
If we already believe most weights are unnecessary,
why not start sparse from the beginning?
This approach is called:
Sparse Training From Scratch
2. The Core Idea
We want a layer that:
• Has a fixed sparsity level
• Keeps zero weights permanently zero
• Only updates active connections
We accomplish this using a binary mask that is applied at every forward pass.
The difference from Article #1:
Now we allow training — but only through active weights.
3. Implementing a Trainable Sparse Linear Layer
import torchimport torch.nn as nnclass SparseLinear(nn.Module): def __init__(self, in_features, out_features, sparsity): super().__init__() self.in_features = in_features self.out_features = out_features # Initialize weights normally self.weight = nn.Parameter( torch.randn(out_features, in_features) * 0.01 ) self.bias = nn.Parameter(torch.zeros(out_features)) # Create fixed binary mask mask = torch.rand(out_features, in_features) mask = (mask > sparsity).float() self.register_buffer("mask", mask) def forward(self, x): # Enforce sparsity every forward pass sparse_weight = self.weight * self.mask return x @ sparse_weight.t() + self.bias
Key properties:
• Mask is not trainable
• Zero weights never contribute
• Gradients flow only through active connections
4. Building a Sparse-From-Scratch MLP
class SparseMLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, sparsity): super().__init__() self.fc1 = SparseLinear(input_dim, hidden_dim, sparsity) self.fc2 = SparseLinear(hidden_dim, hidden_dim, sparsity) self.fc3 = SparseLinear(hidden_dim, output_dim, sparsity) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) return self.fc3(x)
5. Training the Sparse Model
We now train this network normally.
import torch.optim as optimdef generate_data(n=2048, input_dim=100, num_classes=10): X = torch.randn(n, input_dim) y = torch.randint(0, num_classes, (n,)) return X, ydef train(model, X, y, epochs=20): optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() for _ in range(epochs): logits = model(X) loss = loss_fn(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()
Instantiate and train:
model = SparseMLP(100, 256, 10, sparsity=0.8)X, y = generate_data()final_loss = train(model, X, y)print("Final loss (sparse from scratch):", final_loss)
This model:
• Started with 80% zeros
• Stayed 80% sparse
• Never trained unused weights
6. Verifying Sparsity
def measure_sparsity(model):
total = 0
zeros = 0
for module in model.modules():
if isinstance(module, SparseLinear):
total += module.mask.numel()
zeros += (module.mask == 0).sum().item()
return zeros / total
print("Model sparsity:", measure_sparsity(model))
Expected output:
Model sparsity: ~0.8
Sparsity is structural — not accidental.
7. Minimal Math (Why This Works)
Standard training updates:
w ← w − η ∇L
With masking:
w_masked = w ⊙ m
If mᵢⱼ = 0, then:
∂L/∂wᵢⱼ = 0
Those parameters are permanently inactive.
We reduce the parameter space from the beginning.
8. Dense → Prune vs Sparse From Scratch
Dense → Prune:
• Train full network
• Remove small weights
• Fine‑tune
Sparse From Scratch:
• Start sparse
• Train only active weights
• Maintain constant sparsity
The second approach avoids wasted compute during early training.
9. Limitations
Static sparsity has one major weakness:
The mask never changes.
What if the wrong connections were chosen initially?
That leads us to the next major concept:
Dynamic Sparse Training
Where connections are pruned and regrown during training.
10. What Comes Next
In Article #5, we implement:
• Periodic pruning
• Gradient‑based regrowth
• Rewiring sparse networks during learning
This is where sparse training becomes competitive with dense scaling.
Code Location
All code for this article lives in:
04_sparse_from_scratch/
Suggested experiments:
• Try 90% sparsity
• Compare dense vs sparse training curves
• Combine with activation sparsity
Sparsity is no longer a compression trick.
It is a training strategy.