LeetLLM
LearnFeaturesBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Features
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 155 articles completed

🛠️Computing Foundations0/6
NumPy and Tensor ShapesCUDA for ML TrainingMPS & Metal for ML on MacData Structures for AISQL and Data ModelingAlgorithms for ML Engineers
📊Math & Statistics0/8
Gradients and BackpropVectors, Matrices & TensorsLinear Algebra for MLAdam, Momentum, SchedulersProbability for Machine LearningStatistics and UncertaintyDistributions and SamplingHypothesis Tests, Intervals, and pass@k
📚Preparation & Prerequisites0/13
Neural Networks from ScratchCNNs from ScratchTraining & BackpropagationSoftmax, Cross-Entropy & OptimizationRNNs, LSTMs, GRUs, and Sequence ModelingAutoencoders and VAEsThe Transformer Architecture End-to-EndLanguage Modeling & Next TokensFrom GPT to Modern LLMsPrompt Engineering FundamentalsCalling LLM APIs in ProductionFirst AI App End-to-EndThe LLM Lifecycle
🧮ML Algorithms & Evaluation0/11
Linear Regression from ScratchLogistic Regression and MetricsDecision Trees, Forests, and BoostingReinforcement Learning BasicsValidation and LeakageClustering and PCACore Retrieval AlgorithmsDecoding AlgorithmsExperiment Design and A/B TestingPyTorch Training LoopsDataset Pipelines and Data Quality
📦Production ML Systems0/6
Feature Engineering for Production MLBatch and Streaming Feature PipelinesGradient Boosted Trees in ProductionRanking and Recommendation SystemsForecasting and Anomaly DetectionMonitoring Predictive Models
🧪Core LLM Foundations0/8
The Bitter Lesson & ComputeBPE, WordPiece, and SentencePieceStatic to Contextual EmbeddingsPerplexity & Model EvaluationFile Ingestion for AIChunking StrategiesLLM Benchmarks & LimitationsInstruction Tuning & Chat Templates
🧰Applied LLM Engineering0/23
Dimensionality Reduction for EmbeddingsCoT, ToT & Self-Consistency PromptingFunction Calling & Tool UseMCP & Tool Protocol StandardsPrompt Injection DefenseResponsible AI GovernanceData Labeling and Human FeedbackEvaluating AI AgentsProduction RAG PipelinesHybrid Search: Dense + SparseReranking and Cross-Encoders for RAGRAG Evaluation for Reliable AnswersLLM-as-a-Judge EvaluationBias & Fairness in LLMsHallucination Detection & MitigationLLM Observability & MonitoringExperiment Tracking with MLflow and W&BMixed Precision TrainingModel Versioning & DeploymentSemantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Gateways, Routing, and FallbacksDesign an Automated Support Agent
🎓Portfolio Capstones0/9
Capstone: Delivery ETA PredictionCapstone: Product RankingCapstone: Demand ForecastingCapstone: Image Damage ClassifierCapstone: Production ML PipelineCapstone: Document QACapstone: Eval DashboardCapstone: Fine-Tuned ClassifierCapstone: Production Agent
🧠Transformer Deep Dives0/8
Sentence Embeddings & Contrastive LossEmbedding Similarity & QuantizationScaled Dot-Product AttentionVision Transformers and Image EncodersPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNMechanistic InterpretabilityDecoding Strategies: Greedy to Nucleus
🧬Advanced Training & Adaptation0/16
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleBuild GPT from Scratch LabContinued Pretraining for Domain ShiftSynthetic Data PipelinesSupervised Fine-Tuning PipelineDistributed Training: FSDP & ZeROLoRA & Parameter-Efficient TuningReward Modeling from Preference DataRLHF & DPO AlignmentConstitutional AI & Red TeamingRLVR & Verifiable RewardsKnowledge Distillation for LLMsModel Merging and Weight InterpolationPrompt Optimization with DSPyRecursive Language Models (RLM)
🤖Advanced Agents & Retrieval0/14
Vector DB Internals: HNSW & IVFAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access ControlStructured Output GenerationReAct & Plan-and-ExecuteGuardrails & Safety FiltersCode Generation & SandboxingComputer-Use / GUI / Browser AgentsHuman-in-the-Loop Agent ArchitectureAI Coding Workflow with AgentsAgent Memory & PersistenceAgent Failure & RecoveryMulti-Agent Orchestration
⚡Inference & Production Scale0/20
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionPrefix Caching and Prompt CachingFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceModel Parallelism for LLM InferenceModel Quantization: GPTQ, AWQ & GGUFLocal LLM DeploymentSLM Specialization & Edge DeploymentSpeculative DecodingLong Context Window ManagementContext EngineeringMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time ComputeAdvanced MLOps & DevOps for AIGPU Serving & AutoscalingA/B Testing for LLMs
🏗️System Design Capstones0/9
Content Moderation SystemCode Completion SystemMulti-Tenant LLM PlatformLLM-Powered Search EngineVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image GenerationReal-Time Voice AI AgentReasoning & Test-Time Compute
🎤AI Lab Interviewing0/4
AI Lab Coding Interview: Python SystemsAI Lab System Design InterviewAI Lab Behavioral InterviewAI Lab Technical Presentation
Back to Topics
LearnInference & Production ScaleFlashAttention & Memory Efficiency
🚀HardInference Optimization

FlashAttention & Memory Efficiency

Understand how FlashAttention cuts auxiliary attention memory from O(n²) to O(n) with tiling and online softmax, and analyze its IO complexity.

33 min read
Learning path
Step 127 of 155 in the full curriculum
Prefix Caching and Prompt CachingContinuous Batching & Scheduling

FlashAttention & Memory Efficiency

The previous chapter showed how prefix caching avoids repeating prefill work across requests. FlashAttention attacks a lower layer of the same bottleneck: it makes each dense attention computation move far less data through GPU memory.

FlashAttention is an attention implementation that reduces memory traffic without changing the mathematical attention operator. This chapter explains why exact attention can be slow on GPUs, then shows how tiling and recomputation can make long-context execution practical.

Imagine trying to answer a support question where every new token must compare itself against every prior token in a long order timeline: customer messages, carrier scans, refund notes, warehouse exceptions, and policy snippets. As the timeline gets longer, the number of connections you need to track explodes. This is exactly the problem modern AI models face with the attention mechanism. Dense attention still performs a quadratic number of score comparisons as input grows. In an implementation that materializes all score or probability values, its auxiliary storage is quadratic too: double the input length, and those matrices need four times the memory.

FlashAttention reorganizes how dense attention is computed at the hardware level. It computes the same operator, but doesn't materialize the full score or probability matrix in HBM (High Bandwidth Memory, the large GPU RAM pool).[1] That lowers auxiliary attention memory and HBM traffic. When an eligible FlashAttention kernel is a bottleneck improvement for the workload, it can increase throughput or make a longer context fit.

The memory wall

A quick recap of materialized attention

Before we look at the fix, let's ground the problem with a tiny example. In the scaled dot-product attention article, you saw that each query token scores every key token, normalizes those scores with softmax, and blends the corresponding value vectors.

Suppose you have only three support-ticket tokens and a head dimension of two:

  • Query QQQ is a 3×23 \times 23×2 matrix (one row per token).
  • Key KKK and Value VVV are also 3×23 \times 23×2.

The score matrix S=QKTS = QK^TS=QKT is 3×33 \times 33×3. That is nine numbers. For three tokens, this is trivial. But for real sequences, the story changes fast.

For batch=8batch=8batch=8, heads=32heads=32heads=32, seq_len=8192seq\_len=8192seq_len=8192 in 16-bit floating point (FP16):

Attention matrix=8×32×81922×2 bytes≈34 GB\text{Attention matrix} = 8 \times 32 \times 8192^2 \times 2 \text{ bytes} \approx 34 \text{ GB}Attention matrix=8×32×81922×2 bytes≈34 GB

That product is exactly 34,359,738,36834{,}359{,}738{,}36834,359,738,368 bytes, which is 32 GiB or about 34 GB depending on whether you count in binary or decimal units. If an attention implementation saves this intermediate, it consumes a large part of an 80 GB A100 before counting model weights, other activations, or gradients.

Run this calculation before blaming model weights for an out-of-memory error:

score-matrix-size.py
1batch, heads, sequence, bytes_per_value = 8, 32, 8192, 2 2score_bytes = batch * heads * sequence * sequence * bytes_per_value 3 4print(f"score values: {batch * heads * sequence * sequence:,}") 5print(f"binary size: {score_bytes / 1024**3:.2f} GiB") 6print(f"decimal size: {score_bytes / 1000**3:.2f} GB")
Output
1score values: 17,179,869,184 2binary size: 32.00 GiB 3decimal size: 34.36 GB

GPU memory hierarchy

The key insight behind FlashAttention is understanding where data lives. A GPU is not a flat memory space. It has layers:

  • On-chip SRAM (Static Random Access Memory) is the tiny, blisteringly fast scratchpad right next to the compute cores.
  • HBM is the large pool of GPU memory (VRAM) that holds model weights and tensors.
  • CPU DRAM is host memory outside the GPU, used only when data must leave the card entirely.

Memory hierarchy intuition: Think of on-chip SRAM as the packing bench: tiny, but right beside the worker. HBM is warehouse storage: much bigger, but every trip costs time. A materializing attention baseline keeps walking back to storage because the full n×nn \times nn×n score matrix doesn't fit on the bench. FlashAttention keeps only active tiles on the bench at a time.

In the original FlashAttention paper, the motivating A100 numbers are roughly 20 MB of aggregate on-chip SRAM at about 19 TB/s bandwidth versus 40 GB of HBM at about 1.5 TB/s.[1] A kernel doesn't get to use that full 20 MB as one giant scratchpad, though: practical tile sizes are bounded by much smaller per-SM shared-memory and register budgets. The exact chip layout varies by GPU generation, but the qualitative gap is the same: on-chip memory is tiny and fast, off-chip memory is large and much slower to revisit.

Memory tierTypical role in attentionCapacity / bandwidth intuition
On-chip SRAMHold the current Q/K/V tiles and running softmax statisticsTiny, but fast enough to reuse the same tile many times
HBMHold Q, K, V, O, model weights, and other activationsMuch larger, but expensive to touch for every intermediate
CPU DRAMHost memory outside the GPULarger still, but not suitable for the inner loop of an attention kernel
Diagram showing SRAM (On-chip) Tiny per-kernel working set ~20 MB aggregate on A100 ~19 TB/s bandwidth, HBM (GPU Memory) 40 GB on A100 ~1.5 TB/s bandwidth, and CPU DRAM Much larger Far lower bandwidth to the GPU. Diagram showing SRAM (On-chip) Tiny per-kernel working set ~20 MB aggregate on A100 ~19 TB/s bandwidth, HBM (GPU Memory) 40 GB on A100 ~1.5 TB/s bandwidth, and CPU DRAM Much larger Far lower bandwidth to the GPU.
SRAM (On-chip) Tiny per-kernel working set ~20 MB aggregate on A100 ~19 TB/s bandwidth, HBM (GPU Memory) 40 GB on A100 ~1.5 TB/s bandwidth, and CPU DRAM Much larger Far lower bandwidth to the GPU.

The flow above shows the three-tier memory pyramid. When a baseline writes score and probability matrices to HBM, those round-trips can dominate execution. FlashAttention pays for local bookkeeping to avoid them.[1]

Materialized attention can become IO-limited

In a materializing baseline, the n×nn \times nn×n score and probability matrices are written to HBM and read again. A fused backend may already avoid those intermediates, so the useful comparison is FlashAttention versus the backend the system would otherwise execute, not versus every call named "attention."

The FlashAttention algorithm

Core idea: tiling + online softmax

Instead of materializing the full n×nn \times nn×n attention matrix, FlashAttention splits the problem into three simultaneous ideas:

  1. Tiles the computation into blocks that fit in SRAM.
  2. Uses online softmax to compute exact softmax without seeing all values at once.
  3. Never materializes the full attention matrix in HBM.

This block-wise processing strategy avoids quadratic auxiliary score and probability storage. The input, output, and saved row-statistic tensors still scale with sequence length.

The illustration below contrasts the memory access patterns. A materializing baseline writes the full score matrix to HBM, reads it back for softmax, writes the probability matrix, and reads it again for the final multiply. FlashAttention streams small tiles through SRAM and saves only row-wise statistics.

Materializing attention writing score and probability matrices versus FlashAttention streaming tiles through SRAM. Materializing attention writing score and probability matrices versus FlashAttention streaming tiles through SRAM.
A materializing baseline writes large intermediates to HBM. FlashAttention keeps tile work in SRAM and writes compact row statistics plus final output.

The diagram shows the two paths side by side. On the left, the materializing path makes multiple round-trips to HBM for the n×nn \times nn×n intermediates. On the right, FlashAttention processes small blocks in fast SRAM and writes back the output and compact row statistics.

How tiling fits in SRAM

The illustration below visualizes the data flow: Q, K, and V blocks stream from large HBM into the small SRAM workspace, where only the current tile and running statistics are kept.

FlashAttention tiling strategy moving Q, K, and V blocks from HBM into SRAM and writing output plus compact row statistics back. FlashAttention tiling strategy moving Q, K, and V blocks from HBM into SRAM and writing output plus compact row statistics back.
Only active Q/K/V tiles and running softmax statistics sit in fast SRAM. Output and compact row statistics are written, not full attention matrices.

During forward execution, QQQ, KKK, and VVV are read from HBM. The kernel writes final output OOO and, for training, compact row statistics needed by backward. It doesn't write the full score or probability matrices.

Tile size is bounded by on-chip capacity. This simplified payload check counts four FP16 tile-shaped arrays (Q, K, V, and a partial output), while production kernels also budget for statistics, registers, and implementation overhead:

tile-working-set.py
1block_rows, head_dimension, arrays, bytes_per_value = 128, 64, 4, 2 2payload_bytes = block_rows * head_dimension * arrays * bytes_per_value 3 4print(f"simplified tile payload: {payload_bytes / 1024:.0f} KiB") 5print("also budget: row statistics, registers, and kernel overhead")
Output
1simplified tile payload: 64 KiB 2also budget: row statistics, registers, and kernel overhead

Online softmax with a concrete example

Imagine you're summarizing 1,000 package-risk scores, but the packing bench only fits 50 labels at a time. You don't need to hold all 1,000 scores to compute the final normalization. You keep running statistics and update them as each block arrives.

Online softmax works the same way: instead of needing the full n×nn \times nn×n attention matrix to compute softmax, it maintains running statistics (max and sum) and updates them block by block, without introducing any approximation.

Let's walk through a tiny numeric example before we show the general formula.

Standard softmax on three scores

Suppose a query token sees key scores [1.0,  2.0,  0.5][1.0,\; 2.0,\; 0.5][1.0,2.0,0.5].

  1. Find the max: m=2.0m = 2.0m=2.0.
  2. Exponentiate relative to the max: [e−1.0,  e0,  e−1.5]≈[0.368,  1.0,  0.223][e^{-1.0},\; e^{0},\; e^{-1.5}] \approx [0.368,\; 1.0,\; 0.223][e−1.0,e0,e−1.5]≈[0.368,1.0,0.223].
  3. Sum: ℓ=0.368+1.0+0.223=1.591\ell = 0.368 + 1.0 + 0.223 = 1.591ℓ=0.368+1.0+0.223=1.591.
  4. Normalize: [0.231,  0.628,  0.141][0.231,\; 0.628,\; 0.141][0.231,0.628,0.141].

A materializing baseline would store the full 3×33 \times 33×3 score matrix in HBM just to perform that four-step process for every row.

Online softmax with two blocks

Now pretend the SRAM bench only fits two scores at a time. We split the scores into Block A [1.0,  2.0][1.0,\; 2.0][1.0,2.0] and Block B [0.5][0.5][0.5].

Processing Block A:

  • Local max: mA=2.0m_A = 2.0mA​=2.0.
  • Local denominator: ℓA=e−1.0+e0≈1.368\ell_A = e^{-1.0} + e^{0} \approx 1.368ℓA​=e−1.0+e0≈1.368.
  • Local unnormalized numerator: NA=e−1.0⋅V1+e0⋅V2N_A = e^{-1.0} \cdot V_1 + e^{0} \cdot V_2NA​=e−1.0⋅V1​+e0⋅V2​.

Processing Block B:

  • Local max: mB=0.5m_B = 0.5mB​=0.5.
  • New global max: mnew=max⁡(2.0,  0.5)=2.0m_{new} = \max(2.0,\; 0.5) = 2.0mnew​=max(2.0,0.5)=2.0.
  • Rescale the old denominator to the new max: ℓnew=e2.0−2.0⋅1.368+e0.5−2.0=1.368+0.223=1.591\ell_{new} = e^{2.0 - 2.0} \cdot 1.368 + e^{0.5 - 2.0} = 1.368 + 0.223 = 1.591ℓnew​=e2.0−2.0⋅1.368+e0.5−2.0=1.368+0.223=1.591.
  • Update the numerator by rescaling the old accumulator and adding the new block: Nnew=e0⋅NA+e−1.5⋅V3N_{new} = e^{0} \cdot N_A + e^{-1.5} \cdot V_3Nnew​=e0⋅NA​+e−1.5⋅V3​.

After both blocks, the final output for this query row is O=Nnew/1.591O = N_{new} / 1.591O=Nnew​/1.591, which is exactly the same result as standard softmax. The difference is that we never held all three scores in the fast workspace at once.

This small program checks the rescaling rule with scalar values, independent of any GPU kernel:

online-softmax-two-blocks.py
1import math 2 3scores_a, values_a = [1.0, 2.0], [10.0, 20.0] 4scores_b, values_b = [0.5], [40.0] 5 6def local_state(scores, values): 7 max_score = max(scores) 8 weights = [math.exp(score - max_score) for score in scores] 9 return max_score, sum(weights), sum(weight * value for weight, value in zip(weights, values)) 10 11m_a, l_a, n_a = local_state(scores_a, values_a) 12m_b, l_b, n_b = local_state(scores_b, values_b) 13m = max(m_a, m_b) 14l = math.exp(m_a - m) * l_a + math.exp(m_b - m) * l_b 15n = math.exp(m_a - m) * n_a + math.exp(m_b - m) * n_b 16online = n / l 17 18all_scores = scores_a + scores_b 19all_values = values_a + values_b 20dense_weights = [math.exp(score - max(all_scores)) for score in all_scores] 21dense = sum(w * v for w, v in zip(dense_weights, all_values)) / sum(dense_weights) 22 23print(f"online output: {online:.6f}") 24print(f"dense output: {dense:.6f}") 25print(f"match: {abs(online - dense) < 1e-12}")
Output
1online output: 20.492649 2dense output: 20.492649 3match: True

The general update rule

For each new block of scores snews_{\text{new}}snew​ and value vectors VblockV_{\text{block}}Vblock​:

mnew=max⁡(mold,  max⁡(snew))m_{\text{new}} = \max(m_{\text{old}},\; \max(s_{\text{new}}))mnew​=max(mold​,max(snew​))

ℓnew=emold−mnew⋅ℓold+∑jesj−mnew\ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \cdot \ell_{\text{old}} + \sum_j e^{s_j - m_{\text{new}}}ℓnew​=emold​−mnew​⋅ℓold​+∑j​esj​−mnew​

Nnew=emold−mnew⋅Nold+esnew−mnewVblockN_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \cdot N_{\text{old}} + e^{s_{\text{new}} - m_{\text{new}}} V_{\text{block}}Nnew​=emold​−mnew​⋅Nold​+esnew​−mnew​Vblock​

Onew=NnewℓnewO_{\text{new}} = \frac{N_{\text{new}}}{\ell_{\text{new}}}Onew​=ℓnew​Nnew​​

Where mmm is the running max score (for numerical stability), ℓ\ellℓ is the running softmax denominator, NNN is the running unnormalized numerator accumulator, and OOO is the normalized output.

The rescaling terms emold−mnewe^{m_{\text{old}} - m_{\text{new}}}emold​−mnew​ ensure previous results stay correct even though the max changed. This is the mathematical trick that eliminates the need for a second pass over the full row.[2]

Diagram showing Materializing Baseline, FlashAttention: Tiled, Compute full S = QKᵀ (write n² to HBM), and Find max (read n² from HBM). Diagram showing Materializing Baseline, FlashAttention: Tiled, Compute full S = QKᵀ (write n² to HBM), and Find max (read n² from HBM).
Materializing Baseline, FlashAttention: Tiled, Compute full S = QKᵀ (write n² to HBM), and Find max (read n² from HBM).

The diagram above contrasts the two approaches. A materializing baseline makes multiple round-trips to HBM for the full n×nn \times nn×n matrix. FlashAttention performs tile-local score and softmax work on-chip, writing output and compact saved statistics to HBM.

Pseudocode

The following function demonstrates the core logic of FlashAttention. It takes the Query, Key, and Value matrices along with a specified block size to load into SRAM. It returns the same dense attention output as a materialized reference implementation, but iterates through tiles to avoid full score and probability matrices.

pseudocode.py
1import torch 2import math 3 4def flash_attention( 5 Q: torch.Tensor, 6 K: torch.Tensor, 7 V: torch.Tensor, 8 block_size: int = 256 9) -> torch.Tensor: 10 """ 11 Simplified forward-pass sketch of FlashAttention. 12 13 Args: 14 Q: Query tensor of shape (n, d) 15 K: Key tensor of shape (n, d) 16 V: Value tensor of shape (n, d) 17 block_size: Size of blocks to load into SRAM 18 19 Returns: 20 O: Output tensor of shape (n, d) 21 """ 22 n, d = Q.shape 23 O = torch.zeros_like(Q) 24 25 # Outer loop: iterate over Q blocks and keep the current output tile on-chip 26 for i in range(0, n, block_size): 27 Qi = Q[i:i+block_size] # Load Q block to SRAM 28 29 # Initialize running statistics for this Q block 30 Oi = torch.zeros_like(Qi) # Accumulator 31 li = torch.zeros(Qi.shape[0], 1, device=Q.device) # Denominator 32 mi = torch.full((Qi.shape[0], 1), -float('inf'), device=Q.device) # Max 33 34 # Inner loop: Iterate over K, V blocks 35 for j in range(0, n, block_size): 36 Kj = K[j:j+block_size] # Load K block to SRAM 37 Vj = V[j:j+block_size] # Load V block to SRAM 38 39 # Compute local attention scores (in SRAM!) 40 # Shape: (block_size_q, block_size_k) 41 Sij = Qi @ Kj.T / math.sqrt(d) 42 43 # Online softmax update logic 44 m_ij = Sij.max(dim=-1, keepdim=True).values 45 m_new = torch.max(mi, m_ij) 46 47 exp_old_scale = torch.exp(mi - m_new) 48 exp_new = torch.exp(Sij - m_new) 49 50 # Update output accumulator (unnormalized) 51 Oi = exp_old_scale * Oi + exp_new @ Vj 52 53 # Update running statistics 54 li = exp_old_scale * li + exp_new.sum(dim=-1, keepdim=True) 55 mi = m_new 56 57 # Normalize by the final denominator 58 O[i:i+block_size] = Oi / li 59 60 return O 61 62# Quick sanity check: FlashAttention should match reference attention on a small tensor 63if __name__ == "__main__": 64 torch.manual_seed(0) 65 n, d = 64, 32 66 Q = torch.randn(n, d) 67 K = torch.randn(n, d) 68 V = torch.randn(n, d) 69 70 # Reference attention (materializes full n x n matrix) 71 S = Q @ K.T / math.sqrt(d) 72 P = torch.softmax(S, dim=-1) 73 expected = P @ V 74 75 # FlashAttention sketch (tiling, no full score/probability materialization) 76 got = flash_attention(Q, K, V, block_size=16) 77 78 max_difference = (expected - got).abs().max().item() 79 print("Matches dense attention within 1e-5:", max_difference < 1e-5)
Output
1Matches dense attention within 1e-5: True

This sketch keeps Oi as an unnormalized numerator accumulator and divides by li once per Q tile at the end. The small test at the bottom proves the key claim: for a 64×3264 \times 3264×32 toy tensor, the tiled loop produces the same result as standard dense attention, with a maximum difference below 10−510^{-5}10−5.

Common mistake: Beginners sometimes think Oi is already normalized inside the inner loop. It isn't. The division by li happens only after every K/V block for that Q tile has been processed. If you normalize early, you lose the exact rescaling that makes online softmax correct.

Production kernels implement the same algebra much more aggressively, while also handling batching, multiple heads, masks, and dropout.

The backward pass: recomputation can win

A materializing training implementation can save the massive n×nn \times nn×n attention matrices SSS and PPP from the forward pass to compute gradients during the backward pass. Those saved intermediates can become a major source of out-of-memory (OOM) errors.

FlashAttention solves this by recomputing the needed score and probability tiles during the backward pass instead of storing them all from the forward pass.[1] Because it saves only row-wise softmax statistics such as the running max mmm and denominator ℓ\ellℓ, the saved attention state grows as O(n)O(n)O(n) rather than O(n2)O(n^2)O(n2).

In a warehouse, you could photocopy every shipping label at each packing station and file those copies in a giant archive. That is the materializing baseline. Or you could keep a slim logbook with running totals and reprint any label you need from the original order data. That is FlashAttention.

Recomputation isn't free. It adds arithmetic in backward. The point of the FlashAttention paper is that, on the evaluated GPU workloads, avoiding much larger HBM reads and writes more than paid for that arithmetic cost.[1] Measure the trade-off on the model shape and hardware you deploy.

This calculator makes the saved-state difference concrete. It intentionally counts one score matrix and two row statistics, not every tensor in training:

saved-state-comparison.py
1batch, heads, sequence, bytes_per_value = 8, 32, 8192, 2 2materialized_scores = batch * heads * sequence * sequence * bytes_per_value 3row_stats = batch * heads * sequence * 2 * bytes_per_value 4 5print(f"one saved score matrix: {materialized_scores / 1024**3:.2f} GiB") 6print(f"two row statistics: {row_stats / 1024**2:.2f} MiB") 7print(f"size ratio: {materialized_scores / row_stats:,.0f}x")
Output
1one saved score matrix: 32.00 GiB 2two row statistics: 8.00 MiB 3size ratio: 4,096x
PropertyMaterializing baselineFlashAttention
Saved attention state for backwardStore SSS and PPP explicitly: O(n2)O(n^2)O(n2)Store row-wise statistics such as mmm and ℓ\ellℓ: O(n)O(n)O(n)
Backward strategyRead large intermediates from HBMRecompute local tiles from Q,K,VQ, K, VQ,K,V plus saved stats
Trade-offLess recomputation, much higher memoryMore recomputation, much lower memory

Complexity analysis

FlashAttention changes the IO complexity by tiling the computation. Let MMM denote the amount of fast SRAM available to hold a tile's working set.

PropertyMaterializing baselineFlashAttention
Auxiliary attention memoryMaterialize SSS and PPP: O(n2)O(n^2)O(n2)Keep row stats and the current output tile: O(n)O(n)O(n)
FLOPsO(n2d)O(n^2 d)O(n2d)O(n2d)O(n^2 d)O(n2d) (same)
HBM reads/writesO(nd+n2)O(nd + n^2)O(nd+n2)O(n2d2/M)O(n^2 d^2 / M)O(n2d2/M)
ExactYesYes
Reported wall-clock speedBaselineUp to 2-4x faster in evaluated paper workloads[1]
Auxiliary attention memory scaling comparing materialized quadratic intermediates with FlashAttention linear row statistics. Auxiliary attention memory scaling comparing materialized quadratic intermediates with FlashAttention linear row statistics.
FlashAttention reduces auxiliary attention memory from saved score/probability matrices to row-wise statistics and the current output tile.

FLOPs stands for floating-point operations. The memory row here refers to the extra state created by the attention kernel itself, not the shared QQQ, KKK, VVV, and OOO tensors that both approaches still need to hold. FlashAttention doesn't reduce the asymptotic mathematical work required for dense attention: both paths perform O(n2d)O(n^2 d)O(n2d) operations. Its main algorithmic advantage is avoiding n2n^2n2 score and probability transfers to and from slow HBM. Tile size, scheduling, datatype, and hardware still affect observed speed.

By keeping the working memory constrained to the on-chip SRAM budget MMM, the IO complexity drops from a quadratic term to a fraction of that size. This means the GPU spends less time waiting for data to arrive from memory and more time keeping its compute cores busy.

Because FlashAttention computes the same dense attention operator as the reference formula (thanks to the online softmax trick), it doesn't change the model's mathematical attention rule. Different floating-point operation order can still cause small numeric differences.

The following numbers distinguish score-matrix scaling from row-statistic scaling:

auxiliary-state-scaling.py
1base_sequence = 1024 2for sequence in [1024, 2048, 4096, 8192, 16384]: 3 materialized_relative = (sequence / base_sequence) ** 2 4 row_stats_relative = sequence / base_sequence 5 print( 6 f"{sequence:>5} tokens: materialized={materialized_relative:>5.0f}x, " 7 f"row-stats={row_stats_relative:>2.0f}x" 8 )
Output
11024 tokens: materialized= 1x, row-stats= 1x 2 2048 tokens: materialized= 4x, row-stats= 2x 3 4096 tokens: materialized= 16x, row-stats= 4x 4 8192 tokens: materialized= 64x, row-stats= 8x 516384 tokens: materialized= 256x, row-stats=16x

Causal masking in FlashAttention

For autoregressive transformers, attention is causal: token iii can only attend to tokens j≤ij \leq ij≤i. FlashAttention handles this efficiently without materializing a dense causal mask in HBM.

Block-level skipping

If a block of K/V tokens is entirely in the "future" relative to a Q block, the entire block multiplication is skipped. No compute wasted.

Within-block masking

For blocks that straddle the causal boundary, FlashAttention applies the mask after computing scores but before the softmax update. The masked positions are set to −∞-\infty−∞.

FlashAttention causal tile scheduler deciding whether a K/V block is in the past, on the diagonal boundary, or in the future relative to a query block. FlashAttention causal tile scheduler deciding whether a K/V block is in the past, on the diagonal boundary, or in the future relative to a query block.
Causal masking becomes a scheduling decision. Entire future tiles are skipped, past tiles run fully, and only boundary tiles need an in-tile mask.

In other words, causal masking is folded into the tile schedule itself: future tiles are skipped, and diagonal tiles apply an in-tile mask before the online softmax update. That reduces wasted work, but the exact speedup depends on sequence length, tile shape, and kernel implementation rather than being a guaranteed 2x.

The same tiled structure also adapts well to local windowed attention. Tiles that fall completely outside the attention window can be skipped before doing the matrix multiply.

You can audit causal tile decisions without any GPU code:

causal-tile-schedule.py
1tiles = 4 2counts = {"past": 0, "boundary": 0, "future": 0} 3 4for query_tile in range(tiles): 5 for key_tile in range(tiles): 6 if key_tile < query_tile: 7 decision = "past" 8 elif key_tile == query_tile: 9 decision = "boundary" 10 else: 11 decision = "future" 12 counts[decision] += 1 13 14print(counts) 15print(f"computed tiles: {counts['past'] + counts['boundary']} of {tiles * tiles}")
Output
1{'past': 6, 'boundary': 4, 'future': 6} 2computed tiles: 10 of 16

FlashAttention-2 and FlashAttention-3

Since the introduction of the original algorithm, the architecture has evolved to better use modern GPU features and achieve higher theoretical throughput.

FlashAttention-2[3] improves on the original by optimizing the hardware execution:

  • Better work partitioning across GPU thread blocks and warp-level parallelism to reduce synchronization overhead and increase occupancy.
  • Reduces non-matmul FLOPs (like causal masking and softmax operations).
  • Achieves about 2× faster speeds than FlashAttention-1, reaching 50-73% of theoretical max FLOPs/s.[3]

FlashAttention-3[4] is designed to target the advanced capabilities of the Hopper architecture (like the H100):

  • Exploits asynchronous execution using the TMA (Tensor Memory Accelerator) and WGMMA (Warpgroup Matrix-Matrix Multiply Accumulate) instructions, allowing data loading to overlap completely with math operations.
  • Adds native FP8 (8-bit floating point) support for further computational speedup.
  • Achieves about 75% of theoretical FLOPs/s on H100 GPUs.

These subsequent iterations demonstrate that while the mathematical core of tiling and online softmax remains the same, hardware-aware kernel optimization is critical for maximizing performance. Building these highly optimized kernels often requires low-level CUDA programming. Higher-level languages like Triton have also made it much easier to write custom memory-efficient attention kernels without dropping all the way to raw CUDA C++.

Real-world performance

Training throughput

During model training, saving quadratic attention intermediates can sharply restrict the maximum sequence length a model can process. As sequence length increases, a materializing baseline may run out of memory even when a fused attention path can still fit.

Common mistake: Assuming FlashAttention is only a "long-sequence hack." The original paper reports a 15% end-to-end training speedup for BERT-large at sequence length 512, so even a moderate evaluated sequence can benefit when attention IO matters.[1]

Compared with a baseline that saves full attention intermediates, FlashAttention's auxiliary attention memory grows linearly rather than quadratically. That can enable longer sequences and improve throughput even while a materializing baseline still fits in memory.[1][3]

SourceWorkloadReported result
FlashAttention (2022)BERT-large, sequence length 51215% end-to-end training speedup[1]
FlashAttention (2022)GPT-2, sequence length 1K3× speedup[1]
FlashAttention (2022)Long Range Arena, sequence length 1K-4K2.4× speedup[1]
FlashAttention-2 (2023)GPT-style training on A100Up to 225 TFLOPs/s per GPU, 72% model FLOPs utilization[3]

The important point isn't one magic benchmark number. It's that once the attention kernel stops writing giant intermediates to HBM, longer-sequence training becomes much more practical.

Inference impact

During inference, FlashAttention helps most when the workload still looks like dense attention over many prompt tokens:

  • Prefill phase benefits the most, because the model still performs full prompt self-attention. Think of it as sorting a full truckload of packages at once: you need to compare every item against the master manifest, and the comparison matrix is huge.
  • Decode phase benefits less, because each step introduces one new query token and reuses the KV cache (Key-Value cache), so other bottlenecks often dominate.
  • Long prompts benefit more than short prompts, because avoiding a materialized n×nn \times nn×n score matrix matters more as nnn grows.

FlashAttention has its biggest impact when attention itself is the bottleneck. That's usually training and prefill, not single-token decode.

This shape check shows why prefill creates far more score work per request than one decode step:

prefill-versus-decode.py
1prompt_tokens = 8192 2prefill_scores = prompt_tokens * prompt_tokens 3decode_scores = 1 * prompt_tokens 4 5print(f"prefill scores: {prefill_scores:,}") 6print(f"one decode step scores: {decode_scores:,}") 7print(f"ratio: {prefill_scores // decode_scores:,}x")
Output
1prefill scores: 67,108,864 2one decode step scores: 8,192 3ratio: 8,192x

Hardware compatibility

The algorithmic idea is general, but the fastest kernels are hardware-specific. FlashAttention-2 describes better parallelism and work partitioning for modern GPUs, while FlashAttention-3 is a Hopper-focused redesign that targets features such as TMA, WGMMA, and FP8 attention.[3][4]

Availability in an application depends on its framework build, device, datatype, tensor shapes, and attention features. Treat backend selection as something to verify, not something to infer from the model name.

Using FlashAttention in practice

In modern deep learning frameworks, you rarely implement FlashAttention from scratch. PyTorch exposes torch.nn.functional.scaled_dot_product_attention (SDPA), which may choose an optimized CUDA implementation when the inputs and build support it. Its sdpa_kernel context manager lets you select permitted implementations while testing or profiling. Eligibility and fallback behavior depend on the installed PyTorch build, device, datatype, layout, and attention features, so consult the documentation for that build and measure the actual path.[5]

First validate operator behavior with a small CPU example. This checks causality and output shape, not FlashAttention dispatch:

sdpa-operator-check.py
1import torch 2import torch.nn.functional as F 3 4torch.manual_seed(0) 5query = torch.randn(1, 2, 4, 8) 6key = torch.randn(1, 2, 4, 8) 7value = torch.randn(1, 2, 4, 8) 8 9output = F.scaled_dot_product_attention( 10 query, key, value, 11 is_causal=True, 12 dropout_p=0.0, 13) 14print(f"output shape: {tuple(output.shape)}") 15print(f"finite values: {torch.isfinite(output).all().item()}")
Output
1output shape: (1, 2, 4, 8) 2finite values: True

On CUDA hardware, backend restriction is useful as a test probe. This snippet is intentionally not marked runnable because it requires a suitable installed CUDA build and GPU:

request-cuda-flash-backend.py
1import torch 2import torch.nn.functional as F 3from torch.nn.attention import SDPBackend, sdpa_kernel 4 5Q = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16) 6K = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16) 7V = torch.randn(2, 16, 1024, 64, device="cuda", dtype=torch.float16) 8 9with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 10 output = F.scaled_dot_product_attention(Q, K, V, is_causal=True, dropout_p=0.0)

If a model library offers a FlashAttention request flag, requesting it isn't evidence that the fast path ran. Record a before/after measurement and backend evidence:

backend-verification-record.py
1verification = { 2 "requested_backend": "flash_attention", 3 "operator_correctness_checked": True, 4 "profiler_shows_selected_kernel": False, 5 "latency_measured": False, 6} 7 8active = ( 9 verification["operator_correctness_checked"] 10 and verification["profiler_shows_selected_kernel"] 11 and verification["latency_measured"] 12) 13print(f"enough evidence to claim speedup: {active}") 14print("next check: capture backend/profiler output on target GPU")
Output
1enough evidence to claim speedup: False 2next check: capture backend/profiler output on target GPU

For example, Hugging Face Transformers exposes a model-load request in versions and models that support the corresponding integration:

using-flashattention-in-practice-2.py
1import torch 2from transformers import AutoModelForCausalLM 3 4model = AutoModelForCausalLM.from_pretrained( 5 "meta-llama/Llama-3.1-8B", 6 torch_dtype=torch.float16, 7 attn_implementation="flash_attention_2", 8 device_map="auto", 9)

Common mistakes checklist

"FlashAttention is an approximation"

Symptom: You hear FlashAttention grouped with sparse or low-rank attention approximations and assume it drops some connections to save memory.

Cause: The word "efficient" often implies approximation in other contexts.

Fix: FlashAttention is exact. Thanks to the online softmax trick, it computes the same dense attention formula without using sparse or low-rank shortcuts. Numeric outputs can differ slightly from a reference implementation because floating-point operations are associated in a different order, but the mathematical operator is the same. If you need proof, run the small PyTorch test from the pseudocode section and check that the max difference is near zero.

"FlashAttention reduces the number of compute operations"

Symptom: You claim in an interview or code review that FlashAttention cuts FLOPs.

Cause: It's natural to equate "faster" with "fewer operations."

Fix: The forward attention computation still has O(n2d)O(n^2 d)O(n2d) floating-point operations (FLOPs). The speedup comes from reduced memory operations (IO), not from changing dense attention into a cheaper mathematical operator. In training, the backward pass can perform more operations because it recomputes tiles. The win is that compute is cheap and memory movement is expensive.

"GPU memory is one big pool"

Symptom: You only compare total VRAM capacity and miss why attention still runs slowly on large GPUs.

Cause: HBM, on-chip SRAM, shared memory, and registers have very different capacity and bandwidth profiles.

Fix: Ask where each tensor lives and how often it crosses the HBM/SRAM boundary. FlashAttention wins because it keeps Q/K/V tiles and softmax state on-chip long enough to reuse them, then writes only the final output and row statistics back to HBM.

"IO complexity is the same as time complexity"

Symptom: You explain FlashAttention as if it changes attention from quadratic time to linear time.

Cause: The memory table and the FLOP table get mixed together.

Fix: Keep the dimensions separate. Dense attention still does quadratic compute in sequence length. FlashAttention reduces HBM reads and writes, so wall-clock time improves when the workload is memory-bound.

"Online softmax is optional bookkeeping"

Symptom: You tile attention but normalize each block independently.

Cause: The running max and denominator updates look like an implementation detail.

Fix: Online softmax is the correctness mechanism. The running max rescales old contributions when a later tile contains a larger score, and the running denominator keeps all blocks normalized against the same global row.

"FlashAttention is only useful for long sequences"

Symptom: You skip enabling it on short-context models.

Cause: The OOM headlines make FlashAttention look like a long-sequence-only tool.

Fix: While it enables long sequences by avoiding memory limits, it provides substantial speedups even on shorter sequences because it reduces HBM access. The 15% BERT-large speedup at 512 tokens is a clear example of short-sequence gains.[1]

"You have to write custom CUDA kernels to use it"

Symptom: You avoid FlashAttention because you assume it requires low-level GPU programming.

Cause: The original paper describes kernel-level details, which can give the impression that users must write CUDA.

Fix: Use a framework SDPA API or supported model integration, then check backend selection and measure on the target GPU. A request flag is configuration, not proof that an optimized kernel ran.

What you should be able to defend

After working through this chapter, you should be able to:

  1. Identify the real bottleneck. Explain when HBM-to-SRAM traffic can dominate a materializing long-sequence attention path.
  2. Explain tiling. Show how Q, K, and V blocks fit into SRAM so the full n×nn \times nn×n score matrix never has to live in HBM.
  3. Derive online softmax. Walk through the running max, denominator, and numerator updates for a small split-row example.
  4. Separate compute from IO. State that FLOPs remain O(n2d)O(n^2 d)O(n2d) while auxiliary attention memory drops from O(n2)O(n^2)O(n2) to O(n)O(n)O(n).
  5. Defend exactness. Explain why FlashAttention is not sparse attention, not low-rank attention, and not an approximation.
  6. Explain backward recomputation. Describe why saving row-wise statistics and recomputing tiles can be faster than storing SSS and PPP.
  7. Place later versions. Say what FlashAttention-2 improves with parallelism and work partitioning, and what FlashAttention-3 adds for Hopper GPUs.

Production questions

How does FlashAttention change the backward pass?

FlashAttention avoids storing the full N×NN \times NN×N score and probability matrices from the forward pass. It saves row-wise softmax statistics, then recomputes local attention tiles during backpropagation. That trades extra compute for much lower HBM traffic and reduces saved attention state from O(N2)O(N^2)O(N2) to O(N)O(N)O(N).

Can FlashAttention support causal masking efficiently?

Yes. The tiled scheduler can skip whole K/V blocks that lie entirely in the future and apply an in-tile mask only on diagonal boundary tiles. That keeps the mask inside the kernel schedule instead of materializing a dense causal mask, although the exact speedup still depends on sequence length, tile shape, and kernel implementation.

What are the limitations of FlashAttention?

FlashAttention is not a universal kernel for every attention variant. Fast paths have hardware, dtype, head-dimension, layout, and masking constraints. Arbitrary sparse patterns or custom score modifications may need a different kernel family or a fallback implementation. On very short sequences, attention may not be memory-bound enough for the extra kernel complexity to matter.

How does FlashAttention-3 use Hopper GPU features?

FlashAttention-3 is a Hopper-specific redesign. It uses TMA, WGMMA, warp specialization, and FP8 support to overlap data movement with computation more aggressively. The paper reports up to 740 TFLOPs/s in FP16, about 75% utilization, and close to 1.2 PFLOPs/s in FP8 on H100-class hardware.[4]

Next Step
Continue to Continuous Batching & Scheduling

There, you'll understand how LLM schedulers use continuous batching, chunked prefill, and prefill-decode disaggregation to improve throughput without blowing up time-to-first-token (TTFT) or inter-token latency.

PreviousPrefix Caching and Prompt Caching
Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail
References

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022

Online normalizer calculation for softmax.

Milakov, M. & Gimelshein, N. · 2018

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.

Dao, T. · 2023 · ICLR 2024

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.

Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. · 2024

torch.nn.functional.scaled_dot_product_attention

PyTorch Contributors · 2026