0%

FlashAttention technical review en

FlashAttention — In-Depth Technical Review (English)

Author: Zhongzhu Zhou
Paper: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arXiv 2022 / NeurIPS 2022)
ArXiv: https://arxiv.org/abs/2205.14135


Abstract

FlashAttention is one of the papers that changed how the community thinks about efficient Transformer attention. The core point is subtle but extremely important: many previous attempts to accelerate attention focused on reducing FLOPs, while real GPU runtime was often dominated by memory traffic rather than arithmetic. This paper argues that exact attention can be made much faster without changing model semantics if we redesign the algorithm around the GPU memory hierarchy instead of around the usual matrix formula alone. The result is an exact attention kernel that avoids materializing the full attention matrix in high-bandwidth memory, cuts memory usage from quadratic extra memory to linear extra memory, and delivers large end-to-end speedups in BERT, GPT-2, and long-context tasks. In my view, the paper matters because it turned “attention optimization” from a mostly mathematical approximation game into a systems problem with a rigorous IO model.

1. Prerequisites: What to Know Before Reading This Paper

This review is deliberately written for a reader who may be older, careful, and not immersed in GPU systems jargon. So I will slow down and build the needed background before touching the algorithm itself.

1.1 What self-attention actually computes

Suppose we have a sequence of tokens, such as words in a sentence. A Transformer turns each token into three vectors:

  • Query QQ: what this token is looking for.
  • Key KK: what this token offers to others.
  • Value VV: the information this token will send if selected.

For one attention head, if the sequence length is NN and the head dimension is dd, then Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}. Standard scaled dot-product attention does three conceptual steps:

  1. Compute the score matrix [

S=QKTRN×N.S = QK^T \in \mathbb{R}^{N \times N}.

Every row compares one query token against every key token. 2. Apply row-wise softmax \[

P = \text{softmax}(S).

This turns each row into attention weights that sum to one.

  1. Mix values using those weights [

O=PVRN×d.O = PV \in \mathbb{R}^{N \times d}.

The problem is the middle object $S$ or $P$: it has size $N \times N$. If $N$ doubles, that matrix becomes 4 times larger. For long sequences, this is painful both for runtime and memory. ### 1.2 FLOPs are not the whole story A beginner often hears “this algorithm is faster because it uses fewer floating-point operations.” That is sometimes true, but not always. GPUs are not magic arithmetic boxes. They are machines with different layers of memory: * **HBM (High Bandwidth Memory)**: large but relatively slower GPU memory. * **SRAM / on-chip memory / shared memory / registers**: tiny compared to HBM, but much faster. The paper uses the term **IO-awareness** to mean: count how often data must travel between these memory levels, not just how many multiply-adds you do. This matters because in modern GPUs, arithmetic throughput has improved faster than memory movement. So an operation can be “cheap in math” but still “slow in real life” if it keeps reading and writing huge tensors from HBM. Attention is a classic example. The standard implementation repeatedly materializes large intermediate matrices in HBM. That memory traffic can dominate wall-clock time. A useful mental picture is a kitchen. Arithmetic is like chopping ingredients with a very sharp knife. Memory traffic is like walking back and forth to the pantry. If your knife becomes twice as fast but you still walk to the pantry 50 times, dinner is still slow. ### 1.3 Why softmax usually seems to require the full row The attention formula looks like it wants the whole score row before normalization. If one row is $s_1, s_2, \dots, s_N$, then softmax needs:

softmax(sj)=esjkesk.\text{softmax}(s_j) = \frac{e^{s_j}}{\sum_k e^{s_k}}.

At first glance, this suggests we must compute and store all $N$ scores in that row. But there is a trick: we can process scores block by block if we maintain running statistics. In particular, we can maintain: * the running row maximum $m$, used for numerical stability, * the running denominator $\ell = \sum_j e^{s_j - m}$. This “online softmax” idea is one of the essential ingredients that makes FlashAttention possible. ### 1.4 Exact attention versus approximate attention Before FlashAttention, much efficient-attention work tried to avoid the $N^2$ cost by approximating attention: * sparse attention, * low-rank approximations, * kernelized linear attention, * locality assumptions. These methods can be useful, but they alter the operation. The paper makes a different promise: *keep exact attention semantics, but compute it in a more hardware-friendly way.* That distinction is important. If you deploy FlashAttention, you are not changing the model’s mathematical definition of dense attention. You are changing the schedule of computation and memory movement. ## 2. What This Paper Does (The Core Idea) The core idea is simple to state and deep in consequence: > Do not materialize the full $N \times N$ attention matrix in HBM. Instead, tile Q, K, and V into blocks, move small blocks through fast SRAM, and update the output incrementally while carrying enough softmax statistics to remain exact. Figure 1 in the paper is worth understanding carefully. The left side shows standard attention as a pipeline that effectively creates and stores the large attention matrix. FlashAttention instead loops over blocks of $K$ and $V$, copies each block into fast SRAM, and then loops over blocks of $Q$. Inside SRAM it computes partial scores, partial softmax contributions, and partial output updates. Then it writes only the evolving output and a few per-row statistics back to HBM. This solves two problems at once: * **Memory problem**: extra memory becomes linear in sequence length rather than quadratic because the giant attention matrix is not stored. * **Runtime problem**: HBM reads and writes drop dramatically, which speeds up execution on real GPUs. The paper is careful to say that FlashAttention may do *more* arithmetic in some places, especially due to recomputation in the backward pass. But because the dominant bottleneck was often memory traffic, reducing HBM access can still make the method much faster overall. This is a systems lesson that shows up repeatedly in high-performance ML infrastructure. My interpretation is that the paper’s true breakthrough is not merely “fusing kernels.” Many earlier kernels fused operations. The stronger idea is to redesign the *algorithmic structure* around the memory hierarchy and prove that this new structure is IO-optimal over a meaningful range of SRAM sizes. ## 3. Method Details ### 3.1 Standard attention and where the waste happens The paper summarizes the standard implementation as three big stages: 1. Read $Q$ and $K$, compute $S = QK^T$, write $S$ to HBM.

  1. Read SS, compute P=softmax(S)P = \text{softmax}(S), write PP to HBM.
  2. Read PP and VV, compute O=PVO = PV, write OO to HBM.

This means the enormous $N \times N$ objects $S$ and $P$ are written to and read from HBM. That is exactly what FlashAttention tries to avoid. Notice something subtle. The matrix multiply itself is usually a compute-friendly operation. The softmax and masking steps are much more memory-sensitive. If these are implemented as separate kernels, you load and store large tensors multiple times. So even when the math is straightforward, the execution plan is not memory-friendly. ### 3.2 Tiling: break the big matrices into small blocks FlashAttention partitions the matrices into tiles: * $Q$ is partitioned into row blocks $Q_i$ of size $B_r \times d$. * $K$ and $V$ are partitioned into row blocks $K_j, V_j$ of size $B_c \times d$. The algorithm then uses nested loops: * Outer loop over $K_j, V_j$ blocks. * Inner loop over $Q_i$ blocks. For each pair $(i, j)$, it loads one query block and one key/value block into SRAM, computes a local score block

S_{ij} = Q_i K_j^T,

then updates per-row softmax statistics and the partial output block. Why this layout? Because it ensures the expensive part—the block score matrix and its softmax—is handled while the data is resident in fast SRAM. The algorithm never needs to store the global $N \times N$ score matrix. A practical way to think about it is this: the algorithm sweeps through the sequence in small rectangular windows, but it carefully preserves exactly the same final result as if we had formed the whole attention matrix at once. ### 3.3 Online softmax: how to remain exact without seeing the full row at once This is the heart of the paper. Suppose we have already processed some key blocks for a row of queries and stored two quantities: * $m_i$: the running maximum score for that row, * $\ell_i$: the running sum of exponentials, measured relative to the running maximum. When a new score block $S_{ij}$ arrives, we compute a local row maximum $\tilde{m}_{ij}$. Then we update the row maximum:

minew=max(mi,m~ij).m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij}).

Wealsorescaletheolddenominatorandaddthenewblockcontribution: We also rescale the old denominator and add the new block contribution:

inew=emiminewi+jeSijminew.\ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + \sum_j e^{S_{ij} - m_i^{\text{new}}}.

This is the key numerical trick. Because the reference maximum can change when a new block contains a larger score, the previously accumulated denominator must be rescaled. The same rescaling idea is also applied to the accumulated output contribution. In plain language: *We keep enough summary information about the part we have already seen so that when a new tile appears, we can merge it exactly and stably.* That is why the algorithm does not approximate softmax. It performs exact softmax row by row, but in a streaming tiled fashion. ### 3.4 Updating the output block incrementally The output block $O_i$ is also accumulated online. After processing a new $(i,j)$ block, the algorithm combines the old partial output and the new block’s weighted value contribution. Conceptually it does:

Oinew=emiminewiOi+em~ijminewP~ijVjinew.O_i^{\text{new}} = \frac{e^{m_i-m_i^{\text{new}}}\ell_i \cdot O_i + e^{\tilde{m}_{ij}-m_i^{\text{new}}} \tilde{P}_{ij}V_j}{\ell_i^{\text{new}}}.

The exact notation in the paper is slightly more careful, but the message is the same: the old partial output must be renormalized if the running max changes, then combined with the new contribution. This is why FlashAttention needs to retain the per-row statistics $(m, \ell)$. Without them, you could not merge blocks exactly. ### 3.5 Why the backward pass uses recomputation In standard training, the backward pass often reuses large saved intermediates. For attention, that would normally include the full attention probabilities $P$. But FlashAttention does not want to store $P$ because that would bring back quadratic memory and much of the HBM traffic. So the authors make a deliberate tradeoff: * save only the output $O$ and softmax statistics $(m, \ell)$, * recompute needed local attention quantities on-chip during backward. This increases FLOPs, but the paper argues that it is still faster overall because reading a giant $N \times N$ tensor from HBM is more expensive than recomputing local values in SRAM. This is one of those cases where “extra computation” is actually the faster systems choice. The paper compares this to checkpointing, but FlashAttention’s backward derivation is cleaner than generic checkpointing because it analytically simplifies what must be recomputed and what can be summarized. ### 3.6 IO complexity analysis The theoretical contribution is not decoration; it supports the systems claim. The paper shows: * Standard attention requires $\Theta(Nd + N^2)$ HBM accesses. * FlashAttention requires $\Theta(N^2 d^2 / M)$ HBM accesses, where $M$ is SRAM size. For practical GPU regimes, this is substantially smaller. Theorem 2 and the related lower bound argue that FlashAttention is asymptotically optimal across a range of SRAM sizes for exact attention. That matters because it says the method is not just “a clever engineering trick that happened to work on one kernel.” It is close to the best possible under the stated model. A beginner may ask: “Why is the complexity still quadratic in $N$?” Because exact dense attention still compares all token pairs. FlashAttention does not repeal that fact. What it does is reduce *data movement cost* and eliminate the quadratic-size materialized matrix. So it attacks the practical bottleneck without changing the semantics of dense attention. ### 3.7 Block-sparse FlashAttention The paper also extends the kernel to a block-sparse pattern. Here the attention matrix is not fully dense; only certain blocks are computed. The same IO-aware design is reused, but now only nonzero blocks are visited. This is no longer exact dense attention, but it is an efficient approximate/sparse variant built on the same systems principle. The authors show that block-sparse FlashAttention can be even faster and can scale to contexts like 64K, with good results on pathfinding tasks. I think this extension is strategically important. It shows FlashAttention is not a one-off optimization for one exact kernel. It is a reusable primitive for a whole family of attention implementations. ## 4. Experiment Setup ### 4.1 Hardware and implementation environment The paper centers on modern GPUs, especially NVIDIA A100. The hardware discussion is essential because FlashAttention’s argument depends on the gap between on-chip SRAM bandwidth and HBM bandwidth. The paper cites roughly terabytes-per-second bandwidth for HBM and an order-of-magnitude higher effective bandwidth for SRAM, though SRAM capacity is tiny. Implementation-wise, the authors write a custom CUDA kernel so they can control data movement directly. This is an important engineering cost. You do not get FlashAttention’s benefits merely by writing a mathematically equivalent PyTorch expression and hoping the compiler figures it out. ### 4.2 Models and tasks The evaluation is broad rather than narrow. The authors test: * **BERT-large** training on Wikipedia, comparing with the NVIDIA MLPerf 1.1 implementation. * **GPT-2** small and medium on OpenWebText, comparing with HuggingFace and Megatron-LM implementations. * **Long Range Arena (LRA)** tasks, including ListOps, Text, Retrieval, Image, and Pathfinder. * **Long document classification** on MIMIC-III and ECtHR. * **Path-X and Path-256**, famous long-context stress tests. * Single-kernel benchmarks for runtime and memory across sequence lengths. This combination is good experimental design. It checks both micro-level kernel behavior and macro-level training impact. ### 4.3 Metrics The paper does not rely on one vanity metric. It reports: * runtime of attention forward + backward, * attention memory usage, * end-to-end training time, * perplexity for language modeling, * accuracy or F1 for downstream tasks. This is exactly the right way to evaluate a systems paper in ML. If you only report FLOPs or kernel speed, you might miss whether the optimization matters for actual model training. If you only report task accuracy, you might hide systems tradeoffs. FlashAttention reports both. ## 5. Results & Analysis ### 5.1 Figure 1: why avoiding materialization matters Figure 1 is arguably the most educational figure in the paper. The left panel shows the dataflow change: standard attention writes the large attention matrix to HBM, whereas FlashAttention streams blocks through SRAM. The right panel shows a large speedup on GPT-2 attention computation—reported as up to 7.6× relative to the PyTorch implementation. My reading is that Figure 1 does more than show a speedup. It reframes the problem. If someone still thinks attention optimization is mostly about changing the $QK^T$ formula, they have missed the lesson. The true bottleneck here is that the default implementation drags an enormous intermediate through slow memory. ### 5.2 Figure 2: HBM accesses track runtime better than FLOPs Figure 2 is the most important systems-validation result. The paper shows that FlashAttention can have more FLOPs than standard attention because of recomputation, yet still run faster because HBM accesses are much smaller. This directly supports the main thesis: for this workload on this hardware, IO—not raw arithmetic count—is the dominant factor. This is why the paper aged so well. Later GPU kernel work across ML repeatedly rediscovered the same principle: arithmetic is cheap, memory movement is expensive, and good kernels are often about keeping data on-chip. ### 5.3 BERT-large: real end-to-end training speedup On BERT-large, the paper reports training time to target masked-language-modeling accuracy of 72.0%: * NVIDIA MLPerf 1.1 implementation: **20.0 ± 1.5 minutes** * FlashAttention: **17.4 ± 1.4 minutes** That is about a **15% speedup** on 8×A100 GPUs. This result is more meaningful than an isolated kernel benchmark because BERT training includes many operations besides attention. So when an attention optimization still improves end-to-end training by 15%, it means the optimization is genuinely moving the bottleneck. ### 5.4 GPT-2: strong speedups without changing perplexity Table 2 is especially convincing. For GPT-2 small and medium on OpenWebText: * GPT-2 small: * HuggingFace: 9.5 days * Megatron-LM: 4.7 days * FlashAttention: 2.7 days * GPT-2 medium: * HuggingFace: 21.0 days * Megatron-LM: 11.5 days * FlashAttention: 6.9 days Perplexity remains essentially unchanged: * GPT-2 small: 18.2 across implementations * GPT-2 medium: 14.2–14.3 across implementations This is exactly what we want from an exact attention optimization. Same model quality, much lower training time. One subtle point: the comparison is not against a toy baseline. Megatron-LM was already an optimized systems implementation. So beating it by around 1.7–1.8× is a serious result. ### 5.5 Long Range Arena: exact attention becomes surprisingly competitive Many people would assume approximate attention must dominate on long-range tasks. Table 3 complicates that story. On LRA, FlashAttention reaches an average score of **59.8**, compared with **59.3** for the standard Transformer, while also reporting around **2.4× speedup**. Block-sparse FlashAttention is even faster at **2.8×** while keeping similar average quality (**59.6**). The lesson is not that approximation is useless. The lesson is that naive dense-attention baselines often looked bad partly because their implementation was bad. Once exact attention is implemented well, the crossover point where approximations become clearly better may move much farther out than many expected. ### 5.6 Longer contexts produce better models Table 4 is one of my favorite results because it connects systems efficiency to model quality. FlashAttention allows GPT-2 small to train with longer context lengths while still staying fast: * Megatron-LM, 1K context: 18.2 perplexity, 4.7 days * FlashAttention, 1K context: 18.2 perplexity, 2.7 days * FlashAttention, 2K context: 17.6 perplexity, 3.0 days * FlashAttention, 4K context: 17.5 perplexity, 3.6 days So with FlashAttention, a 4K-context GPT-2 is still **30% faster** than the 1K Megatron baseline while also improving perplexity by about **0.7**. That is a beautiful systems-and-modeling outcome. Efficiency is not merely saving money; it changes which model configurations are feasible. ### 5.7 Long document tasks and Path-X / Path-256 On long document classification, increasing context with FlashAttention improves micro-F1 substantially: * On **MIMIC-III**, length 16K reaches **57.1** vs **52.8** at length 512. * On **ECtHR**, length 8K reaches **80.7** vs **72.2** at length 512. These are large practical gains. Then comes the headline-grabbing result on pathfinding: * FlashAttention achieves **61.4** on **Path-X**. * Block-sparse FlashAttention achieves **63.1** on **Path-256**. The paper emphasizes these were the first Transformer results above chance on these benchmarks. This is a powerful example of a systems primitive enabling a capability result, not just a benchmark win. ### 5.8 Figure 3: runtime and memory scaling Figure 3 shows two important curves: * runtime of forward + backward, * memory footprint. FlashAttention’s memory grows linearly in sequence length rather than requiring the huge quadratic attention matrix. The paper reports up to **20×** better memory efficiency than exact attention baselines, and notes that many competing methods run out of memory before 64K on an A100, while FlashAttention still fits. This matters because out-of-memory failure is a hard constraint, not a mild inconvenience. Many modeling ideas are dead on arrival if the kernel cannot fit the problem. ## 6. Limitations & Boundary Conditions ### 6.1 It does not remove quadratic pairwise interaction FlashAttention is often misunderstood as “solving quadratic attention.” That is not correct. Exact dense attention still has to consider all token pairs. The paper reduces IO complexity and memory footprint, but it does not magically turn exact dense attention into linear-time modeling. So if your sequence length becomes enormous, approximate or structured methods may still be necessary. ### 6.2 The kernel engineering burden is real The paper openly admits that the implementation requires custom CUDA work. That is a genuine limitation. Research code and production code are different things; maintaining highly optimized kernels across GPU architectures is hard. This is one reason FlashAttention later became a software ecosystem, not just a paper. To make the idea useful, the community had to package it, maintain it, and integrate it into frameworks. ### 6.3 Benefits depend on hardware and workload regime The speedup comes from the gap between on-chip and off-chip memory behavior. If the hardware changes, the sweet spots and block sizes change too. The paper’s conclusions are highly relevant for modern GPUs, but one should not blindly assume identical gains on every accelerator. Also, some approximate methods can overtake FlashAttention at sufficiently long sequence lengths because they change the asymptotic compute itself. The paper even shows crossover behavior in benchmarking. ### 6.4 Numerical stability must be engineered carefully The online softmax update is exact, but only if implemented with the right rescaling and stable max-tracking. This is not the kind of kernel where one casual sign error merely hurts speed; it can silently damage correctness. So one practical boundary condition is engineering maturity. You want the battle-tested implementation, not a rushed re-implementation unless you have a strong reason. ### 6.5 Exact attention is not always the globally best design choice Even though FlashAttention makes exact attention much more attractive, there remain scenarios where locality, recurrence, state-space models, retrieval, or sparse patterns are a better modeling decision. The paper’s contribution is to improve one important primitive, not to prove that every long-context problem should use dense attention forever. ## 7. Reproducibility & Practical Notes ### 7.1 Why this paper is highly reproducible by systems-paper standards The paper is unusually reproducible for three reasons: * the core algorithm is specified clearly, * the IO argument is formalized with theorems, * the authors released code. That said, “reproducible” does not mean “easy.” Reproducing the exact CUDA-level performance of a top kernel is much harder than reproducing a model-training recipe from Python. ### 7.2 What a practitioner needs to use FlashAttention in production If you are a practitioner rather than a kernel author, you usually do **not** want to rewrite the algorithm from scratch. Instead you want: * a maintained FlashAttention library or framework integration, * GPU support matching your deployment fleet, * correctness tests against a trusted dense-attention reference, * profiling tools to verify the kernel is actually selected, * careful handling of masking, dropout, causal attention, and mixed precision. In other words, operational success depends as much on software integration and testing discipline as on the paper itself. ### 7.3 When FlashAttention is most compelling In my opinion, FlashAttention is most compelling when at least one of the following is true: * sequence length is moderately large to large, * memory pressure is blocking batch size or context length, * attention is a noticeable share of runtime, * you want exact dense attention semantics, * you need a production-quality primitive rather than a research-only approximation. This is why it became standard infrastructure for LLM training and inference stacks. ### 7.4 Practical mental model for beginners If I had to explain the entire paper to a careful beginner in one paragraph, I would say this: > Standard attention is like writing a giant temporary notebook page for every pairwise token interaction, then repeatedly carrying that notebook between a big cupboard and your desk. FlashAttention instead keeps a small working set on the desk, updates the answer as it goes, and only stores the final result plus a few small bookkeeping numbers. The math result is the same, but the walking is much smaller, so the whole job finishes faster. That is, in spirit, the entire contribution. ## References 1. Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. *FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness*. arXiv:2205.14135, 2022. 2. Ashish Vaswani et al. *Attention Is All You Need*. NeurIPS, 2017. 3. Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. *Reformer: The Efficient Transformer*. ICLR, 2020. 4. Sinong Wang et al. *Linformer: Self-Attention with Linear Complexity*. arXiv, 2020. 5. Tri Dao. FlashAttention repository and follow-up implementations. --- *Review written on 2026-03-14.*