115 min readadvanced

Backpropagation Part 3: Systems, Stability, Interpretability, Frontiers

Theory assumes infinite precision; hardware delivers float16. Bridge the gap between mathematical backprop and production systems. In this post, we cover a lot of "practical" ground from PyTorch's tape to mixed precision training, from numerical disasters to systematic testing, from gradient monitoring to interpretability. What breaks, why, and how to fix it.

From Concepts to Systems

Parts 1 and 2 gave you the mental model. Computational graphs, adjoints, VJPs. You can trace a gradient by hand through a small network. You know why transposes appear and why residuals are gradient highways.

Then you train something real, and step 247 prints loss = nan.

Nothing in the math predicted this. You stare at it. Add logging. Halve the learning rate. Run it again. The NaN comes back at a different step. Or it doesn't, and you're not sure which is worse.

Or: the model trains fine on toy data, but at full scale with bigger batches, CUDA runs out of memory. You enable mixed precision. Gradient underflow makes training go nowhere. You add loss scaling. One hyperparameter now governs whether your model even runs.

The math of backpropagation assumes infinite precision and unlimited memory. Real training has neither. Most of the time, that's fine. When it isn't, the failure is rarely dramatic. The loss curve still descends. Validation metrics still improve. You're just leaving 10% accuracy on the table because a normalization layer computes variance in a numerically unstable way, or your custom backward pass has a broadcasting error nobody caught.

The bugs are silent. You blame the architecture. You tune the optimizer. The problem was a precision issue two layers back.

This post is about running that math on real hardware. Same chain rule, same VJP patterns, same local rules. Under finite precision and memory constraints that the clean version ignores.

We map the conceptual tape to what PyTorch actually records. Write custom gradients for operations where autodiff gives the numerically wrong answer. Run mixed precision without blowing up. Checkpoint activations when memory runs out. Build a testing habit that catches gradient bugs before they waste a week of compute. At the end, turn gradients into interpretation tools: ask what the network is actually responding to.

The theory holds. We're just running it in conditions it wasn't designed for.

Autodiff Systems in Practice

Parts 1 and 2 gave you the math. You can trace adjoints through a computational graph and understand why transposes appear where they do. Now: what actually happens when you call loss.backward() in PyTorch, or grad(loss) in JAX?

Not conceptually. Literally. What gets recorded, and when? Every framework runs the same chain rule, but the engineering choices around when to build the computational graph determine what models you can express, what your error messages look like, and how much the runtime can optimize behind the scenes.

The central split: does the framework build the graph before your code runs, or while it runs?

Define by Run vs Static Graphs

Here's something that becomes obvious once you see it: PyTorch lets you put a Python if statement inside a model. TensorFlow 1 didn't. You needed tf.cond instead. That API choice wasn't arbitrary. It was forced by a decision made before the model ever ran a single input.

Define-by-Run: The Graph Builds Itself

PyTorch's approach is called define-by-run, or eager execution. Just run the code. The graph builds itself along the way.

Every operation immediately executes and records itself. The graph exists only for this specific forward pass. Run the same code with different inputs, different control flow, and you get a different graph:

Python control flow just works. An if statement creates a different graph depending on the input. A for loop unrolls into exactly as many nodes as it ran. The graph is a faithful record of what actually happened, not a description of what might happen.

The tradeoff: the framework only ever sees one operation at a time. It can't fuse operations or reorder them for better memory access.

Static Graphs: Describe First, Execute Later

TensorFlow 1.x started from the opposite end. Describe the entire computation first, then execute it.

The framework sees the whole computation before running any of it. That's the entire point: global visibility enables optimizations that operation-by-operation execution can't. Fuse three matrix multiplies into one kernel pass. Distribute work across devices before anything runs. Compile everything to optimized machine code.

The cost was equally real. Python if statements don't mean anything to a graph that isn't running yet. You needed tf.cond for conditionals, tf.while_loop for loops. Debugging was painful: errors surfaced at execution time, not construction time. By the time something failed, the Python code that caused it was long gone.

The Modern Compromise

Modern frameworks didn't pick a side. PyTorch added TorchScript for static optimization when you need it. TensorFlow 2 adopted eager execution by default but kept XLA. JAX went somewhere else entirely:

JAX treats everything as functional transformations. Your forward pass is a pure function. grad turns it into its gradient function. jit compiles it. They compose: you can JIT a gradient, take gradients of JIT-compiled code. The trick is that JAX traces through your function at transformation time, building a static representation from something that looks dynamic. That's why it requires pure functions. Side effects would break the trace.

No universally best approach. Dynamic graphs make debugging tractable: the Python stack trace points to the actual line that failed. Static graphs enable optimizations that eager execution can't touch. Modern frameworks let you reach for either. Debug in eager mode. Ship in compiled mode. Pick based on what you're doing right now, not on ideology.

What Lives on the Tape

Print any tensor that came from a differentiable operation and you'll see something like tensor(6., grad_fn=<SumBackward0>). That grad_fn is a pointer to a record the forward pass wrote. When you call loss.backward(), PyTorch walks backward through these records. Each one receives the upstream gradient, computes its local contribution, and passes results further back.

So what does each record actually store?

The Tape Entry Anatomy

Think about y = x @ w + b. The backward pass needs to produce gradients for x, w, and b. Computing grad_x needs w. Computing grad_w needs x. And grad_b = grad_output.sum(axis=0) requires nothing from the forward pass at all. So the tape stores x and w, skips b. Each entry saves exactly what its backward function needs for the VJP, nothing more.

That principle holds for every op. Some save even less than you'd expect:

The memory difference adds up fast. A ReLU layer with one million activations stores a boolean mask: one bit per element, 125KB. Storing the full float32 input instead would cost 4MB. Same backward computation, 32x the memory. Across a hundred-layer network, that's the difference between fitting on one GPU and not.

The Tape's Lifetime

The tape is created during the forward pass and consumed during backward. Each entry is freed as it's used. Most PyTorch users hit this error at some point:

If you genuinely need multiple backward passes on the same graph, you can opt in:

Usually there's a cleaner design that avoids needing this.

No Tape, No Gradient

Not every operation needs to build the tape. Three ways to skip it:

During inference, model parameters still have requires_grad=True. Every forward pass builds a tape across the entire network and holds it in memory until garbage collected. Wasted work:

Running inference without no_grad() doesn't surface in tests. No error, no warning. Just extra memory and slower calls.

Custom Gradients: When Autodiff Isn't Enough

Autodiff doesn't compute the gradient of your operation. It computes the gradient of your implementation of it.

Those are usually the same. Sometimes they're not.

Take log-softmax. Write it as torch.log(torch.softmax(x, dim=-1)) and autodiff will differentiate it correctly. It will also produce NaN for large inputs, because the intermediate softmax first exponentiates x, which overflows before log can pull values back into range. The gradient is right by the math, broken in practice. The stable alternative, x - log_sum_exp(x), computes the same function but avoids the overflow. Autodiff can't discover this shortcut. It follows your code.

That's the first reason you'd write a custom gradient: numerical stability. There are two others. Sometimes your forward pass calls an external solver that Python can't differentiate through. Autodiff stops at that boundary. And sometimes the automatic gradient is correct but slower than a hand-derived formula. In all three cases, the fix is the same: replace what autodiff would generate with a backward function you write yourself.

The Basic Pattern

PyTorch's mechanism is torch.autograd.Function. Subclass it, define forward and backward as static methods, and use ctx to pass saved tensors between them:

When the reverse pass reaches your operation, it calls your backward with the upstream gradient. Your function is just another node in the graph. The only difference is you wrote the backward rule instead of the framework.

A Real Example: Stable Log-Softmax

Back to log-softmax. Here's what it looks like as a custom autograd function. The forward pass subtracts the max before exponentiating to keep values in range. The backward pass computes the gradient from the saved intermediates rather than letting autodiff retrace the naive path:

Notice what gets saved: exp_x and sum_exp, not the original input x. The backward only needs those two quantities. You could save x and recompute exponentials inside backward, but that redoes work the forward pass already did.

Straight-Through Estimators

The strangest application of custom gradients. Rounding has zero derivative almost everywhere. Backpropagate through torch.round and every upstream gradient dies. But quantization-aware training needs rounding in the forward pass and gradients flowing in the backward pass.

The straight-through estimator does something that sounds wrong: in the backward pass, ignore the rounding entirely. Pass the upstream gradient straight through as if the operation were identity.

Why does lying to the backward pass work? Because the lie is small enough to be useful. The gradient reaching earlier layers is wrong in precise terms but approximately right in direction, nudging weights toward lower loss. This pattern shows up wherever discrete operations need to participate in gradient-based training: quantized weights, binarized networks, discrete latent variables.

Mixed Precision: The Balancing Act

An A100's FP16 tensor cores peak at 312 TFLOPS. Its FP32 throughput? 77 TFLOPS. Four times the math, same silicon, just narrower numbers. Every training GPU shipped in the last five years has this asymmetry baked in.

So why not just train in FP16?

Try it. Some models train fine. Many don't. The loss plateaus too early, oscillates without converging, or the network learns but generalizes worse than it should. The failure is quiet. No error, no crash. Just... worse results.

The culprit is almost always gradient underflow.

FP16 can't represent numbers smaller than about 6×10⁻⁸. Gradients deep in a network are products of many local Jacobians, the chain rule accumulating backward through dozens or hundreds of layers. Those products shrink fast. At some depth they drop below the FP16 floor and round to zero. Layers with zero gradients stop updating. Training keeps going. The loss might even decrease, because other layers still learn. But you've silently killed learning in the underflowed layers, and you won't notice until the model underperforms in ways you can't explain.

Mixed precision keeps the FP16 speed without the silent failure. The idea: not every part of training needs the same precision. Use FP16 where it's safe. Use FP32 where it matters.

The Precision Hierarchy

Why this split?

Forward pass arithmetic (matmuls, convolutions, activations) operates on values that typically sit between 0.001 and 1000. FP16 handles that range fine.

Gradients are more sensitive, but what matters is relative precision, not absolute magnitude. A gradient of 1e-5 still carries useful direction in FP16 as long as it doesn't underflow to zero. The goal is to keep gradients alive, not to represent them at full FP32 precision.

Parameter updates are where FP16 breaks. Under Adam, the actual weight change per step might be on the order of 1e-6. If the weight currently sits at 1.0, you're asking FP16 to distinguish 1.000001 from 1.0. It can't. The update rounds away to nothing. This is why you need an FP32 master copy of the parameters: to accumulate small changes that FP16 would silently discard.

Loss Scaling: Preventing Underflow

The underflow fix is almost too simple. Make the gradients bigger.

Multiply the loss by a large constant (say 1024) before calling backward(). The chain rule propagates that factor through every gradient in the network. Values that would have hit the FP16 floor now sit comfortably above it. After the backward pass, divide all gradients by the same constant before the optimizer step. Direction and relative magnitudes are unchanged. You've just temporarily shifted the scale into FP16's representable range.

The catch with a fixed scale: too large and gradients overflow to Inf (worse than underflow). Too small and you haven't solved anything. And the right scale changes over training as gradient magnitudes evolve. Static scaling means manual tuning, and it breaks in both directions.

Dynamic loss scaling solves this by watching for overflow. If any gradient comes back as Inf or NaN, the scale was too aggressive: halve it, skip this update. If several consecutive updates succeed without overflow, the scale has room to grow: double it. Over training, the scaler converges to the largest stable value.

The scaler's strategy:

  1. Start with a large scale (65536)
  2. If gradients overflow (inf/nan), skip the update and halve the scale
  3. If N steps succeed without overflow, double the scale
  4. Converge on the largest safe value automatically

In practice, you don't manually pick which operations run in FP16 and which need FP32. PyTorch's autocast handles it:

autocast maintains an internal allowlist. Matmuls and convolutions get FP16. Reductions, softmax, and batch norm get FP32. You write normal code and the framework routes each op to the right dtype.

Memory Optimization: Trading Compute for Space

You scale up. The model trains fine on a toy dataset. Then you switch to the real thing: bigger batches, longer sequences, and CUDA out of memory. Not a warning. Training stops.

The instinct is to shrink the batch size until it fits. That works. It's also the least interesting option.

What's actually eating the memory? Parameters are a fixed cost. A billion parameters in BF16: about 2GB, predictable. What grows with batch size and sequence length is activations: the intermediate tensors your forward pass produces and keeps alive because the backward pass needs them. Double your batch size, double your activation memory. Stack 100 layers and you're keeping 100 intermediate states alive simultaneously.

Every technique in this section works the same trade-off: recompute something cheaply rather than store it, or reduce how much you need to store at once.

Gradient Checkpointing: Recompute to Remember Less

Storing every intermediate activation is expensive. Recomputing any one of them is cheap. The forward pass already did the work once. Doing it again costs a fraction of total training time.

So: mark certain layers as checkpoints and discard activations in between. During the backward pass, when you reach a discarded segment, rerun the forward from the nearest checkpoint. Memory drops substantially. Compute goes up by about a third.

Which layers to checkpoint? Not all equally. The best candidates have large activation footprints and are fast to recompute. The last layer in a segment usually doesn't need checkpointing at all, since its output is consumed immediately.

Gradient Accumulation: Fake Larger Batches

You want batch size 32 for stable gradients. Your GPU can fit 8. These feel like the same constraint. They're not.

Gradient accumulation decouples them. Run several small forward-backward passes without calling optimizer.step(). Gradients accumulate in each parameter's .grad tensor. Once you've processed enough micro-batches to simulate your target batch, take the update. Memory only ever sees the small batch. The gradient behaves like it came from the large one.

Mathematically, accumulated gradients are equivalent to computing on the full batch at once. One exception: batch normalization. BN computes statistics over whatever batch it sees. If your micro-batch is 8 examples, BN normalizes over 8 examples, not 32. For models using LayerNorm (most modern transformers), this doesn't apply. Gradient accumulation is a clean substitute for larger batches.

Memory Profiling: Find the Leaks

You can't fix what you can't measure. Peak memory during the backward pass is often double what the forward pass used. A few measurements isolate the culprit.

Once you have the numbers, the usual suspects:

The through-line: parameters are a fixed cost, activations are not. Activations scale with batch size, sequence length, and depth. Every technique here either reduces how many you store or makes each one cheaper.

Get these right and the ceiling lifts. The model that wouldn't fit at batch 32 runs at batch 128 after checkpointing and mixed precision. Sometimes that's the difference between a run that finishes and one that never starts.

Numerical Stability and Testing

We've explored how backprop works mathematically, how to implement it, and how to manage gradient pathologies. In practice: even if you get all the math right, numerical computation on finite precision hardware can silently distort gradients. A single overflow in a softmax can cascade into NaN losses. An accumulation of rounding errors can make gradients point in the wrong direction. A mismatch between your mental model and your actual implementation can waste weeks of debugging.

This section is about trust. How do you know your gradients are correct? How do you prevent numerical disasters before they happen? How do you systematically test layers to catch bugs early? These aren't glamorous topics like attention mechanisms or diffusion models, but they're the difference between research that works and research that mysteriously doesn't.

The irony is, in practice, the bugs that waste the most time are the ones that do not crash. Your network trains, loss decreases, but performance plateaus below expectations. It is easy to blame the architecture, the data, or the hyperparameters, while the underlying issue is a subtle numerical problem that degrades gradients just enough to hurt learning without obviously failing.

Here's how to build bulletproof implementations: patterns that are numerically stable, testing strategies that catch bugs early, and monitoring approaches that reveal problems before they become disasters.

Stable Patterns: Computing Without Exploding

Write softmax(x) for x=[1000,999,1001]x = [1000, 999, 1001]. The formula says pi=exi/jexjp_i = e^{x_i} / \sum_j e^{x_j}. To evaluate it, you need e1000e^{1000}. That's roughly 5×104335 \times 10^{433}. Float32 maxes out around 103810^{38}.

Overflow. Infinity. NaN. Your training run is dead.

The math isn't wrong. The computation path is. Every fix in this section follows the same principle: find an algebraically equivalent path that keeps intermediate values inside the range your hardware can represent. Same function, different route through the number line.

Log-Sum-Exp: The Universal Stabilizer

The softmax overflow has a one-line fix. Subtract the maximum value from every input before exponentiating:

Why does this work? The constant cancels. exic/jexjc=exiec/(jexjec)e^{x_i - c} / \sum_j e^{x_j - c} = e^{x_i} \cdot e^{-c} / (\sum_j e^{x_j} \cdot e^{-c}). The ece^{-c} divides out. But now the largest exponent is e0=1e^0 = 1. Nothing overflows.

This generalizes. Whenever you compute ratios of exponentials, stay in log space as long as you can:

Log space is where exponentials behave. Leave it only when you must.

Cross-Entropy: Never Compute Log of Probability

Cross-entropy loss is log(ptarget)-\log(p_{\text{target}}), where pp comes from softmax. Two operations chained together. Each one is fine on its own. Together, they create a trap.

The softmax produces probabilities. For the wrong classes, those probabilities can be tiny: 102010^{-20}, 103010^{-30}, smaller. In float32, they underflow or lose most of their significant digits. Then you pass that degraded number into log. Garbage in, garbage out.

The fix: never compute the probability at all. Go from logits to log-probabilities directly:

PyTorch's F.cross_entropy expects raw logits for this reason. It fuses the computation internally. Pass it softmax outputs instead of logits and you undo the fusion, reintroducing exactly the instability the fused form was designed to avoid.

The Sigmoid-Binary Cross-Entropy Fusion

Binary classification has the same problem. Compute sigmoid, then binary cross-entropy. For large positive inputs, sigmoid(x) rounds to 1.0, and log(1 - 1.0) is negative infinity. For large negative inputs, sigmoid(x) rounds to 0.0, and log(0.0) blows up the other direction.

The pattern is always the same. Accept raw logits. Fuse the computation. Never materialize a probability just to take its log.

Variance: The Catastrophic Cancellation Problem

Here's a formula you might write without thinking twice: Var(X)=E[X2]E[X]2\text{Var}(X) = E[X^2] - E[X]^2. Mathematically correct. Numerically, it's a bomb.

Say your data is [1000000.1,1000000.2,1000000.3][1000000.1, 1000000.2, 1000000.3]. The mean is about 10610^6. The variance is about 0.00670.0067. The formula computes E[X2]1012E[X^2] \approx 10^{12} and E[X]21012E[X]^2 \approx 10^{12}, then subtracts them. Float32 gives you 7 significant digits. All 7 are spent tracking the 101210^{12} part. The 0.00670.0067 you actually care about lives entirely in the rounding error.

The result can come back negative. Variance. Negative.

This failure mode is called catastrophic cancellation: subtracting two nearly equal large numbers destroys all the significant digits that distinguished them. The fix is to center first. Subtract the mean, then square. The centered values are small, their squares are small, and the computation stays well within float32's precision.

This is why every BatchNorm and LayerNorm implementation centers before computing variance:

Mixed Precision Pitfalls: When Float16 Kills You

Everything above gets worse in float16. Much worse.

Float32 tops out around 103810^{38}. Float16 tops out at 65,504. That's it. A gradient magnitude that's perfectly comfortable in float32 overflows to infinity in float16 without you doing anything unusual. On the other end, the smallest normalized float16 value is around 6×1056 \times 10^{-5}. Gradients deep in a network routinely fall below that and round to zero.

The fix is loss scaling, covered in Section 1.4. But in the context of stability patterns, the point is simpler: every overflow and cancellation risk you just read about is amplified when you halve your precision. The log-space tricks, the fused loss computations, the centered variance, all of them go from "good practice" to "required for training to work at all."

Gradient Checking Playbook

You write a custom backward pass. Run it. Numbers come out. Training doesn't crash.

Are the gradients right?

Probably. But "probably" is a bad place to stand when a week of GPU time is on the line. Wrong gradients don't crash. That's the whole problem. A sign error, a missing transpose, a broadcasting assumption that holds for one shape and breaks for another: all produce plausible numbers. Loss still decreases. The model just converges worse than it should, and you blame the architecture for a month while the real issue is a gradient off by 2x.

Gradient checking turns "probably" into "yes." Estimate gradients numerically with finite differences, compare against your analytical implementation. Mismatch means bug, and you find it before it costs anything.

The check itself is simple. What makes it fail is subtle: wrong epsilon, wrong inputs, wrong comparison metric. A check that passes on a broken implementation is worse than no check at all.

The Centered Difference Formula

Two versions exist. The one-sided version nudges forward by ϵ\epsilon: (f(x+ϵ)f(x))/ϵ(f(x+\epsilon) - f(x))/\epsilon. Error proportional to ϵ\epsilon. Halve ϵ\epsilon, halve the error. Fine, but not great.

The centered version probes both directions: (f(x+ϵ)f(xϵ))/(2ϵ)(f(x+\epsilon) - f(x-\epsilon))/(2\epsilon). Error scales with ϵ2\epsilon^2. Halve ϵ\epsilon, quarter the error. One extra function evaluation. Always use centered:

The Gradient Check Protocol

You have a numerical estimate and an analytical one. How do you compare them?

Not with absolute difference. A discrepancy of 0.001 means nothing if the gradient is 1000, and everything if the gradient is 0.0001. You need relative error, normalized by magnitude. The denominator adds a small epsilon for the near-zero case where both gradients are tiny and the ratio blows up.

This protocol also reports where the error is worst, which matters when you're debugging a layer with thousands of parameters:

Choosing the Right Epsilon

The instinct is "smaller is better." Finer perturbation, better approximation. True up to a point.

The numerator computes f(x+ϵ)f(xϵ)f(x+\epsilon) - f(x-\epsilon). When ϵ\epsilon shrinks far enough, those two values become nearly identical in floating point. Subtracting them amplifies rounding error. Below about 10710^{-7} in float64, rounding noise grows faster than the signal. Too large and you're measuring a chord, not a tangent. Too small and rounding eats the answer. The sweet spot sits in the middle.

Usually 10510^{-5} to 10410^{-4} in float64. Here's how to find it empirically:

Special Cases That Need Care

Three operations will fail your gradient check unless you handle them specifically.

ReLU is non-differentiable at x=0x = 0. Don't test near it. Dropout is stochastic: the numerical and analytical passes see different random masks unless you freeze the mask. Max pooling's argmax can flip when the perturbation crosses a tie, making the numerical estimate jump discontinuously.

Unit Tests for Layers

You write a layer. Gradient check passes. You move on.

Three weeks later, training is unstable at batch size 32. You halve the learning rate. It helps a little. You halve it again. You try a different optimizer. You start questioning the architecture.

The bug was in the bias gradient. It accumulated across the batch instead of averaging. At batch size 1 (which is what gradient check used), accumulating and averaging are the same thing. At batch size 32, the gradient is 32x too large. The gradient check literally cannot see this. It tested one sample. The math was right. The software was wrong.

That's the gap. Gradient checks verify calculus. They don't verify behavior. Does the layer produce the right output shape for batch size 1 and batch size 64, or just the one you happened to test? What happens when the input is all zeros? Does forward modify the input tensor in place, silently corrupting autograd's record of what happened? Is the output deterministic across calls with the same input?

Software bugs. All of them. And the gradient check sailed right past.

The template below is what to run before trusting a new layer:

The Most Common Layer Bugs

Four patterns account for most silent layer failures. All pass gradient checks. All surface eventually, usually at the worst possible time.

Reproducibility Knobs

The worst debugging experience is a bug that shows up once, then disappears. You run training. Step 247, NaN. You add logging, rerun. Clean. Did you fix it, or did the random seed land differently this time?

You can't know. And you can't know because a training run draws randomness from five independent sources: Python's random module, NumPy's state, PyTorch's global state, CUDA's state across all devices, and CuDNN's algorithm selection (which varies between runs when autotuning is on). Miss any one of them and two "identical" runs aren't identical.

The CuDNN flag is the one people miss. cudnn.benchmark = True (the default) tries a few candidate algorithms for your layer shapes and picks the fastest. That selection isn't reproducible across runs. Setting cudnn.deterministic = True forces a consistent choice. The cost is real:

On a 10-hour run, 1.5x slower is 5 extra hours. You don't want this on by default.

Better pattern: treat reproducibility as a mode you switch into for debugging, not a permanent setting. Save the backend state, enable full determinism for the enclosed block, restore when done.

When an intermittent bug appears, wrap the suspicious block, rerun with the same seed. If it reproduces: real bug, go fix it. If it doesn't: the behavior depends on randomness, and you need to find which specific random state triggers it.

Monitoring Gradient Health

You're at step 3000. Loss is going down. You leave it running overnight.

You come back to step 8000. Validation accuracy is stuck below where it should be. The training loss is still descending, slowly. You assume the learning rate is too high. You halve it. Nothing changes.

Then you add two lines of logging — gradient norm per layer, printed every hundred steps — and scroll back through the output. Layer 11 has been printing 4e-9 since step 200.

It's been dead for 7800 steps. Every other layer kept learning, so the loss kept going down. But the network you thought you were training is not the network you actually have.

Logging gradient statistics is how you see this at step 200 instead of step 8000.

What to Watch For

Most gradient failures fall into four patterns. You can spot all of them from the norm history alone.

Norms climbing exponentially: explosion is coming. The loss hasn't diverged yet, but give it another hundred steps. Norms bouncing wildly between updates: the optimizer is fighting itself, probably a learning rate issue. Norms frozen at some tiny value with near-zero variance: the layer quit. And sparsity creeping upward over time: neurons are dying off, one by one, and they aren't coming back.

The code below detects all four:

None of this is the interesting part of training. Nobody got into deep learning to write logging code. But three days into a validation plateau, when you still can't tell which layer died, you'll wish you had written it on day one.

Two lines of per-layer norm logging. That's the difference between catching a dead layer at step 200 and blaming your architecture at step 8000.

Interpretability via Gradients: What Your Network Actually Looks At

Throughout this post, we've explored how gradients flow backward to train networks. But here's something not yet covered: those same gradients can tell you what your network is "looking at" when it makes decisions. Not in some abstract mathematical sense, but literally which pixels in an image or words in a sentence drove the prediction.

This is interpretability through gradients, and it's both simpler and more limited than most people realize. The core idea: if changing a pixel would change the output, that pixel matters. The gradient tells you exactly how much. But as we'll see, this local sensitivity isn't the same as importance, and definitely isn't the same as understanding.

Think about it this way: you have a trained network that correctly classifies an image as a dog. You want to know why. The gradient L/x\partial L/\partial x at each input pixel tells you: "if I slightly increased this pixel's intensity, here's how much the dog score would change." Pixels with large gradients have high influence. Visualize these gradients and you get a saliency map, a heat map of influence.

But there's a catch that trips everyone up: the gradient is purely local. It tells you what would happen if you made a tiny change right now, not what would happen if the pixel wasn't there at all, or what role it plays in the broader computation. It's like asking "which pedal affects your speed?" while driving at 60mph. The answer (brake pedal: massive negative effect, gas pedal: small positive effect) tells you about local sensitivity, not about which pedal got you to 60mph in the first place.

Plain Saliency: The Simplest Attribution

Every gradient so far has been with respect to weights. That's training. But flip the question: what's the gradient with respect to the input?

fc(x)/xi\partial f_c(x) / \partial x_i tells you how much the class-cc score changes if you nudge pixel ii. Large magnitude means the network is sensitive there. Near zero means it doesn't care. Take absolute values, rescale, and you have a heat map of local sensitivity. That's a saliency map.

The code is almost embarrassingly short:

Run this on an image classifier and you get... something. Structure is there. Pixels near edges light up. But the map is noisy, speckled, full of high-frequency patterns that don't correspond to anything you'd recognize. It doesn't look like the network is "focusing on the dog." More like TV static with faint signal underneath.

Two things working against you.

The noise problem. The gradient captures every local sensitivity at once. Sensitivity to meaningful features, sensitivity to high-frequency pixel variations the network incidentally responds to. All of it, mixed together, no way to separate signal from artifact.

The saturation problem. This one is worse. The gradient is zero at any input passing through an inactive ReLU:

Zero gradient. The saliency map says that first input doesn't matter. Wrong. It means we're in a dead zone right now. If that pixel were slightly positive, it would activate and strongly drive the prediction. The gradient can't see through the switch.

Here's why saturation is the more dangerous failure mode. Noise is at least somewhat symmetric: some unimportant pixels flicker on, but important pixels usually show up somewhere. Saturation silently erases them. A pixel that matters gets reported as irrelevant because a ReLU happened to be off.

The standard fix for noise is smoothing:

Cleaner maps. But nothing fundamental changed. Saturation is still there. And there's a deeper issue: the gradient is a local linearization. It tells you what happens for infinitesimally small changes. Remove a pixel entirely, shift it by 50 intensity units, and the linear approximation stops being meaningful.

SmoothGrad addresses the noise side more carefully. For the saturation problem, you need integrated gradients.

SmoothGrad: Average Out the Noise

You ran plain saliency. The map came back speckled. Structure was there, maybe, underneath a layer of high-frequency static. The dog's face glowed faintly. So did a dozen random patches of grass. You squinted. You could sort of read it. Not confidence-inspiring.

Where does the noise come from? Think about what the gradient is actually responding to at a single point. Every local sensitivity at once. The edge of the ear that matters for classification, but also some texture pattern in the background that the network incidentally cares about at this exact pixel configuration. Signal and artifact, mixed together. One sample gives you no way to tell them apart.

But here's the thing. Jiggle the image slightly, just a few pixel values of random noise, and compute the gradient again. The meaningful sensitivities stay roughly the same. The dog's ear is still important. The spurious flickers move. They were artifacts of the exact point in input space, not stable features of the network's reasoning.

That's the entire idea behind SmoothGrad. Add small random noise to your image. Compute the gradient. Repeat with different noise. Fifty times. Average the results. The stable signal accumulates. The random jitter cancels out.

Why does averaging work? Same reason you'd take multiple readings from a noisy instrument. Each measurement is the true value plus some random error. If the errors are roughly independent across measurements (and for high-frequency gradient noise, they usually are), averaging nn samples cuts the variance by nn:

Var(mean of n samples)=Var(single sample)n\text{Var}(\text{mean of } n \text{ samples}) = \frac{\text{Var}(\text{single sample})}{n}

Fifty samples. Variance drops by 50x. The consistent parts survive. The random parts wash out.

The maps look noticeably better. Less static, more coherent structure. You can actually see the dog now instead of squinting through noise. That improvement is real.

But be careful about what improved. SmoothGrad is still computing local gradients. It's just averaging them over a small neighborhood of your input instead of reading a single point. The noise problem got better. The saturation problem from the previous section didn't. If the network is saturated at your input, averaging fifty nearby samples gives you a cleaner estimate of near-zero. That's not a better attribution. It's a more confident wrong answer.

Integrated Gradients: The Path Integral Solution

Plain saliency fails when neurons are saturated. SmoothGrad fails too, for the same reason: you're still sampling local gradients, just more of them. Average a hundred samples from a dead neighborhood and you get a stable zero. That's not better than a single zero. It's just quieter.

The saturation problem has a different shape than noise. Noise is random. Saturation is structural. The gradient is near zero not because the pixel doesn't matter, but because the network has already committed. It settled. Moving the pixel slightly doesn't change the output. Removing it entirely would.

There's something wrong with the question.

"What would happen if I nudged this pixel?" That's what the gradient measures. For a saturated neuron: nothing. The ReLU is pegged, the output doesn't budge. Technically correct. Practically useless.

A better question: how much did this pixel contribute to getting the output from nothing to where it is now? Not at one point. Cumulatively. Over the entire arc from blank to final image.

The path. Start from a baseline: a black image, say, where the network outputs near zero for every class. Interpolate toward your actual image. At each step along the way, measure the gradient with respect to each pixel. Some pixels contribute heavily early on, then settle into saturation before you reach the final image. Others kick in late. By the time you arrive at the actual input, each pixel has accumulated a total contribution: the sum of its effect at every step.

That accumulated sum is the attribution. The gradient at the endpoint is one sample. Integrated Gradients uses the whole path.

IGi(x)=(xixi)×α=01F(x+α×(xx))xidα\text{IG}_i(x) = (x_i - x_i') \times \int_{\alpha=0}^{1} \frac{\partial F(x' + \alpha \times (x - x'))}{\partial x_i} d\alpha

xx' is the baseline, FF is the model's output for the class of interest. At each interpolation step α\alpha, evaluate the gradient. Integrate over all α\alpha, then scale by how much each feature actually changed from baseline to input.

In code, approximate with a Riemann sum:

Fifty forward passes per attribution map. That's the cost.

Why saturation doesn't block this. A pixel drives a saturated ReLU at the final image. Zero gradient at the endpoint. But along the path from baseline to input, that pixel's value crossed the ReLU threshold. During that crossing, the gradient was nonzero. Integrated Gradients captured it. Plain saliency saw only the endpoint and reported nothing:

Completeness. The attributions sum to the difference between the model's output at the input and at the baseline:

F(x)F(x)=iIGi(x)F(x) - F(x') = \sum_i \text{IG}_i(x)

That's the fundamental theorem of calculus. Integrate a derivative over a path, get the total change. If the model's confidence went from 0.1 at the baseline to 0.9 at your input, the attributions sum to exactly 0.8. Nothing lost. Nothing invented.

A second property worth knowing: implementation invariance. Two networks computing the same mathematical function get the same attributions, even with different internal structure. Plain gradients don't have this. They depend on which specific activations fired, not just on what function the network computes.

What baseline to use. The black image is standard. But it's a choice, and that choice encodes a counterfactual: "compared to this reference, what drove the prediction?"

Blurred baseline: what sharp features matter beyond the overall structure? Dataset mean: what distinguishes this input from the average? Random noise: what structured signal rises above randomness?

Each is a different counterfactual. The black image asks "compared to nothing, what matters?" The dataset mean asks "compared to typical, what matters?" Neither is wrong. Pick based on the question you actually want answered, not on which map looks cleanest.

The Limits of Gradient-Based Interpretation

Plain saliency. SmoothGrad. Integrated Gradients. Each one fixed something real in the previous method. The maps got sharper.

Better visualizations. Not better answers.

None of these methods can tell you whether the network is looking at the right thing. There is no ground truth for "the correct saliency map." The map looks plausible. Regions light up near the cat's face. It feels like an explanation. But plausibility is not evidence, and there's nothing to verify against.

Three places where this breaks down.

Adversarial perturbations.

Take a correctly classified image. Add imperceptible noise, chosen to flip the prediction. Each pixel shifts by at most 0.01. Invisible to you. The prediction flips from "cat" to "dog." What does the saliency map look like now?

Completely different.

Nearly identical inputs. Completely different explanations. If the map captured something stable about what the network learned, a tiny perturbation shouldn't redraw the whole picture. But the gradient is local. It responds to exactly where you are in input space. Shift slightly and the local geometry shifts with you. The "explanation" is not a window into reasoning. It's a snapshot of the loss surface at one point.

Gradient attributions don't compose.

Say pixel A has attribution 0.5 and pixel B has 0.3. Combined effect? You'd guess 0.8. In a nonlinear network, it might be 2.0. Gradients measure what happens when you change one feature while holding everything else fixed. But the network learned interactions between features. A and B don't act independently. The "hold everything else fixed" assumption is exactly wrong for capturing joint effects.

Not a fixable limitation. It's structural. Gradients decompose the output into independent per-feature contributions. Nonlinear interactions don't decompose that way.

Saliency is not causality.

This one takes the longest to absorb.

Take a model trained on cows in grassland, camels in sand. It learns, not unreasonably, that green backgrounds predict cow and sandy backgrounds predict camel. It never needed to learn what a cow actually looks like. The background was a reliable shortcut. Show it a cow on a beach.

The saliency map highlights the sand. Accurately. The sand is driving the prediction. But from the map alone, you can't tell whether this is a failure (learned shortcut) or a success (sand really is relevant). The map looks the same either way.

So what are these methods for?

Not proving what a model learned. Not establishing trust.

Debugging. Hypothesis generation. When a model misclassifies and the saliency map highlights background instead of foreground, you've learned something concrete. When you're comparing two architectures to see which regions they respond to, attributions are a starting point. When you're scanning for dataset biases, saliency maps are a cheap first pass.

But "starting point" is doing real work in those sentences. A saliency map is a hypothesis. Not evidence.

Attribution methods are not explanations. They're queries. You ask "what would change the output if I moved this feature?" and get a local answer. Whether that answer is useful depends on whether you asked the right question.

The flashlight metaphor from earlier holds. A flashlight shows local structure around where you're standing. Not the layout of the room. Gradient attributions are derivatives. Local by construction. They describe what the network is sensitive to at one input. They say nothing about global behavior, learned patterns, or generalization.

Use them. But verify what they suggest. If the saliency map says the network attends to ears when classifying cats, occlude the ears. See if performance drops. If it doesn't, the attribution was misleading. The map is the hypothesis. The occlusion experiment is the evidence.

Advanced Topics and Frontiers: Beyond Standard Backprop

We've explored backprop from every angle: as efficient gradient computation, as reverse-mode autodiff, as graph traversal, as the foundation of deep learning. You understand how gradients flow through layers, how they vanish or explode, how to control them with initialization and normalization. You can trace adjoints through any computational graph and implement VJPs for custom operations.

But backprop's story doesn't end with standard neural networks. The same principles scale to programs far more complex than feedforward networks: differential equation solvers, optimization problems, probabilistic programs, even quantum circuits. Once you understand that backprop is just the chain rule applied systematically to any differentiable computation, whole new domains open up.

This final section surveys the frontiers. We won't dive deep into implementation (each topic deserves its own post), but here's how the principles you've learned extend to cutting-edge research. Think of this as a map of where to explore next, with just enough detail to see how everything connects back to what you already know.

Higher-Order Gradients: Differentiating the Differentiator

The backward pass is a program. Your forward code runs, the framework writes matching backward code that walks the same graph in reverse. We've been treating that backward pass as the final product. But it's just code. And if it's code, you can differentiate it.

Grad-of-grad. Differentiate the gradient computation itself.

Does this break something? f\nabla f is just another function, mapping inputs to outputs like any other. Same chain rule, same graph traversal, applied to the backward graph instead of the forward one. The second derivative falls out without any new machinery.

So when would you want this? MAML and other meta-learning methods optimize over gradient steps. To train an optimizer that learns how to learn, you need to differentiate through a gradient update. Newton's method uses curvature (the second derivative) to take better-shaped steps toward a minimum. Trust region methods need to know how fast the gradient itself is changing.

The usefulness was never in question. The cost is what kills you.

The Hessian-Vector Product Pattern

The second derivative of a scalar function with nn parameters is the Hessian: an n×nn \times n matrix HH. For a model with a million parameters, that's a trillion entries. You're not storing that.

But you rarely need the full matrix. Most algorithms that use second-order information only need Hessian-vector products: given a direction vv, compute HvHv. One row of information at a time.

Barak Pearlmutter noticed something clean. The product HvHv is the gradient of fv\nabla f \cdot v. In other words: take the dot product of the gradient with your vector vv, then differentiate that scalar with respect to xx. Two nested calls to grad:

Cost: roughly 2-3x a single gradient computation. Compare that to nn gradient computations for the full Hessian. For a million parameters, that's 3 forward-backward passes versus a million.

The same trick composes. Third derivatives? Differentiate the HVP. Each level adds another layer to the computation graph:

The pattern to notice: you never build the full higher-order tensor. You only compute its contraction with vectors. Everything stays linear in the number of parameters.

The Forward-Over-Reverse Pattern

You can also mix forward-mode and reverse-mode autodiff. Run forward-mode on top of a reverse-mode gradient, and each forward pass gives you one column of the Hessian:

O(n)O(n) passes, each about the cost of one gradient. Still expensive for huge nn, but when you actually need the full Hessian (small networks, critical point analysis), this is the way to compute it.

Implicit Function Differentiation: Solving Without Unrolling

Not every forward pass is a sequence of matrix multiplies. Sometimes it's a solver. You're finding a fixed point, solving a linear system, running an optimization to convergence. The answer is well-defined. You ran 100 iterations and got it. Now you need gradients.

The obvious move: unroll those 100 iterations into your computation graph and backprop through all of them.

This works. It also stores 100 intermediate states for the backward pass, and the gradients are unstable. From backprop's perspective, a 100-iteration solver is a 100-layer network. The vanishing and exploding gradient problems this entire series has been about? They show up here too, except you didn't choose 100 layers for architectural reasons. The solver just needed that many steps.

Here's what makes this wasteful. You don't actually care how the solver got to the answer. You care that it arrived.

At convergence, the solution satisfies a condition. For a fixed-point solver: x=f(x,θ)x^* = f(x^*, \theta). Rearrange that into F(x,θ)=xf(x,θ)=0F(x^*, \theta) = x^* - f(x^*, \theta) = 0. This condition holds at the solution regardless of which path led there. A different solver, a different number of iterations, a different starting point. Same xx^*. Same condition satisfied.

So differentiate the condition. Not the 100 steps that found it. That's the implicit function theorem:

xθ=(Fxx)1Fθx\frac{\partial x^*}{\partial \theta} = -\left(\frac{\partial F}{\partial x}\bigg|_{x^*}\right)^{-1} \frac{\partial F}{\partial \theta}\bigg|_{x^*}

Two Jacobians, both evaluated at the solution. One linear solve. Instead of carrying 100 intermediate states through backward, you carry one: xx^* itself.

The pattern generalizes beyond fixed points. Optimization problems, differential equations, eigenvalue computations. Anywhere your forward pass finds a solution by iterating, you can skip the iteration in the backward pass and differentiate the condition that defines the solution instead.

DEQs: Deep Equilibrium Models

This idea has an architectural consequence. If you can differentiate through a fixed-point solver cheaply, why stack explicit layers at all?

Deep Equilibrium Models (DEQs) take this seriously. One layer, applied repeatedly until the representation converges to a fixed point. No layer 1, layer 2, layer 47. Just one function ff and an equilibrium z=f(z,x,θ)z^* = f(z^*, x, \theta).

Infinite effective depth. Constant memory. The backward pass solves one linear system regardless of how many iterations the forward solver needed. The cost is wall-clock time: finding fixed points is slower than a single forward pass through explicit layers. But when memory is the bottleneck, that's a trade worth making.

Neural ODEs: When Depth Becomes Continuous

Common Derivative Reference Tables

You have already computed most of these. The activation derivative showed up in vanishing gradients. The matrix multiply gradient appeared on every backward pass through a linear layer. The softmax-cross-entropy fusion surfaced twice in the numerical stability section alone.

Here they live together, organized by category so the family relationships are visible.

Activation Functions and Their Derivatives

Two relationships worth noticing. Softplus is the smooth ReLU, and its derivative is sigmoid. Not a coincidence. Sigmoid and tanh are rescalings of the same underlying function, both saturating at the extremes, both causing vanishing gradients in deep stacks. ReLU sidesteps saturation by being non-smooth.

Loss Functions and Their Gradients

One row matters more than the rest: Softmax + CE. The gradient of cross-entropy through softmax, taken with respect to the raw logits, collapses to p^one_hot(y)\hat{p} - \text{one\_hot}(y). All the Jacobian complexity cancels. This is why frameworks expect raw logits, not probabilities. Pass softmax outputs into a cross-entropy function and you undo the fusion, reintroducing the numerical problems the fused form exists to avoid.

Matrix Operations

One pattern repeats across almost every row: transpose routing. Differentiating through Y=XWY = XW sends gradients to WW via XTX^T and to XX via WTW^T. Get that in your head and the rest of the table follows. The determinant and inverse rows show up less often in practice, but they appear in generative model objectives and certain optimization methods.

The Shape Discipline Quick Reference

Shape errors are the only bugs in backprop that crash immediately. Wrong sign, missing factor, botched chain rule application: all of those fail silently. Training continues, loss goes down, you blame the learning rate or the architecture. Weeks later, maybe you find it. Maybe you don't.

Shape mismatches? PyTorch stops cold and points at the line. Loud failures are a gift.

Here is the trap, though. The forward pass gives no warning when you set up a backward-pass shape bug. A bias of shape (D,) adds cleanly to a batch of (B, D) because broadcasting handles it. Forward pass works fine. But the backward pass needs to produce a (D,) gradient for that bias, which means summing over the batch axis. Skip that reduction, produce (B, D) instead, and the error only shows up later when that gradient tries to update a (D,)-shaped parameter.

Five patterns cover most of the cases you will hit.

The matrix multiply transposes are not a convention. They are the only arrangement that makes the dimensions work. grad_Y is (B, N), W is (M, N). To get grad_X with shape (B, M), you must hit grad_Y @ W.T. Write the shape annotations first, then the transpose placement writes itself.

Broadcasting is the pattern worth burning into memory. Whatever axis the forward pass creates through broadcasting, the backward pass must destroy through summation. Broadcast over batch? Sum over batch. Broadcast over sequence? Sum over sequence. Creation forward, destruction backward. Symmetric.

One rule sits above all five patterns: a gradient must have the same shape as its parameter. You cannot add a (B, M, N) tensor to an (M, N) parameter. When a shape bug bites, work backward from that rule and one of these five patterns will tell you where the mismatch entered.

A Closing Thought: The Gradient Perspective

Gradients are local. That's the whole thing, really.

The gradient at a point doesn't know where you're going. It knows which direction is downhill right now. Everything we built across these three posts is scaffolding around that one stubbornly local object. Mixed precision so gradients survive finite arithmetic. Checkpointing so you can afford to keep them. Monitoring so you notice when they stop flowing. Interpretability methods to ask what the network is locally sensitive to. All of it in service of a quantity that can only see one step ahead.

Here's what I find odd: it works. Not in theory. In practice. Nothing in the chain rule promises that chasing local sensitivities will produce a model that generalizes. The math says: here's how to compute these derivatives efficiently. Whether those derivatives are useful depends entirely on what you're differentiating through.

And that's why architecture matters as much as optimization. Residual connections give gradients a short path back to early layers. Attention keeps them from dying across long sequences. LayerNorm keeps the backward pass well-conditioned. Read the last decade of architecture papers through this lens and a pattern emerges: every major design choice is an answer to the same question. How do we make local gradient steps globally trustworthy?


The flashlight metaphor from the interpretability section applies here too. A gradient illuminates what's locally sensitive. It says nothing about whether that sensitivity matters, whether it generalizes, or whether the model learned the right thing for the right reasons. It answers one question: what does changing this input do to that output, right now?

After three posts, you can ask that question at whatever granularity you want. A single weight. An entire layer. A pixel in a saliency map. You can trace the answer back through any computation graph, however deep.

The limits are real. No global guarantees. No causal claims. No certainty that gradient descent finds anything close to optimal. But the terrain keeps getting better, and gradients keep getting further. Not because we made them smarter. Because we learned to build functions they can actually navigate.

References and Further Reading

Part 3 draws on more primary sources than the other two combined. Organized by topic, with notes on which papers are worth reading directly versus just citing.

Autodiff systems

Baydin et al. (2018) is the best single map of autodiff. Forward mode, reverse mode, source transformation, operator overloading, higher-order, tradeoffs between all of them. If you want to understand what PyTorch and JAX are doing at the theory level, start here. The PyTorch autograd mechanics page is the implementation-level companion.

Mixed precision and memory

Two papers behind most of Sections 1.4 and 1.5. Micikevicius et al. (2018) pinned down the rules: which ops stay in fp16, where to accumulate in fp32, how loss scaling prevents underflow. Chen et al. (2016) showed checkpointing at O(N)O(\sqrt{N}) boundary nodes gets you sublinear memory at constant compute overhead. Both are short. Both are worth reading in full.

Gradient attribution and interpretability

Each of the three attribution methods in Part 3 has a corresponding paper. Sundararajan et al. is the one I'd read first: it derives the axioms a good attribution method should satisfy (completeness, sensitivity, implementation invariance) and shows plain saliency fails most of them. Read it before using any of these methods on real problems.

The frontier: higher-order, implicit, and continuous

If Sections 4.1 through 4.3 made you curious, start here. The Neural ODE paper (Chen et al., 2018) is more readable than its title suggests: depth becomes continuous, and backprop costs two ODE solves regardless of integration depth. OptNet is the cleanest early treatment of wrapping a convex solver as a differentiable layer. MAML is where higher-order gradients became practically important: the algorithm differentiates through multiple gradient steps, so you need Hessian-vector products in the inner loop.

The original paper

Rumelhart, Hinton, and Williams (1986). One of the most cited papers in ML, which means most people have absorbed its conclusions without reading it. Section 3 is where the algorithm lives. After three posts, you have the prerequisites to follow their argument directly.