One-Line Summary: Gradient checkpointing trades compute for memory by storing activations only at selected layers and recomputing the rest during backpropagation.
Prerequisites: Understanding of forward and backward passes, where activations live in memory, and the rough memory footprint of large transformer training.
What It Is
In a normal training step, the forward pass stores every intermediate activation so the backward pass can reuse them when computing gradients. For a 70-billion-parameter transformer at long context, those activations alone can dominate GPU memory — easily 100+ GB.
Gradient checkpointing (sometimes called activation recomputation) takes a deal: store activations only at a sparse set of "checkpoint" layers; throw the rest away. Then in the backward pass, when you need an activation that wasn't kept, recompute it from the nearest checkpoint by re-running that segment of the forward pass.
without checkpointing: store activations for every layer (memory ↑↑, compute neutral)
with checkpointing: store every √L layers (memory ↓↓, compute ↑ ~33%)If you checkpoint every √L layers, you spend roughly O(√L) memory and pay roughly 33% extra compute on the backward pass.
Why It Matters
Checkpointing is one of the four pillars that made very-large-model training tractable on commodity-ish hardware. The other three are:
- Mixed-precision training (bf16 / fp8 weights and activations).
- Sharded optimizer state (ZeRO, FSDP) so optimizer memory doesn't live on every device.
- FlashAttention, which never materializes the full O(n²) attention matrix and so eliminates a huge activation-memory term on its own.
Together they shift the dominant memory cost off activations and onto weights and optimizer state — exactly where sharding strategies can take over. A 70B-parameter model goes from "infeasible without a hyperscaler" to "trainable on a serious-but-not-absurd cluster" largely because of these four ideas.
Key Technical Details
The PyTorch primitive is torch.utils.checkpoint, which wraps a chunk of the forward pass and arranges for it to be recomputed on the backward. DeepSpeed and Megatron-LM both expose this at the layer level with sensible defaults. The trade-off is real — checkpointing slows training by roughly 25–35% — but it almost always wins, because if you can't fit the model in memory at all, training speed is irrelevant.
Combined with FlashAttention, you can often keep activation memory effectively constant in sequence length, which is what made very-long-context training feasible.