SolveWithPython

Training Sparse Neural Networks From Scratch (No Dense Pretraining)

In the previous articles, we followed a familiar path:

  1. Train a dense model
  2. Remove unimportant weights (pruning)
  3. Fine‑tune

This works — but it raises a deeper question:

Do we really need to start dense at all?

Modern research increasingly shows that the answer is no.

In this article, you will:

  • Build a sparse layer that enforces sparsity during training
  • Prevent zeroed weights from updating
  • Train a network that is sparse from initialization
  • Compare it conceptually to prune‑after‑training

As always: Python first. Math only where necessary.

1. Why Avoid Dense Pretraining?

Dense pretraining has two inefficiencies:

• You compute gradients for weights that will later be deleted
• You store parameters you never truly needed
• You waste memory bandwidth during training

If we already believe most weights are unnecessary,
why not start sparse from the beginning?

This approach is called:

Sparse Training From Scratch

2. The Core Idea

We want a layer that:

• Has a fixed sparsity level
• Keeps zero weights permanently zero
• Only updates active connections

We accomplish this using a binary mask that is applied at every forward pass.

The difference from Article #1:

Now we allow training — but only through active weights.

3. Implementing a Trainable Sparse Linear Layer

Python
import torch
import torch.nn as nn
class SparseLinear(nn.Module):
def __init__(self, in_features, out_features, sparsity):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize weights normally
self.weight = nn.Parameter(
torch.randn(out_features, in_features) * 0.01
)
self.bias = nn.Parameter(torch.zeros(out_features))
# Create fixed binary mask
mask = torch.rand(out_features, in_features)
mask = (mask > sparsity).float()
self.register_buffer("mask", mask)
def forward(self, x):
# Enforce sparsity every forward pass
sparse_weight = self.weight * self.mask
return x @ sparse_weight.t() + self.bias

Key properties:

• Mask is not trainable
• Zero weights never contribute
• Gradients flow only through active connections

4. Building a Sparse-From-Scratch MLP

Python
class SparseMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, sparsity):
super().__init__()
self.fc1 = SparseLinear(input_dim, hidden_dim, sparsity)
self.fc2 = SparseLinear(hidden_dim, hidden_dim, sparsity)
self.fc3 = SparseLinear(hidden_dim, output_dim, sparsity)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.fc3(x)

5. Training the Sparse Model

We now train this network normally.

Python
import torch.optim as optim
def generate_data(n=2048, input_dim=100, num_classes=10):
X = torch.randn(n, input_dim)
y = torch.randint(0, num_classes, (n,))
return X, y
def train(model, X, y, epochs=20):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for _ in range(epochs):
logits = model(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()

Instantiate and train:

Python
model = SparseMLP(100, 256, 10, sparsity=0.8)
X, y = generate_data()
final_loss = train(model, X, y)
print("Final loss (sparse from scratch):", final_loss)

This model:

• Started with 80% zeros
• Stayed 80% sparse
• Never trained unused weights

6. Verifying Sparsity


def measure_sparsity(model):
    total = 0
    zeros = 0
    
    for module in model.modules():
        if isinstance(module, SparseLinear):
            total += module.mask.numel()
            zeros += (module.mask == 0).sum().item()

    return zeros / total


print("Model sparsity:", measure_sparsity(model))


Expected output:

Model sparsity: ~0.8

Sparsity is structural — not accidental.

7. Minimal Math (Why This Works)

Standard training updates:

w ← w − η ∇L

With masking:

w_masked = w ⊙ m

If mᵢⱼ = 0, then:

∂L/∂wᵢⱼ = 0

Those parameters are permanently inactive.

We reduce the parameter space from the beginning.

8. Dense → Prune vs Sparse From Scratch

Dense → Prune:

• Train full network
• Remove small weights
• Fine‑tune

Sparse From Scratch:

• Start sparse
• Train only active weights
• Maintain constant sparsity

The second approach avoids wasted compute during early training.

9. Limitations

Static sparsity has one major weakness:

The mask never changes.

What if the wrong connections were chosen initially?

That leads us to the next major concept:

Dynamic Sparse Training

Where connections are pruned and regrown during training.

10. What Comes Next

In Article #5, we implement:

• Periodic pruning
• Gradient‑based regrowth
• Rewiring sparse networks during learning

This is where sparse training becomes competitive with dense scaling.

Code Location

All code for this article lives in:

04_sparse_from_scratch/

Suggested experiments:

• Try 90% sparsity
• Compare dense vs sparse training curves
• Combine with activation sparsity

Sparsity is no longer a compression trick.

It is a training strategy.