Understand how FlashAttention cuts auxiliary attention memory from O(n²) to O(n) with tiling and online softmax, and analyze its IO complexity.
The previous chapter showed how prefix caching avoids repeating prefill work across requests. FlashAttention attacks a lower layer of the same bottleneck: it makes each dense attention computation move far less data through GPU memory.
FlashAttention is an attention implementation that reduces memory traffic without changing the mathematical attention operator. This chapter explains why exact attention can be slow on GPUs, then shows how tiling and recomputation can make long-context execution practical.
Imagine trying to answer a support question where every new token must compare itself against every prior token in a long order timeline: customer messages, carrier scans, refund notes, warehouse exceptions, and policy snippets. As the timeline gets longer, the number of connections you need to track explodes. This is exactly the problem modern AI models face with the attention mechanism. Dense attention still performs a quadratic number of score comparisons as input grows. In an implementation that materializes all score or probability values, its auxiliary storage is quadratic too: double the input length, and those matrices need four times the memory.
FlashAttention reorganizes how dense attention is computed at the hardware level. It computes the same operator, but doesn't materialize the full score or probability matrix in HBM (High Bandwidth Memory, the large GPU RAM pool).[1] That lowers auxiliary attention memory and HBM traffic. When an eligible FlashAttention kernel is a bottleneck improvement for the workload, it can increase throughput or make a longer context fit.
Before we look at the fix, let's ground the problem with a tiny example. In the scaled dot-product attention article, you saw that each query token scores every key token, normalizes those scores with softmax, and blends the corresponding value vectors.
Suppose you have only three support-ticket tokens and a head dimension of two:
The score matrix is . That is nine numbers. For three tokens, this is trivial. But for real sequences, the story changes fast.
For , , in 16-bit floating point (FP16):
That product is exactly bytes, which is 32 GiB or about 34 GB depending on whether you count in binary or decimal units. If an attention implementation saves this intermediate, it consumes a large part of an 80 GB A100 before counting model weights, other activations, or gradients.
Run this calculation before blaming model weights for an out-of-memory error:
1batch, heads, sequence, bytes_per_value = 8, 32, 8192, 2
2score_bytes = batch * heads * sequence * sequence * bytes_per_value
3
4print(f"score values: {batch * heads * sequence * sequence:,}")
5print(f"binary size: {score_bytes / 1024**3:.2f} GiB")
6print(f"decimal size: {score_bytes / 1000**3:.2f} GB")1score values: 17,179,869,184
2binary size: 32.00 GiB
3decimal size: 34.36 GBThe key insight behind FlashAttention is understanding where data lives. A GPU is not a flat memory space. It has layers:
Memory hierarchy intuition: Think of on-chip SRAM as the packing bench: tiny, but right beside the worker. HBM is warehouse storage: much bigger, but every trip costs time. A materializing attention baseline keeps walking back to storage because the full score matrix doesn't fit on the bench. FlashAttention keeps only active tiles on the bench at a time.
In the original FlashAttention paper, the motivating A100 numbers are roughly 20 MB of aggregate on-chip SRAM at about 19 TB/s bandwidth versus 40 GB of HBM at about 1.5 TB/s.[1] A kernel doesn't get to use that full 20 MB as one giant scratchpad, though: practical tile sizes are bounded by much smaller per-SM shared-memory and register budgets. The exact chip layout varies by GPU generation, but the qualitative gap is the same: on-chip memory is tiny and fast, off-chip memory is large and much slower to revisit.
| Memory tier | Typical role in attention | Capacity / bandwidth intuition |
|---|---|---|
| On-chip SRAM | Hold the current Q/K/V tiles and running softmax statistics | Tiny, but fast enough to reuse the same tile many times |
| HBM | Hold Q, K, V, O, model weights, and other activations | Much larger, but expensive to touch for every intermediate |
| CPU DRAM | Host memory outside the GPU | Larger still, but not suitable for the inner loop of an attention kernel |
The flow above shows the three-tier memory pyramid. When a baseline writes score and probability matrices to HBM, those round-trips can dominate execution. FlashAttention pays for local bookkeeping to avoid them.[1]
In a materializing baseline, the score and probability matrices are written to HBM and read again. A fused backend may already avoid those intermediates, so the useful comparison is FlashAttention versus the backend the system would otherwise execute, not versus every call named "attention."
Instead of materializing the full attention matrix, FlashAttention splits the problem into three simultaneous ideas:
This block-wise processing strategy avoids quadratic auxiliary score and probability storage. The input, output, and saved row-statistic tensors still scale with sequence length.
The illustration below contrasts the memory access patterns. A materializing baseline writes the full score matrix to HBM, reads it back for softmax, writes the probability matrix, and reads it again for the final multiply. FlashAttention streams small tiles through SRAM and saves only row-wise statistics.
The diagram shows the two paths side by side. On the left, the materializing path makes multiple round-trips to HBM for the intermediates. On the right, FlashAttention processes small blocks in fast SRAM and writes back the output and compact row statistics.
The illustration below visualizes the data flow: Q, K, and V blocks stream from large HBM into the small SRAM workspace, where only the current tile and running statistics are kept.
During forward execution, , , and are read from HBM. The kernel writes final output and, for training, compact row statistics needed by backward. It doesn't write the full score or probability matrices.
Tile size is bounded by on-chip capacity. This simplified payload check counts four FP16 tile-shaped arrays (Q, K, V, and a partial output), while production kernels also budget for statistics, registers, and implementation overhead:
1block_rows, head_dimension, arrays, bytes_per_value = 128, 64, 4, 2
2payload_bytes = block_rows * head_dimension * arrays * bytes_per_value
3
4print(f"simplified tile payload: {payload_bytes / 1024:.0f} KiB")
5print("also budget: row statistics, registers, and kernel overhead")1simplified tile payload: 64 KiB
2also budget: row statistics, registers, and kernel overheadImagine you're summarizing 1,000 package-risk scores, but the packing bench only fits 50 labels at a time. You don't need to hold all 1,000 scores to compute the final normalization. You keep running statistics and update them as each block arrives.
Online softmax works the same way: instead of needing the full attention matrix to compute softmax, it maintains running statistics (max and sum) and updates them block by block, without introducing any approximation.
Let's walk through a tiny numeric example before we show the general formula.
Suppose a query token sees key scores .
A materializing baseline would store the full score matrix in HBM just to perform that four-step process for every row.
Now pretend the SRAM bench only fits two scores at a time. We split the scores into Block A and Block B .
Processing Block A:
Processing Block B:
After both blocks, the final output for this query row is , which is exactly the same result as standard softmax. The difference is that we never held all three scores in the fast workspace at once.
This small program checks the rescaling rule with scalar values, independent of any GPU kernel:
1import math
2
3scores_a, values_a = [1.0, 2.0], [10.0, 20.0]
4scores_b, values_b = [0.5], [40.0]
5
6def local_state(scores, values):
7 max_score = max(scores)
8 weights = [math.exp(score - max_score) for score in scores]
9 return max_score, sum(weights), sum(weight * value for weight, value in zip(weights, values))
10
11m_a, l_a, n_a = local_state(scores_a, values_a)
12m_b, l_b, n_b = local_state(scores_b, values_b)
13m = max(m_a, m_b)
14l = math.exp(m_a - m) * l_a + math.exp(m_b - m) * l_b
15n = math.exp(m_a - m) * n_a + math.exp(m_b - m) * n_b
16online = n / l
17
18all_scores = scores_a + scores_b
19all_values = values_a + values_b
20dense_weights = [math.exp(score - max(all_scores)) for score in all_scores]
21dense = sum(w * v for w, v in zip(dense_weights, all_values)) / sum(dense_weights)
22
23print(f"online output: {online:.6f}")
24print(f"dense output: {dense:.6f}")
25print(f"match: {abs(online - dense) < 1e-12}")1online output: 20.492649
2dense output: 20.492649
3match: TrueFor each new block of scores and value vectors :
Where is the running max score (for numerical stability), is the running softmax denominator, is the running unnormalized numerator accumulator, and is the normalized output.
The rescaling terms ensure previous results stay correct even though the max changed. This is the mathematical trick that eliminates the need for a second pass over the full row.[2]
The diagram above contrasts the two approaches. A materializing baseline makes multiple round-trips to HBM for the full matrix. FlashAttention performs tile-local score and softmax work on-chip, writing output and compact saved statistics to HBM.
The following function demonstrates the core logic of FlashAttention. It takes the Query, Key, and Value matrices along with a specified block size to load into SRAM. It returns the same dense attention output as a materialized reference implementation, but iterates through tiles to avoid full score and probability matrices.
1import torch
2import math
3
4def flash_attention(
5 Q: torch.Tensor,
6 K: torch.Tensor,
7 V: torch.Tensor,
8 block_size: int = 256
9) -> torch.Tensor:
10 """
11 Simplified forward-pass sketch of FlashAttention.
12
13 Args:
14 Q: Query tensor of shape (n, d)
15 K: Key tensor of shape (n, d)
16 V: Value tensor of shape (n, d)
17 block_size: Size of blocks to load into SRAM
18
19 Returns:
20 O: Output tensor of shape (n, d)
21 """
22 n, d = Q.shape
23 O = torch.zeros_like(Q)
24
25 # Outer loop: iterate over Q blocks and keep the current output tile on-chip
26 for i in range(0, n, block_size):
27 Qi = Q[i:i+block_size] # Load Q block to SRAM
28
29 # Initialize running statistics for this Q block
30 Oi = torch.zeros_like(Qi) # Accumulator
31 li = torch.zeros(Qi.shape[0], 1, device=Q.device) # Denominator
32 mi = torch.full((Qi.shape[0], 1), -float('inf'), device=Q.device) # Max
33
34 # Inner loop: Iterate over K, V blocks
35 for j in range(0, n, block_size):
36 Kj = K[j:j+block_size] # Load K block to SRAM
37 Vj = V[j:j+block_size] # Load V block to SRAM
38
39 # Compute local attention scores (in SRAM!)
40 # Shape: (block_size_q, block_size_k)
41 Sij = Qi @ Kj.T / math.sqrt(d)
42
43 # Online softmax update logic
44 m_ij = Sij.max(dim=-1, keepdim=True).values
45 m_new = torch.max(mi, m_ij)
46
47 exp_old_scale = torch.exp(mi - m_new)
48 exp_new = torch.exp(Sij - m_new)
49
50 # Update output accumulator (unnormalized)
51 Oi = exp_old_scale * Oi + exp_new @ Vj
52
53 # Update running statistics
54 li = exp_old_scale * li + exp_new.sum(dim=-1, keepdim=True)
55 mi = m_new
56
57 # Normalize by the final denominator
58 O[i:i+block_size] = Oi / li
59
60 return O
61
62# Quick sanity check: FlashAttention should match reference attention on a small tensor
63if __name__ == "__main__":
64 torch.manual_seed(0)
65 n, d = 64, 32
66 Q = torch.randn(n, d)
67 K = torch.randn(n, d)
68 V = torch.randn(n, d)
69
70 # Reference attention (materializes full n x n matrix)
71 S = Q @ K.T / math.sqrt(d)
72 P = torch.softmax(S, dim=-1)
73 expected = P @ V
74
75 # FlashAttention sketch (tiling, no full score/probability materialization)
76 got = flash_attention(Q, K, V, block_size=16)
77
78 max_difference = (expected - got).abs().max().item()
79 print("Matches dense attention within 1e-5:", max_difference < 1e-5)1Matches dense attention within 1e-5: TrueThis sketch keeps Oi as an unnormalized numerator accumulator and divides by li once per Q tile at the end. The small test at the bottom proves the key claim: for a toy tensor, the tiled loop produces the same result as standard dense attention, with a maximum difference below .
Common mistake: Beginners sometimes think
Oiis already normalized inside the inner loop. It isn't. The division bylihappens only after every K/V block for that Q tile has been processed. If you normalize early, you lose the exact rescaling that makes online softmax correct.
Production kernels implement the same algebra much more aggressively, while also handling batching, multiple heads, masks, and dropout.
A materializing training implementation can save the massive attention matrices and from the forward pass to compute gradients during the backward pass. Those saved intermediates can become a major source of out-of-memory (OOM) errors.
FlashAttention solves this by recomputing the needed score and probability tiles during the backward pass instead of storing them all from the forward pass.[1] Because it saves only row-wise softmax statistics such as the running max and denominator , the saved attention state grows as rather than .
In a warehouse, you could photocopy every shipping label at each packing station and file those copies in a giant archive. That is the materializing baseline. Or you could keep a slim logbook with running totals and reprint any label you need from the original order data. That is FlashAttention.
Recomputation isn't free. It adds arithmetic in backward. The point of the FlashAttention paper is that, on the evaluated GPU workloads, avoiding much larger HBM reads and writes more than paid for that arithmetic cost.[1] Measure the trade-off on the model shape and hardware you deploy.
This calculator makes the saved-state difference concrete. It intentionally counts one score matrix and two row statistics, not every tensor in training:
1batch, heads, sequence, bytes_per_value = 8, 32, 8192, 2
2materialized_scores = batch * heads * sequence * sequence * bytes_per_value
3row_stats = batch * heads * sequence * 2 * bytes_per_value
4
5print(f"one saved score matrix: {materialized_scores / 1024**3:.2f} GiB")
6print(f"two row statistics: {row_stats / 1024**2:.2f} MiB")
7print(f"size ratio: {materialized_scores / row_stats:,.0f}x")1one saved score matrix: 32.00 GiB
2two row statistics: 8.00 MiB
3size ratio: 4,096x| Property | Materializing baseline | FlashAttention |
|---|---|---|
| Saved attention state for backward | Store and explicitly: | Store row-wise statistics such as and : |
| Backward strategy | Read large intermediates from HBM | Recompute local tiles from plus saved stats |
| Trade-off | Less recomputation, much higher memory | More recomputation, much lower memory |
FlashAttention changes the IO complexity by tiling the computation. Let denote the amount of fast SRAM available to hold a tile's working set.
| Property | Materializing baseline | FlashAttention |
|---|---|---|
| Auxiliary attention memory | Materialize and : | Keep row stats and the current output tile: |
| FLOPs | (same) | |
| HBM reads/writes | ||
| Exact | Yes | Yes |
| Reported wall-clock speed | Baseline | Up to 2-4x faster in evaluated paper workloads[1] |
FLOPs stands for floating-point operations. The memory row here refers to the extra state created by the attention kernel itself, not the shared , , , and tensors that both approaches still need to hold. FlashAttention doesn't reduce the asymptotic mathematical work required for dense attention: both paths perform operations. Its main algorithmic advantage is avoiding score and probability transfers to and from slow HBM. Tile size, scheduling, datatype, and hardware still affect observed speed.
By keeping the working memory constrained to the on-chip SRAM budget , the IO complexity drops from a quadratic term to a fraction of that size. This means the GPU spends less time waiting for data to arrive from memory and more time keeping its compute cores busy.
Because FlashAttention computes the same dense attention operator as the reference formula (thanks to the online softmax trick), it doesn't change the model's mathematical attention rule. Different floating-point operation order can still cause small numeric differences.
The following numbers distinguish score-matrix scaling from row-statistic scaling:
1base_sequence = 1024
2for sequence in [1024, 2048, 4096, 8192, 16384]:
3 materialized_relative = (sequence / base_sequence) ** 2
4 row_stats_relative = sequence / base_sequence
5 print(
6 f"{sequence:>5} tokens: materialized={materialized_relative:>5.0f}x, "
7 f"row-stats={row_stats_relative:>2.0f}x"
8 )11024 tokens: materialized= 1x, row-stats= 1x
2 2048 tokens: materialized= 4x, row-stats= 2x
3 4096 tokens: materialized= 16x, row-stats= 4x
4 8192 tokens: materialized= 64x, row-stats= 8x
516384 tokens: materialized= 256x, row-stats=16xFor autoregressive transformers, attention is causal: token can only attend to tokens . FlashAttention handles this efficiently without materializing a dense causal mask in HBM.
If a block of K/V tokens is entirely in the "future" relative to a Q block, the entire block multiplication is skipped. No compute wasted.
For blocks that straddle the causal boundary, FlashAttention applies the mask after computing scores but before the softmax update. The masked positions are set to .
In other words, causal masking is folded into the tile schedule itself: future tiles are skipped, and diagonal tiles apply an in-tile mask before the online softmax update. That reduces wasted work, but the exact speedup depends on sequence length, tile shape, and kernel implementation rather than being a guaranteed 2x.
The same tiled structure also adapts well to local windowed attention. Tiles that fall completely outside the attention window can be skipped before doing the matrix multiply.
You can audit causal tile decisions without any GPU code:
1tiles = 4
2counts = {"past": 0, "boundary": 0, "future": 0}
3
4for query_tile in range(tiles):
5 for key_tile in range(tiles):
6 if key_tile < query_tile:
7 decision = "past"
8 elif key_tile == query_tile:
9 decision = "boundary"
10 else:
11 decision = "future"
12 counts[decision] += 1
13
14print(counts)
15print(f"computed tiles: {counts['past'] + counts['boundary']} of {tiles * tiles}")1{'past': 6, 'boundary': 4, 'future': 6}
2computed tiles: 10 of 16Since the introduction of the original algorithm, the architecture has evolved to better use modern GPU features and achieve higher theoretical throughput.
FlashAttention-2[3] improves on the original by optimizing the hardware execution:
FlashAttention-3[4] is designed to target the advanced capabilities of the Hopper architecture (like the H100):
These subsequent iterations demonstrate that while the mathematical core of tiling and online softmax remains the same, hardware-aware kernel optimization is critical for maximizing performance. Building these highly optimized kernels often requires low-level CUDA programming. Higher-level languages like Triton have also made it much easier to write custom memory-efficient attention kernels without dropping all the way to raw CUDA C++.
During model training, saving quadratic attention intermediates can sharply restrict the maximum sequence length a model can process. As sequence length increases, a materializing baseline may run out of memory even when a fused attention path can still fit.
Common mistake: Assuming FlashAttention is only a "long-sequence hack." The original paper reports a 15% end-to-end training speedup for BERT-large at sequence length 512, so even a moderate evaluated sequence can benefit when attention IO matters.[1]
Compared with a baseline that saves full attention intermediates, FlashAttention's auxiliary attention memory grows linearly rather than quadratically. That can enable longer sequences and improve throughput even while a materializing baseline still fits in memory.[1][3]
| Source | Workload | Reported result |
|---|---|---|
| FlashAttention (2022) | BERT-large, sequence length 512 | 15% end-to-end training speedup[1] |
| FlashAttention (2022) | GPT-2, sequence length 1K | 3× speedup[1] |
| FlashAttention (2022) | Long Range Arena, sequence length 1K-4K | 2.4× speedup[1] |
| FlashAttention-2 (2023) | GPT-style training on A100 | Up to 225 TFLOPs/s per GPU, 72% model FLOPs utilization[3] |
The important point isn't one magic benchmark number. It's that once the attention kernel stops writing giant intermediates to HBM, longer-sequence training becomes much more practical.
During inference, FlashAttention helps most when the workload still looks like dense attention over many prompt tokens:
FlashAttention has its biggest impact when attention itself is the bottleneck. That's usually training and prefill, not single-token decode.
This shape check shows why prefill creates far more score work per request than one decode step:
1prompt_tokens = 8192
2prefill_scores = prompt_tokens * prompt_tokens
3decode_scores = 1 * prompt_tokens
4
5print(f"prefill scores: {prefill_scores:,}")
6print(f"one decode step scores: {decode_scores:,}")
7print(f"ratio: {prefill_scores // decode_scores:,}x")1prefill scores: 67,108,864
2one decode step scores: 8,192
3ratio: 8,192xThe algorithmic idea is general, but the fastest kernels are hardware-specific. FlashAttention-2 describes better parallelism and work partitioning for modern GPUs, while FlashAttention-3 is a Hopper-focused redesign that targets features such as TMA, WGMMA, and FP8 attention.[3][4]
Availability in an application depends on its framework build, device, datatype, tensor shapes, and attention features. Treat backend selection as something to verify, not something to infer from the model name.
In modern deep learning frameworks, you rarely implement FlashAttention from scratch. PyTorch exposes torch.nn.functional.scaled_dot_product_attention (SDPA), which may choose an optimized CUDA implementation when the inputs and build support it. Its sdpa_kernel context manager lets you select permitted implementations while testing or profiling. Eligibility and fallback behavior depend on the installed PyTorch build, device, datatype, layout, and attention features, so consult the documentation for that build and measure the actual path.[5]
First validate operator behavior with a small CPU example. This checks causality and output shape, not FlashAttention dispatch:
1import torch
2import torch.nn.functional as F
3
4torch.manual_seed(0)
5query = torch.randn(1, 2, 4, 8)
6key = torch.randn(1, 2, 4, 8)
7value = torch.randn(1, 2, 4, 8)
8
9output = F.scaled_dot_product_attention(
10 query, key, value,
11 is_causal=True,
12 dropout_p=0.0,
13)
14print(f"output shape: {tuple(output.shape)}")
15print(f"finite values: {torch.isfinite(output).all().item()}")1output shape: (1, 2, 4, 8)
2finite values: TrueOn CUDA hardware, backend restriction is useful as a test probe. This snippet is intentionally not marked runnable because it requires a suitable installed CUDA build and GPU:
1import torch
2import torch.nn.functional as F
3from torch.nn.attention import SDPBackend, sdpa_kernel
4
5Q = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16)
6K = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16)
7V = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16)
8
9with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
10 output = F.scaled_dot_product_attention(Q, K, V, is_causal=True, dropout_p=0.0)If a model library offers a FlashAttention request flag, requesting it isn't evidence that the fast path ran. Record a before/after measurement and backend evidence:
1verification = {
2 "requested_backend": "flash_attention",
3 "operator_correctness_checked": True,
4 "profiler_shows_selected_kernel": False,
5 "latency_measured": False,
6}
7
8active = (
9 verification["operator_correctness_checked"]
10 and verification["profiler_shows_selected_kernel"]
11 and verification["latency_measured"]
12)
13print(f"enough evidence to claim speedup: {active}")
14print("next check: capture backend/profiler output on target GPU")1enough evidence to claim speedup: False
2next check: capture backend/profiler output on target GPUFor example, Hugging Face Transformers exposes a model-load request in versions and models that support the corresponding integration:
1import torch
2from transformers import AutoModelForCausalLM
3
4model = AutoModelForCausalLM.from_pretrained(
5 "meta-llama/Llama-3.1-8B",
6 torch_dtype=torch.float16,
7 attn_implementation="flash_attention_2",
8 device_map="auto",
9)Symptom: You hear FlashAttention grouped with sparse or low-rank attention approximations and assume it drops some connections to save memory.
Cause: The word "efficient" often implies approximation in other contexts.
Fix: FlashAttention is exact. Thanks to the online softmax trick, it computes the same dense attention formula without using sparse or low-rank shortcuts. Numeric outputs can differ slightly from a reference implementation because floating-point operations are associated in a different order, but the mathematical operator is the same. If you need proof, run the small PyTorch test from the pseudocode section and check that the max difference is near zero.
Symptom: You claim in an interview or code review that FlashAttention cuts FLOPs.
Cause: It's natural to equate "faster" with "fewer operations."
Fix: The forward attention computation still has floating-point operations (FLOPs). The speedup comes from reduced memory operations (IO), not from changing dense attention into a cheaper mathematical operator. In training, the backward pass can perform more operations because it recomputes tiles. The win is that compute is cheap and memory movement is expensive.
Symptom: You only compare total VRAM capacity and miss why attention still runs slowly on large GPUs.
Cause: HBM, on-chip SRAM, shared memory, and registers have very different capacity and bandwidth profiles.
Fix: Ask where each tensor lives and how often it crosses the HBM/SRAM boundary. FlashAttention wins because it keeps Q/K/V tiles and softmax state on-chip long enough to reuse them, then writes only the final output and row statistics back to HBM.
Symptom: You explain FlashAttention as if it changes attention from quadratic time to linear time.
Cause: The memory table and the FLOP table get mixed together.
Fix: Keep the dimensions separate. Dense attention still does quadratic compute in sequence length. FlashAttention reduces HBM reads and writes, so wall-clock time improves when the workload is memory-bound.
Symptom: You tile attention but normalize each block independently.
Cause: The running max and denominator updates look like an implementation detail.
Fix: Online softmax is the correctness mechanism. The running max rescales old contributions when a later tile contains a larger score, and the running denominator keeps all blocks normalized against the same global row.
Symptom: You skip enabling it on short-context models.
Cause: The OOM headlines make FlashAttention look like a long-sequence-only tool.
Fix: While it enables long sequences by avoiding memory limits, it provides substantial speedups even on shorter sequences because it reduces HBM access. The 15% BERT-large speedup at 512 tokens is a clear example of short-sequence gains.[1]
Symptom: You avoid FlashAttention because you assume it requires low-level GPU programming.
Cause: The original paper describes kernel-level details, which can give the impression that users must write CUDA.
Fix: Use a framework SDPA API or supported model integration, then check backend selection and measure on the target GPU. A request flag is configuration, not proof that an optimized kernel ran.
After working through this chapter, you should be able to:
FlashAttention avoids storing the full score and probability matrices from the forward pass. It saves row-wise softmax statistics, then recomputes local attention tiles during backpropagation. That trades extra compute for much lower HBM traffic and reduces saved attention state from to .
Yes. The tiled scheduler can skip whole K/V blocks that lie entirely in the future and apply an in-tile mask only on diagonal boundary tiles. That keeps the mask inside the kernel schedule instead of materializing a dense causal mask, although the exact speedup still depends on sequence length, tile shape, and kernel implementation.
FlashAttention is not a universal kernel for every attention variant. Fast paths have hardware, dtype, head-dimension, layout, and masking constraints. Arbitrary sparse patterns or custom score modifications may need a different kernel family or a fallback implementation. On very short sequences, attention may not be memory-bound enough for the extra kernel complexity to matter.
FlashAttention-3 is a Hopper-specific redesign. It uses TMA, WGMMA, warp specialization, and FP8 support to overlap data movement with computation more aggressively. The paper reports up to 740 TFLOPs/s in FP16, about 75% utilization, and close to 1.2 PFLOPs/s in FP8 on H100-class hardware.[4]
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022
Online normalizer calculation for softmax.
Milakov, M. & Gimelshein, N. · 2018
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.
Dao, T. · 2023 · ICLR 2024
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.
Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. · 2024
torch.nn.functional.scaled_dot_product_attention
PyTorch Contributors · 2026