One-Line Summary: The realization that attention's bottleneck was not arithmetic but memory bandwidth, and the tiling algorithm that turned that insight into a 2-4x speedup with zero approximation.
Prerequisites: 01-attention-mechanism-evolution.md
What Is Flash Attention?
Imagine you are a chef in a kitchen where the pantry is in a separate building. The actual cooking — chopping, seasoning, sauteing — takes seconds. But every time you need an ingredient, you walk five minutes to the pantry and five minutes back. The obvious solution is not to cook faster but to reorganize your workflow: bring everything you need for one dish at a time, finish it completely, then move to the next. FlashAttention is exactly this reorganization applied to the attention mechanism.
Standard attention computes the full N x N attention matrix, writes it to GPU high-bandwidth memory (HBM), reads it back for the softmax, writes the result again, and reads it once more for the value multiplication. Each read/write to HBM takes orders of magnitude longer than the actual arithmetic. Tri Dao, working at Stanford, recognized that this memory traffic — not the floating-point operations — was the true bottleneck. His solution: tile the computation so that everything happens in fast SRAM (the GPU's on-chip memory), never materializing the full attention matrix in slow HBM.
FlashAttention is not an approximation. It computes mathematically exact attention. It is a systems-level optimization that respects the hardware memory hierarchy, and it changed what was practical in LLM training and inference. It is now built into PyTorch, JAX, and every major training framework.
How It Works
FlashAttention: IO-Aware Tiling on GPU Memory Hierarchy:
Standard Attention FlashAttention
(Memory-Bound) (IO-Aware Tiling)
Q K V Q K V
│ │ │ │ │ │
▼ ▼ ▼ ▼ ▼ ▼
┌──────────────┐ ┌──────────────┐
│ Compute S │ │ Load tile │
│ = QK^T │ │ into SRAM │─┐
└──────┬───────┘ └──────┬───────┘ │
│ write to HBM (SLOW) │ │
┌──────▼───────┐ ┌──────▼───────┐ │ All in SRAM
│ Read S │ │ Compute QK^T│ │ (FAST,
│ from HBM │ │ + softmax │ │ ~20MB,
└──────┬───────┘ │ + output │ │ 19 TB/s)
│ │ per tile │ │
┌──────▼───────┐ └──────┬───────┘ │
│ Softmax(S) │ │ │
└──────┬───────┘ ┌──────▼───────┐ │
│ write to HBM (SLOW) │ Write only │─┘
┌──────▼───────┐ │ final output│
│ Read P │ │ to HBM │
│ Compute PV │ └──────────────┘
└──────────────┘
Memory: O(N) not O(N^2)
O(N^2) HBM reads/writes Speedup: 2-4x (exact output)
GPU Memory Hierarchy:
┌────────────────────┐
│ SRAM: ~20 MB │ ◀── ~19 TB/s bandwidth (FAST)
├────────────────────┤
│ HBM: 40-80 GB │ ◀── ~2 TB/s bandwidth (10x slower)
└────────────────────┘The Memory Hierarchy Problem
Modern GPUs like the NVIDIA A100 have two levels of memory: SRAM (on-chip, ~20 MB, ~19 TB/s bandwidth) and HBM (off-chip, 40-80 GB, ~2 TB/s bandwidth). That is a 10x bandwidth gap. Standard attention for a sequence of length N requires O(N^2) reads and writes to HBM — reading Q, K, V, writing the N x N attention matrix S, reading S for softmax, writing the normalized matrix P, reading P and V for the output. For N = 8192 and d = 128, the attention matrix alone is 8192 x 8192 x 4 bytes = 256 MB. This dwarfs the useful computation, which is just matrix multiplications that GPUs execute extremely fast.
FlashAttention v1 — IO-Aware Tiling (May 2022)
Dao's key insight was to decompose the attention computation into tiles that fit in SRAM. The algorithm processes Q in blocks of size B_r and K, V in blocks of size B_c. For each block of Q, it iterates over all blocks of K and V, computing partial attention scores in SRAM. The challenge is softmax, which normally requires knowing the maximum value across the entire row. FlashAttention uses an online softmax trick: it maintains running statistics (max and sum of exponentials) that are updated as each K block is processed, allowing exact softmax computation without ever storing the full row. The result: O(N) memory instead of O(N^2), and 2-4x wall-clock speedup on A100 GPUs. Training BERT-large went from 14 hours to 7.6 hours. GPT-2 training saw similar improvements.
FlashAttention-2 — Better Parallelism (July 2023)
The second version improved hardware utilization from roughly 50-70% of theoretical FLOPs to 70-90%. Key changes: reducing non-matmul FLOPs (which GPUs handle less efficiently), better work partitioning across thread blocks (parallelizing over the sequence length dimension rather than batch and heads only), and native support for Multi-Query Attention and Grouped-Query Attention. FlashAttention-2 achieved up to 230 TFLOPs/s on A100 (73% utilization), compared to 125 TFLOPs/s for v1. It also added support for head dimensions up to 256 (up from 128), which matters for models with fewer, larger heads.
FlashAttention-3 — Asynchronous and Low-Precision (2024)
Targeting NVIDIA's Hopper architecture (H100), FlashAttention-3 exploited new hardware features: asynchronous execution of WGMMA (Warp Group Matrix Multiply-Accumulate) operations, FP8 tensor core support for attention computation, and hardware-assisted memory operations. The asynchronous approach overlaps data loading with computation — while one tile is being computed, the next is being loaded. FP8 support enables roughly 2x throughput over FP16 for attention, with minimal accuracy loss. On H100, FlashAttention-3 achieves up to 740 TFLOPs/s with FP8, approaching the theoretical maximum.
The Broader Hardware-Aware Computing Movement
FlashAttention catalyzed a mindset shift: algorithms must be designed for the hardware they run on, not in abstract mathematical elegance. This spawned a generation of hardware-aware optimizations. FlashDecoding (Nov 2023) applied similar tiling principles to inference, parallelizing over the KV sequence length for faster decoding. Ring Attention (2023) distributed attention computation across multiple GPUs for extremely long sequences. xFormers (Meta) collected memory-efficient attention implementations. ThunderKittens (2024) provided a DSL for writing GPU-aware kernels. The common thread: stop treating GPUs as abstract compute machines and start respecting their memory hierarchy.
Why It Matters
Enabling Long-Context Training
Before FlashAttention, training on sequences longer than 2-4K tokens was prohibitively expensive due to the O(N^2) memory of standard attention. FlashAttention's O(N) memory made 16K, 32K, and eventually 128K-token training sequences practical. Without it, the long-context revolution described in 07-long-context-techniques.md would not have happened.
Universal Adoption
FlashAttention is built into PyTorch (torch.nn.functional.scaled_dot_product_attention), JAX, Hugging Face Transformers, and every major training framework. It is the rare research contribution that became immediate infrastructure. Virtually every model trained after mid-2022 benefits from it.
Complementary to Algorithmic Innovations
FlashAttention and attention mechanism variants like GQA/MLA solve different problems. GQA reduces what needs to be stored; FlashAttention reduces how expensively it is computed. They stack: a model using GQA with FlashAttention gets benefits from both. This complementarity means FlashAttention remains relevant regardless of which attention variant wins.
Key Technical Details
- Standard attention memory: O(N^2) for the attention matrix. FlashAttention: O(N) — only the output and running statistics.
- FlashAttention v1 (2022): 2-4x speedup over PyTorch attention. BERT-large training: 14 hours → 7.6 hours on 8x A100s.
- FlashAttention-2 (2023): 70-90% of theoretical GPU FLOPs. Up to 230 TFLOPs/s on A100 (vs 125 for v1).
- FlashAttention-3 (2024): Asynchronous execution + FP8. Up to 740 TFLOPs/s on H100. ~1.5-2x faster than v2 on Hopper.
- GPU memory hierarchy: A100 SRAM ~20 MB at ~19 TB/s; HBM 40-80 GB at ~2 TB/s. ~10x bandwidth gap drives the optimization.
- Block sizes tuned to SRAM capacity: typical B_r, B_c = 128-256, chosen so QK^T block fits in SRAM.
- Online softmax trick: maintains running max and sum of exponentials, enabling exact softmax without full-row materialization.
- Adopted universally: PyTorch 2.0+, JAX, DeepSpeed, Megatron-LM, Hugging Face, vLLM.
Common Misconceptions
-
"FlashAttention is an approximate attention method." It computes mathematically exact standard attention. The output is bit-for-bit identical (up to floating-point reordering) to naive attention. It is a systems optimization, not an algorithmic approximation.
-
"FlashAttention only matters for training." FlashDecoding and FlashAttention's inference mode provide significant speedups for serving as well. The memory savings are especially critical during inference with long contexts.
-
"FlashAttention makes sparse attention unnecessary." FlashAttention optimizes dense attention — it still computes all N^2 interactions, just more efficiently. Sparse attention methods reduce the number of interactions computed. For extremely long sequences (millions of tokens), sparse or linear attention may still be needed.
-
"Any GPU benefits equally." The speedup depends on the GPU's SRAM-to-HBM bandwidth ratio. GPUs with relatively more SRAM (like H100) benefit most. On CPUs or very old GPUs, the benefit is minimal.
-
"FlashAttention solves the quadratic scaling problem." It solves the memory problem (O(N^2) → O(N)) but the compute remains O(N^2). For 1M-token sequences, the arithmetic cost is still enormous — FlashAttention just ensures it is not multiplied by memory inefficiency.
Connections to Other Concepts
FlashAttention is the hardware complement to 01-attention-mechanism-evolution.md — GQA and MLA reduce the data, FlashAttention speeds up its processing. 06-kv-cache-and-serving-optimization.md builds on FlashAttention's memory management principles for inference. 07-long-context-techniques.md depends on FlashAttention for practical long-sequence training and inference. 07-training-efficiency-breakthroughs.md includes FlashAttention as one of several techniques that dramatically reduced training costs. 05-state-space-models-and-mamba.md represents the alternative path — if attention's quadratic compute is ultimately intractable for very long sequences, SSMs offer a fundamentally different approach.
Further Reading
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (2022) — the original paper.
- Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (2023) — the improved version.
- Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024) — Hopper-optimized version.
- Dao et al., "FlashDecoding" (2023) — applying similar principles to inference.
- Rabe and Staats, "Self-attention Does Not Need O(n^2) Memory" (2021) — the theoretical foundation for memory-efficient attention.