SolveWithPython

Dynamic Sparse Training in Python: Let the Network Rewire Itself

In Article #4, we trained a network that was sparse from initialization.

But there was a limitation:

The mask never changed.

If we chose the wrong connections at the beginning, the model had no way to recover.

This is where Dynamic Sparse Training (DST) enters.

Instead of keeping the sparsity pattern fixed, we:

  1. Periodically prune weak connections
  2. Regrow new connections
  3. Keep total sparsity constant

The network rewires itself during training.

This is the core idea behind methods like RigL and modern sparse scaling strategies.

As always — Python first.

1. Static vs Dynamic Sparsity

Static sparsity:

• Mask chosen once
• Zero weights remain zero forever
• No structural adaptation

Dynamic sparsity:

• Weak connections are removed
• New connections are added
• Total number of active weights remains constant

This turns sparsity into a learning mechanism — not just a constraint.

2. High-Level Algorithm

Every T training steps:

  1. Compute weight magnitudes
  2. Remove smallest fraction (prune step)
  3. Select new connections (regrow step)
  4. Continue training

We maintain fixed overall sparsity throughout.

3. Sparse Layer With Rewiring Support

Python
import torch
import torch.nn as nn
class DynamicSparseLinear(nn.Module):
def __init__(self, in_features, out_features, sparsity):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_features, in_features) * 0.01
)
self.bias = nn.Parameter(torch.zeros(out_features))
mask = torch.rand(out_features, in_features)
mask = (mask > sparsity).float()
self.register_buffer("mask", mask)
def forward(self, x):
return x @ (self.weight * self.mask).t() + self.bias

This is similar to Article #4 — but now we will modify the mask during training.

4. Pruning Step

We remove the smallest active weights.

Python
def prune_weights(layer, prune_fraction):
weight = layer.weight.data
mask = layer.mask
active_weights = weight[mask.bool()]
threshold = torch.quantile(active_weights.abs(), prune_fraction)
prune_mask = (weight.abs() > threshold).float()
layer.mask *= prune_mask

This ensures:

• Only currently active weights are considered
• A fraction is removed

5. Regrowth Step

We now regrow the same number of connections randomly.

Python
def regrow_weights(layer):
weight = layer.weight.data
mask = layer.mask
num_to_regrow = int((mask == 0).sum().item() * 0.1)
zero_indices = (mask == 0).nonzero(as_tuple=False)
if len(zero_indices) == 0:
return
selected = zero_indices[
torch.randperm(len(zero_indices))[:num_to_regrow]
]
mask[selected[:, 0], selected[:, 1]] = 1.0

This keeps total sparsity roughly constant.

A more advanced approach would use gradient information to decide where to regrow.

6. Training Loop With Rewiring

Python
import torch.optim as optim
def train_dynamic(model, X, y, epochs=20, rewire_every=5):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
logits = model(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % rewire_every == 0 and epoch > 0:
for module in model.modules():
if isinstance(module, DynamicSparseLinear):
prune_weights(module, prune_fraction=0.2)
regrow_weights(module)
return loss.item()

7. Full Model Definition

Python
class DynamicSparseMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, sparsity):
super().__init__()
self.fc1 = DynamicSparseLinear(input_dim, hidden_dim, sparsity)
self.fc2 = DynamicSparseLinear(hidden_dim, hidden_dim, sparsity)
self.fc3 = DynamicSparseLinear(hidden_dim, output_dim, sparsity)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.fc3(x)

8. Why Dynamic Sparse Training Matters

Static sparse training can fail if initial masks are poor.

Dynamic sparsity:

• Explores alternative connectivity
• Recovers from bad initial structure
• Approaches dense performance with sparse compute

This is why modern sparse research rarely uses fixed masks.

9. Minimal Math Intuition

We maintain constant parameter budget:

Total active weights = constant

But we change their positions over time.

This turns training into a search over sparse connectivity graphs.

Instead of optimizing only weights,
we implicitly optimize structure.

10. What Comes Next

We have now covered:

  1. Structural sparsity
  2. Weight pruning
  3. Activation sparsity
  4. Sparse training from scratch
  5. Dynamic sparse training

The next step moves us toward modern large-scale systems:

Sparse Attention and Mixture of Experts (MoE)

This is where sparsity becomes conditional computation at scale.

Code Location

All code for this article lives in:

05_dynamic_sparse_training/

Suggested experiments:

• Compare static vs dynamic sparse models
• Increase rewiring frequency
• Implement gradient-based regrowth

Dynamic sparsity transforms sparsity from compression
into a structural learning strategy.