In Article #4, we trained a network that was sparse from initialization.
But there was a limitation:
The mask never changed.
If we chose the wrong connections at the beginning, the model had no way to recover.
This is where Dynamic Sparse Training (DST) enters.
Instead of keeping the sparsity pattern fixed, we:
- Periodically prune weak connections
- Regrow new connections
- Keep total sparsity constant
The network rewires itself during training.
This is the core idea behind methods like RigL and modern sparse scaling strategies.
As always — Python first.
1. Static vs Dynamic Sparsity
Static sparsity:
• Mask chosen once
• Zero weights remain zero forever
• No structural adaptation
Dynamic sparsity:
• Weak connections are removed
• New connections are added
• Total number of active weights remains constant
This turns sparsity into a learning mechanism — not just a constraint.
2. High-Level Algorithm
Every T training steps:
- Compute weight magnitudes
- Remove smallest fraction (prune step)
- Select new connections (regrow step)
- Continue training
We maintain fixed overall sparsity throughout.
3. Sparse Layer With Rewiring Support
import torchimport torch.nn as nnclass DynamicSparseLinear(nn.Module): def __init__(self, in_features, out_features, sparsity): super().__init__() self.weight = nn.Parameter( torch.randn(out_features, in_features) * 0.01 ) self.bias = nn.Parameter(torch.zeros(out_features)) mask = torch.rand(out_features, in_features) mask = (mask > sparsity).float() self.register_buffer("mask", mask) def forward(self, x): return x @ (self.weight * self.mask).t() + self.bias
This is similar to Article #4 — but now we will modify the mask during training.
4. Pruning Step
We remove the smallest active weights.
def prune_weights(layer, prune_fraction): weight = layer.weight.data mask = layer.mask active_weights = weight[mask.bool()] threshold = torch.quantile(active_weights.abs(), prune_fraction) prune_mask = (weight.abs() > threshold).float() layer.mask *= prune_mask
This ensures:
• Only currently active weights are considered
• A fraction is removed
5. Regrowth Step
We now regrow the same number of connections randomly.
def regrow_weights(layer): weight = layer.weight.data mask = layer.mask num_to_regrow = int((mask == 0).sum().item() * 0.1) zero_indices = (mask == 0).nonzero(as_tuple=False) if len(zero_indices) == 0: return selected = zero_indices[ torch.randperm(len(zero_indices))[:num_to_regrow] ] mask[selected[:, 0], selected[:, 1]] = 1.0
This keeps total sparsity roughly constant.
A more advanced approach would use gradient information to decide where to regrow.
6. Training Loop With Rewiring
import torch.optim as optimdef train_dynamic(model, X, y, epochs=20, rewire_every=5): optimizer = optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() for epoch in range(epochs): logits = model(X) loss = loss_fn(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % rewire_every == 0 and epoch > 0: for module in model.modules(): if isinstance(module, DynamicSparseLinear): prune_weights(module, prune_fraction=0.2) regrow_weights(module) return loss.item()
7. Full Model Definition
class DynamicSparseMLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, sparsity): super().__init__() self.fc1 = DynamicSparseLinear(input_dim, hidden_dim, sparsity) self.fc2 = DynamicSparseLinear(hidden_dim, hidden_dim, sparsity) self.fc3 = DynamicSparseLinear(hidden_dim, output_dim, sparsity) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) return self.fc3(x)
8. Why Dynamic Sparse Training Matters
Static sparse training can fail if initial masks are poor.
Dynamic sparsity:
• Explores alternative connectivity
• Recovers from bad initial structure
• Approaches dense performance with sparse compute
This is why modern sparse research rarely uses fixed masks.
9. Minimal Math Intuition
We maintain constant parameter budget:
Total active weights = constant
But we change their positions over time.
This turns training into a search over sparse connectivity graphs.
Instead of optimizing only weights,
we implicitly optimize structure.
10. What Comes Next
We have now covered:
- Structural sparsity
- Weight pruning
- Activation sparsity
- Sparse training from scratch
- Dynamic sparse training
The next step moves us toward modern large-scale systems:
Sparse Attention and Mixture of Experts (MoE)
This is where sparsity becomes conditional computation at scale.
Code Location
All code for this article lives in:
05_dynamic_sparse_training/
Suggested experiments:
• Compare static vs dynamic sparse models
• Increase rewiring frequency
• Implement gradient-based regrowth
Dynamic sparsity transforms sparsity from compression
into a structural learning strategy.