Backpropagation
Backpropagation is the chain rule, applied carefully to a computational graph. It’s how every neural network learns. Frameworks hide the details, but you should understand them — every weird training bug eventually involves gradients.
The setup
A loss L(θ) is a composition of operations:
x → W₁ → σ → W₂ → σ → W₃ → loss(L)
Each operation has a local Jacobian (how its output changes with its input/parameters). Backprop computes ∂L/∂θᵢ for every parameter by walking the graph backwards, multiplying local Jacobians.
A worked example: 2-layer MLP
Forward pass:
z₁ = W₁ x + b₁
h₁ = σ(z₁)
z₂ = W₂ h₁ + b₂
ŷ = softmax(z₂)
L = cross_entropy(ŷ, y)
Backward pass (using δ for ∂L/∂(thing)):
δz₂ = ŷ − y # cross-entropy + softmax cancel beautifully
δW₂ = δz₂ ⊗ h₁ # outer product
δb₂ = δz₂
δh₁ = W₂ᵀ δz₂ # gradient flows back through the linear layer
δz₁ = δh₁ ⊙ σ'(z₁) # element-wise multiply with activation derivative
δW₁ = δz₁ ⊗ x
δb₁ = δz₁
That’s backprop. Every layer reduces to:
- Receive
δoutputfrom the layer above. - Compute
δinputto pass below: multiply by the layer’s Jacobian. - Compute
δparametersfor this layer.
Why it works
The chain rule:
∂L/∂W₁ = (∂L/∂z₂)(∂z₂/∂h₁)(∂h₁/∂z₁)(∂z₁/∂W₁)
You can compute this left-to-right (forward-mode) or right-to-left (reverse-mode). For neural nets, reverse-mode is dramatically faster because:
- Networks have a small loss output and many parameters.
- Going backward, each step has a
(1, k)×(k, m)shape — cheap. - Going forward, you’d need to recompute Jacobians for each parameter.
So gradient computation in deep learning is reverse-mode automatic differentiation.
Computational graphs
PyTorch builds a graph dynamically as you do operations on tensors with requires_grad=True. Each operation is recorded with its inputs and a backward function. When you call loss.backward(), the graph is walked in reverse.
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x ** 2
z = y.sum()
z.backward()
print(x.grad) # [2.0, 4.0]
JAX uses a different model — explicit grad/vmap/jit transformations on pure functions. Same idea, different ergonomics.
Vanishing and exploding gradients
When you multiply many small numbers, you get something tiny. When you multiply many big numbers, you get something huge. Both kill training.
In a deep network with sigmoid activations (max derivative 0.25), gradients shrink geometrically with depth. By layer 10, they’re effectively zero.
Mitigations:
- Better activations. ReLU has derivative 1 in the active region — gradients pass through unchanged.
- Residual connections.
output = layer(x) + x— gradients can flow around the layer via the addition. - Normalization. BatchNorm/LayerNorm keep activations in a healthy range.
- Initialization. Xavier/He scaling keeps the variance of activations stable layer-to-layer.
- Gradient clipping. Cap exploding gradients to prevent training spikes.
Detached computation and no_grad
Sometimes you don’t want gradients:
with torch.no_grad():
predictions = model(x)
Disables graph construction. Saves memory; faster. Use it for inference.
tensor.detach() returns a new tensor that shares data but is excluded from autograd. Useful for “stop gradient” tricks (e.g. in target networks for RL).
Gradient checkpointing
Backprop needs activations from the forward pass to compute gradients. Storing them all costs memory. Gradient checkpointing trades compute for memory: drop intermediate activations during forward, recompute them during backward.
Used routinely in training large models.
from torch.utils.checkpoint import checkpoint
def forward(x):
h = checkpoint(layer1, x)
h = checkpoint(layer2, h)
return h
Higher-order gradients
You can take the gradient of a gradient:
grad = torch.autograd.grad(loss, x, create_graph=True)
hessian_vec = torch.autograd.grad(grad.sum(), x)
Used in:
- Meta-learning (MAML)
- Some second-order optimizers
- Adversarial robustness research
Common bugs
- Forgetting
optimizer.zero_grad(): gradients accumulate by default. Train one step → great. Train two steps without zeroing → garbage. - In-place operations on a tensor that needs gradient:
x.add_(1)instead ofx = x + 1. Sometimes works, sometimes errors. Avoid. - Detaching too aggressively: blocks gradient flow you actually wanted.
- Not detaching enough: pulls extra graph into your loss, OOMs memory.
- Gradient norm spikes: a single bad batch destabilizes training. Clip gradients.
- Mixed-precision NaNs: fp16 can underflow. Use bf16 or
GradScaler.
Debugging gradients
Print the L2 norm of gradients per layer:
for name, p in model.named_parameters():
if p.grad is not None:
print(name, p.grad.norm().item())
If gradients are zero in early layers — vanishing problem. If they explode — clipping needed. If None — check that the parameter is actually used in the forward pass.
Exercises
- Manual backprop. For a 2-layer MLP
y = W₂ relu(W₁ x)on a single example, write out∂L/∂W₁and∂L/∂W₂. Verify against PyTorch autograd. - Verify a custom autograd function. Implement a custom op (e.g.
gelu) with explicit forward/backward; verify it matchestorch.autograd.gradcheck. - Vanishing gradients live. Build a 30-layer MLP with sigmoid activations. Print gradient norms per layer. Watch them collapse.
- Fix it. Replace sigmoid with ReLU + add residual connections + BatchNorm. Re-check.