LeetLLM
LearnFeaturesBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Features
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 155 articles completed

🛠️Computing Foundations0/6
NumPy and Tensor ShapesCUDA for ML TrainingMPS & Metal for ML on MacData Structures for AISQL and Data ModelingAlgorithms for ML Engineers
📊Math & Statistics0/8
Gradients and BackpropVectors, Matrices & TensorsLinear Algebra for MLAdam, Momentum, SchedulersProbability for Machine LearningStatistics and UncertaintyDistributions and SamplingHypothesis Tests, Intervals, and pass@k
📚Preparation & Prerequisites0/13
Neural Networks from ScratchCNNs from ScratchTraining & BackpropagationSoftmax, Cross-Entropy & OptimizationRNNs, LSTMs, GRUs, and Sequence ModelingAutoencoders and VAEsThe Transformer Architecture End-to-EndLanguage Modeling & Next TokensFrom GPT to Modern LLMsPrompt Engineering FundamentalsCalling LLM APIs in ProductionFirst AI App End-to-EndThe LLM Lifecycle
🧮ML Algorithms & Evaluation0/11
Linear Regression from ScratchLogistic Regression and MetricsDecision Trees, Forests, and BoostingReinforcement Learning BasicsValidation and LeakageClustering and PCACore Retrieval AlgorithmsDecoding AlgorithmsExperiment Design and A/B TestingPyTorch Training LoopsDataset Pipelines and Data Quality
📦Production ML Systems0/6
Feature Engineering for Production MLBatch and Streaming Feature PipelinesGradient Boosted Trees in ProductionRanking and Recommendation SystemsForecasting and Anomaly DetectionMonitoring Predictive Models
🧪Core LLM Foundations0/8
The Bitter Lesson & ComputeBPE, WordPiece, and SentencePieceStatic to Contextual EmbeddingsPerplexity & Model EvaluationFile Ingestion for AIChunking StrategiesLLM Benchmarks & LimitationsInstruction Tuning & Chat Templates
🧰Applied LLM Engineering0/23
Dimensionality Reduction for EmbeddingsCoT, ToT & Self-Consistency PromptingFunction Calling & Tool UseMCP & Tool Protocol StandardsPrompt Injection DefenseResponsible AI GovernanceData Labeling and Human FeedbackEvaluating AI AgentsProduction RAG PipelinesHybrid Search: Dense + SparseReranking and Cross-Encoders for RAGRAG Evaluation for Reliable AnswersLLM-as-a-Judge EvaluationBias & Fairness in LLMsHallucination Detection & MitigationLLM Observability & MonitoringExperiment Tracking with MLflow and W&BMixed Precision TrainingModel Versioning & DeploymentSemantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Gateways, Routing, and FallbacksDesign an Automated Support Agent
🎓Portfolio Capstones0/9
Capstone: Delivery ETA PredictionCapstone: Product RankingCapstone: Demand ForecastingCapstone: Image Damage ClassifierCapstone: Production ML PipelineCapstone: Document QACapstone: Eval DashboardCapstone: Fine-Tuned ClassifierCapstone: Production Agent
🧠Transformer Deep Dives0/8
Sentence Embeddings & Contrastive LossEmbedding Similarity & QuantizationScaled Dot-Product AttentionVision Transformers and Image EncodersPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNMechanistic InterpretabilityDecoding Strategies: Greedy to Nucleus
🧬Advanced Training & Adaptation0/16
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleBuild GPT from Scratch LabContinued Pretraining for Domain ShiftSynthetic Data PipelinesSupervised Fine-Tuning PipelineDistributed Training: FSDP & ZeROLoRA & Parameter-Efficient TuningReward Modeling from Preference DataRLHF & DPO AlignmentConstitutional AI & Red TeamingRLVR & Verifiable RewardsKnowledge Distillation for LLMsModel Merging and Weight InterpolationPrompt Optimization with DSPyRecursive Language Models (RLM)
🤖Advanced Agents & Retrieval0/14
Vector DB Internals: HNSW & IVFAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access ControlStructured Output GenerationReAct & Plan-and-ExecuteGuardrails & Safety FiltersCode Generation & SandboxingComputer-Use / GUI / Browser AgentsHuman-in-the-Loop Agent ArchitectureAI Coding Workflow with AgentsAgent Memory & PersistenceAgent Failure & RecoveryMulti-Agent Orchestration
⚡Inference & Production Scale0/20
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionPrefix Caching and Prompt CachingFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceModel Parallelism for LLM InferenceModel Quantization: GPTQ, AWQ & GGUFLocal LLM DeploymentSLM Specialization & Edge DeploymentSpeculative DecodingLong Context Window ManagementContext EngineeringMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time ComputeAdvanced MLOps & DevOps for AIGPU Serving & AutoscalingA/B Testing for LLMs
🏗️System Design Capstones0/9
Content Moderation SystemCode Completion SystemMulti-Tenant LLM PlatformLLM-Powered Search EngineVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image GenerationReal-Time Voice AI AgentReasoning & Test-Time Compute
🎤AI Lab Interviewing0/4
AI Lab Coding Interview: Python SystemsAI Lab System Design InterviewAI Lab Behavioral InterviewAI Lab Technical Presentation
Back to Topics
LearnTransformer Deep DivesScaled Dot-Product Attention
🧠HardTransformer Architecture

Scaled Dot-Product Attention

Learn scaled dot-product attention from first principles, including Q/K/V routing, variance scaling, masks, multi-head shapes, KV-cache costs, and FlashAttention.

38 min read
Learning path
Step 87 of 155 in the full curriculum
Embedding Similarity & QuantizationVision Transformers and Image Encoders

The previous chapter showed how to compare embedding vectors for retrieval. layers do a related job inside a sequence: each token compares itself to other tokens, decides which positions matter, and blends information from those positions.

When you read "The package missed the dock because it was closed," you instantly know that "it" refers to the dock, not the package. Your brain routes focus to the relevant word. Scaled dot-product attention is the transformer operation that learns that kind of routing with multiplication.

Attention is the core information-routing mechanism in modern transformers. Given a sequence of , it answers: "For each token, which other positions should it use, and how much?" Once that routing idea is clear, the math and code become direct translations of it.

Scaled Dot-Product Attention Flow: Q, K, V projections through MatMul, Scale, Softmax, to Output Scaled Dot-Product Attention Flow: Q, K, V projections through MatMul, Scale, Softmax, to Output
Scaled dot-product attention pipeline: Query, Key, and Value vectors flow through score computation, scaling, masking, softmax, and value aggregation to produce context-aware outputs.

The attention mechanism

Intuition

Attention answers a simple question: "For each token in a sequence, which other tokens should it focus on?"

Think of reading a sentence: when you see "it" in "The package missed the dock because it was closed", your brain attends back to "dock" to resolve what "it" means. Self-attention does this computationally: every token computes a weighted combination of other token representations based on learned relevance.[1]

Setup: Q, K, V projections

To perform attention, the model transforms each token representation into three different roles. Think of it like a support search system:

VectorQuestion it answersSupport-search analogy
Query (Q)What am I looking for?An order-status question
Key (K)What do I contain?Searchable tags on a shipment event
Value (V)What information do I carry?The actual carrier scan or policy text

The model creates these three vectors by multiplying each token representation through learned weight matrices:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_VQ=XWQ​,K=XWK​,V=XWV​

From one embedding to three views

Each token representation XXX is multiplied by three different weight matrices (WQW_QWQ​, WKW_KWK​, WVW_VWV​) to produce three different views of the same token. One view highlights what this token is searching for (Query), another highlights what this token offers to others (Key), and the third carries the actual information (Value).

Here WQ,WK∈Rdmodel×dkW_Q, W_K \in \mathbb{R}^{d_{\text{model}} \times d_k}WQ​,WK​∈Rdmodel​×dk​ and WV∈Rdmodel×dvW_V \in \mathbb{R}^{d_{\text{model}} \times d_v}WV​∈Rdmodel​×dv​ are the learned projections (where dmodeld_{\text{model}}dmodel​ is the model's overall hidden dimension size, and dk,dvd_k, d_vdk​,dv​ are the dimensions of the queries/keys and values respectively).[1]

Q, K, and V projections showing one token embedding transformed into query, key, and value roles. Q, K, and V projections showing one token embedding transformed into query, key, and value roles.
Q, K, and V are learned views of the same input embedding. Q and K do routing; V carries the content that gets mixed after softmax.

Q and K determine where to look. V determines what information to extract. The separation of "routing" (Q, K) from "content" (V) is fundamental to why attention is so powerful.


The core formula

The scaled dot-product attention formula computes a weighted combination of value vectors based on the compatibility between queries and keys:[1]

Attention(Q,K,V)=softmax(Q⋅KTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk​​Q⋅KT​)V

When some key positions aren't allowed, add a mask matrix MMM after scaling. An allowed location has Mij=0M_{ij}=0Mij​=0; a blocked location has Mij=−∞M_{ij}=-\inftyMij​=−∞:

Attention(Q,K,V;M)=softmax(QKTdk+M)V\text{Attention}(Q, K, V; M) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)VAttention(Q,K,V;M)=softmax(dk​​QKT​+M)V

Step by step

StepOperationWhat it means
Compute scoresS=Q⋅KTS = Q \cdot K^TS=Q⋅KTBuild an n×nn \times nn×n matrix where SijS_{ij}Sij​ measures how much token iii should attend to token jjj.
ScaleL=S/dkL = S / \sqrt{d_k}L=S/dk​​Keep softmax logits in a range where gradients stay useful.
MaskAdd 000 or −∞-\infty−∞ to each logitBlock future keys for causal attention or padded keys for batching.
Normalizeαij=exp⁡(Lij)∑kexp⁡(Lik)\alpha_{ij} = \frac{\exp(L_{ij})}{\sum_k \exp(L_{ik})}αij​=∑k​exp(Lik​)exp(Lij​)​Turn each row into weights that sum to 1.
Aggregateα⋅V\alpha \cdot Vα⋅VBlend value vectors according to the attention weights.

The shape flow is compact enough to keep beside the formula. Queries and keys build one routing matrix; values join only after softmax:

Diagram showing Q [B, h, Nq, d_k], QKᵀ / √d_k + mask [B, h, Nq, Nk], Kᵀ [B, h, d_k, Nk], and softmax over keys [B, h, Nq, Nk]. Diagram showing Q [B, h, Nq, d_k], QKᵀ / √d_k + mask [B, h, Nq, Nk], Kᵀ [B, h, d_k, Nk], and softmax over keys [B, h, Nq, Nk].
Q [B, h, Nq, d_k], QKᵀ / √d_k + mask [B, h, Nq, Nk], Kᵀ [B, h, d_k, Nk], and softmax over keys [B, h, Nq, Nk].
Worked two-token attention trace showing raw scores, scaled logits, softmax weights, and weighted output vectors. Worked two-token attention trace showing raw scores, scaled logits, softmax weights, and weighted output vectors.
A small numeric trace makes the formula concrete: dot products become scaled logits, softmax turns them into row-normalized weights, and those weights blend the value vectors.

A trace with real numbers

Before running PyTorch code, walk through one step with actual vectors. Imagine a two-word sequence, "order delayed." Each word lives in a 2-dimensional space for this toy example, and we have learned tiny projection matrices that give:

TokenQuery vectorKey vectorValue vector
order[1.0, 0.5][0.8, 0.2][2.0, 1.0]
delayed[0.5, 1.0][0.3, 0.9][1.0, 2.0]

Step 1: raw scores

Compute the dot product of every query with every key:

S=QKT=[1.0⋅0.8+0.5⋅0.21.0⋅0.3+0.5⋅0.90.5⋅0.8+1.0⋅0.20.5⋅0.3+1.0⋅0.9]=[0.900.750.601.05]S = QK^T = \begin{bmatrix} 1.0 \cdot 0.8 + 0.5 \cdot 0.2 & 1.0 \cdot 0.3 + 0.5 \cdot 0.9 \\ 0.5 \cdot 0.8 + 1.0 \cdot 0.2 & 0.5 \cdot 0.3 + 1.0 \cdot 0.9 \end{bmatrix} = \begin{bmatrix} 0.90 & 0.75 \\ 0.60 & 1.05 \end{bmatrix}S=QKT=[1.0⋅0.8+0.5⋅0.20.5⋅0.8+1.0⋅0.2​1.0⋅0.3+0.5⋅0.90.5⋅0.3+1.0⋅0.9​]=[0.900.60​0.751.05​]

Step 2: scale

With dk=2d_k = 2dk​=2, divide by 2≈1.414\sqrt{2} \approx 1.4142​≈1.414:

L=[0.640.530.420.74]L = \begin{bmatrix} 0.64 & 0.53 \\ 0.42 & 0.74 \end{bmatrix}L=[0.640.42​0.530.74​]

Step 3: softmax

Normalize each row so it sums to 1:

α=[0.530.470.420.58]\alpha = \begin{bmatrix} 0.53 & 0.47 \\ 0.42 & 0.58 \end{bmatrix}α=[0.530.42​0.470.58​]

Step 4: weighted values

Multiply the weights by the value vectors:

  • New "order" vector = 0.53⋅[2.0,1.0]+0.47⋅[1.0,2.0]=[1.53,1.47]0.53 \cdot [2.0, 1.0] + 0.47 \cdot [1.0, 2.0] = [1.53, 1.47]0.53⋅[2.0,1.0]+0.47⋅[1.0,2.0]=[1.53,1.47]
  • New "delayed" vector = 0.42⋅[2.0,1.0]+0.58⋅[1.0,2.0]=[1.42,1.58]0.42 \cdot [2.0, 1.0] + 0.58 \cdot [1.0, 2.0] = [1.42, 1.58]0.42⋅[2.0,1.0]+0.58⋅[1.0,2.0]=[1.42,1.58]

The output replaces each original embedding with a blend of the whole sequence, weighted by relevance. Even in this toy example, "order" pulls slightly more from its own value (0.53) than from "delayed" (0.47), while "delayed" mixes both values nearly evenly. In a real model with 512 or 4096 dimensions, this blending happens across thousands of numbers at once.

The same arithmetic is easy to verify without a tensor library:

a-trace-with-real-numbers.py
1import math 2 3Q = [[1.0, 0.5], [0.5, 1.0]] 4K = [[0.8, 0.2], [0.3, 0.9]] 5V = [[2.0, 1.0], [1.0, 2.0]] 6 7def softmax(row: list[float]) -> list[float]: 8 shift = max(row) 9 exps = [math.exp(x - shift) for x in row] 10 total = sum(exps) 11 return [x / total for x in exps] 12 13scores = [[sum(q_i * k_i for q_i, k_i in zip(q, k)) for k in K] for q in Q] 14scaled = [[score / math.sqrt(2) for score in row] for row in scores] 15weights = [softmax(row) for row in scaled] 16outputs = [ 17 [sum(weight * value[col] for weight, value in zip(row, V)) for col in range(2)] 18 for row in weights 19] 20 21print([[round(x, 2) for x in row] for row in weights]) 22print([[round(x, 2) for x in row] for row in outputs]) 23print([round(sum(row), 3) for row in weights])
Output
1[[0.53, 0.47], [0.42, 0.58]] 2[[1.53, 1.47], [1.42, 1.58]] 3[1.0, 1.0]

PyTorch implementation

Here's how the scaled dot-product attention formula translates into practical PyTorch code. The function takes the Query, Key, and Value tensors along with an optional mask (used for causal or padding purposes). It computes the attention scores, scales them by the square root of the dimension, normalizes them into probabilities, and returns both the final weighted output and the attention weights themselves.

pytorch-implementation.py
1import torch 2import torch.nn.functional as F 3import math 4 5def scaled_dot_product_attention( 6 Q: torch.Tensor, # (batch, n_heads, seq_len, d_k) 7 K: torch.Tensor, # (batch, n_heads, seq_len, d_k) 8 V: torch.Tensor, # (batch, n_heads, seq_len, d_v) 9 mask: torch.Tensor | None = None, # complete visibility mask, broadcastable to (B, h, n, n) 10 dropout_p: float = 0.0, 11 training: bool = True, 12) -> tuple[torch.Tensor, torch.Tensor]: 13 """Scaled dot-product attention (Vaswani et al., 2017).""" 14 d_k = Q.size(-1) 15 16 # Step 1: Compute raw attention scores 17 attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, n, n) 18 19 # Step 2: Scale by head width to control logit spread 20 attn_scores = attn_scores / math.sqrt(d_k) 21 22 # Step 3: Apply mask (causal or padding) 23 if mask is not None: 24 if not bool(mask.any(dim=-1).all()): 25 raise ValueError("each query row must have at least one visible key") 26 attn_scores = attn_scores.masked_fill(~mask, float('-inf')) 27 28 # Step 4: Softmax normalization (each row sums to 1) 29 attn_weights = F.softmax(attn_scores, dim=-1) 30 31 # Optional: dropout on attention weights (regularization) 32 if dropout_p > 0.0: 33 attn_weights = F.dropout(attn_weights, p=dropout_p, training=training) 34 35 # Step 5: Weighted aggregation of values 36 output = torch.matmul(attn_weights, V) # (B, h, n, d_v) 37 38 return output, attn_weights 39 40Q = torch.tensor([[[[1.0, 0.5], [0.5, 1.0]]]]) 41K = torch.tensor([[[[0.8, 0.2], [0.3, 0.9]]]]) 42V = torch.tensor([[[[2.0, 1.0], [1.0, 2.0]]]]) 43causal = torch.tensor([[[[True, False], [True, True]]]]) 44output, weights = scaled_dot_product_attention(Q, K, V, causal) 45 46pretty_weights = [[round(float(value), 3) for value in row] for row in weights[0, 0]] 47print("causal weights:", pretty_weights) 48print("token 0 future weight:", weights[0, 0, 0, 1].item()) 49print("output shape:", tuple(output.shape))
Output
1causal weights: [[1.0, 0.0], [0.421, 0.579]] 2token 0 future weight: 0.0 3output shape: (1, 1, 2, 2)

Two masking details matter in real code. First, softmax must run along the key dimension (dim=-1) so each query row sums to 1; normalizing along the query axis silently changes the operation. Second, a row with no permitted key has no valid attention distribution. -inf makes that mistake visible as NaN. Replacing it with a finite negative value hides the bug: softmax assigns weight to blocked keys because every blocked logit ties. Ensure each active query has at least one allowed key (causal attention includes its own position), or explicitly suppress outputs for padded query rows.

This short failure case makes the second rule concrete:

fully-masked-rows-need-explicit-handling.py
1import torch 2 3scores = torch.tensor([[0.8, 0.1], [0.4, -0.2]]) 4visible = torch.tensor([[True, False], [False, False]]) 5 6finite_fill_weights = torch.softmax(scores.masked_fill(~visible, -1e4), dim=-1) 7active_rows = visible.any(dim=-1, keepdim=True) 8served_weights = torch.where(active_rows, finite_fill_weights, torch.zeros_like(finite_fill_weights)) 9 10print("finite fill, invalid row:", finite_fill_weights[1].tolist()) 11print("after padded-query suppression:", served_weights[1].tolist()) 12print("valid row ignores blocked key:", served_weights[0].tolist())
Output
1finite fill, invalid row: [0.5, 0.5] 2after padded-query suppression: [0.0, 0.0] 3valid row ignores blocked key: [1.0, 0.0]

Common shape mistakes

The most frequent debugging task in transformer work is tracing a dimension mismatch. Here are two typical mistakes and the exact symptoms they produce.

Forgetting to transpose K

If you write torch.matmul(Q, K) instead of torch.matmul(Q, K.transpose(-2, -1)), the inner dimensions won't align. With Q shape (B, h, n, d_k) and K shape (B, h, n, d_k), the operation attempts to multiply two arrays whose final matrix axes don't form (n×dk)(dk×n)(n \times d_k)(d_k \times n)(n×dk​)(dk​×n), so it fails before softmax.

The fix is always the same: the last two dimensions of K must be swapped so the matrix multiplication becomes (n×dk)⋅(dk×n)(n \times d_k) \cdot (d_k \times n)(n×dk​)⋅(dk​×n), yielding the (n×n)(n \times n)(n×n) score matrix.

Forgetting to scale

If you skip attn_scores / math.sqrt(d_k), the code won't crash. Under the independent unit-variance setup below, a dk=64d_k=64dk​=64 raw dot product has standard deviation 8 rather than 1. That larger spread can saturate softmax and reduce routing gradients. Inspect score statistics and attention entropy when debugging, then restore the scale factor unless you're intentionally testing a different attention formulation.

Illustrative self-attention heatmap for 'The package missed the dock' Illustrative self-attention heatmap for 'The package missed the dock'
Invented attention weights for "The package missed the dock": each row shows how a query could distribute weight across keys. Use this to learn matrix orientation, not to claim what a trained head represents.

Why scale by dk\sqrt{d_k}dk​​? The variance proof

Imagine combining many warehouse sensor signals into one routing score. Without scaling, adding more signals (higher dkd_kdk​) makes raw dot products larger and larger until softmax saturates. Dividing by dk\sqrt{d_k}dk​​ normalizes the score so the model can still distinguish several plausible routes instead of locking onto one too early.

Here's the derivation that motivated the transformer's scale factor. For this calculation, assume entries of qqq and kkk are independent with mean 0 and variance 1. Learned activations won't satisfy those assumptions exactly, but the calculation explains why unscaled logits begin with dimension-dependent spread.[1]

Assume

qi,kiq_i, k_iqi​,ki​ are independent components, each with mean 000 and variance 111. The proof needs only those two moments, not a specific distribution, which is why it holds for any reasonable initialization. The dot product is:

q⋅k=∑i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_iq⋅k=∑i=1dk​​qi​ki​

Each term qikiq_i k_iqi​ki​ has E[qiki]=0\mathbb{E}[q_i k_i] = 0E[qi​ki​]=0 (since E[qi]=E[ki]=0\mathbb{E}[q_i] = \mathbb{E}[k_i] = 0E[qi​]=E[ki​]=0 and they're independent). The variance is:

Var(qiki)=E[(qiki)2]−(E[qiki])2=E[qi2]E[ki2]−0=1⋅1=1\text{Var}(q_i k_i) = \mathbb{E}[(q_i k_i)^2] - (\mathbb{E}[q_i k_i])^2 = \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] - 0 = 1 \cdot 1 = 1Var(qi​ki​)=E[(qi​ki​)2]−(E[qi​ki​])2=E[qi2​]E[ki2​]−0=1⋅1=1

By the sum of independent variances:

Var(q⋅k)=dk\text{Var}(q \cdot k) = d_kVar(q⋅k)=dk​

So the standard deviation of the raw dot product is dk\sqrt{d_k}dk​​. As dkd_kdk​ grows under this model, logits spread more widely before softmax:

dkd_kdk​Std(q⋅k)\text{Std}(q \cdot k)Std(q⋅k)Std(q⋅k/dk)\text{Std}(q \cdot k / \sqrt{d_k})Std(q⋅k/dk​​)
164.01.0
648.01.0
51222.61.0
409664.01.0

After dividing by dk\sqrt{d_k}dk​​

Var ⁣(q⋅kdk)=dkdk=1\text{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1Var(dk​​q⋅k​)=dk​dk​​=1 under these assumptions. Scaling doesn't promise a particular learned attention pattern; it removes a predictable source of width-dependent logit growth.

Charts showing raw dot-product standard deviation growing with sqrt(d_k), while scaled logits stay near unit variance and avoid softmax saturation. Charts showing raw dot-product standard deviation growing with sqrt(d_k), while scaled logits stay near unit variance and avoid softmax saturation.
Under the independent unit-variance assumptions, dividing by $\sqrt{d_k}$ keeps logit variance stable as head dimension grows; without scaling, wider logits can make softmax gradients weak.

See it in code. This experiment samples independent unit-variance query/key vectors and checks the standard-deviation calculation rather than choosing one dramatic softmax row:

after-dividing-by-sqrtdk.py
1import math 2import random 3import statistics 4 5rng = random.Random(7) 6 7def dot_products(width: int, samples: int = 5000) -> list[float]: 8 return [ 9 sum(rng.gauss(0, 1) * rng.gauss(0, 1) for _ in range(width)) 10 for _ in range(samples) 11 ] 12 13for width in (16, 64, 512): 14 raw = dot_products(width) 15 raw_std = statistics.pstdev(raw) 16 scaled_std = statistics.pstdev([x / math.sqrt(width) for x in raw]) 17 print(f"d_k={width:3d}: raw std={raw_std:5.2f}, scaled std={scaled_std:4.2f}")
Output
1d_k= 16: raw std= 4.07, scaled std=1.02 2d_k= 64: raw std= 8.09, scaled std=1.01 3d_k=512: raw std=22.39, scaled std=0.99

The sampled values won't be exactly the theoretical values, but their trend should match: raw spread grows with width while scaled spread stays near one.


Three types of attention

The transformer architecture uses attention in three distinct patterns. Bidirectional attention is like a support dashboard where every event in an order timeline can see every other event. Causal attention processes that timeline left to right, where each new event can only use earlier events. Cross-attention is like a reply generator reading from a separate policy document while writing the customer response.

1. Bidirectional self-attention (encoder)

Every non-padding token may attend to every other non-padding token; there is no future-token mask. Encoder architectures such as BERT use this pattern for tasks where the full input is available. For example, given a complete sentence, every word's representation can use every other word:

text
1"The package missed the dock" 2 Token "missed" attends to: [The, package, missed, the, dock] (full context)

2. Causal self-attention (decoder)

Each token can only attend to itself and previous tokens. Future positions are masked with −∞-\infty−∞. Used in decoder architectures such as GPT-style and other autoregressive language models for generation tasks where the model must predict the next token without seeing the future. For instance, when processing a sequence step-by-step, the model progressively builds context but remains strictly blind to upcoming words:

text
1"The package missed the dock" 2 Token "missed" attends to: [The, package, missed] (only past + self) 3 Token "dock" attends to: [The, package, missed, the, dock] (full history)

The causal mask is a lower-triangular matrix that ensures each position only looks at itself and the positions before it. This minimal Python version takes the sequence length as input and outputs a boolean matrix where True indicates an allowed connection and False indicates a masked one.

2-causal-self-attention-decoder.py
1def create_causal_mask(seq_len: int) -> list[list[bool]]: 2 return [[key_pos <= query_pos for key_pos in range(seq_len)] for query_pos in range(seq_len)] 3 4mask = create_causal_mask(4) 5for row in mask: 6 print(row) 7print(mask[0] == [True, False, False, False]) 8print(mask[1] == [True, True, False, False]) 9print(mask[3] == [True, True, True, True])
Output
1[True, False, False, False] 2[True, True, False, False] 3[True, True, True, False] 4[True, True, True, True] 5True 6True 7True

3. Cross-attention (encoder-decoder)

Queries come from one sequence, Keys and Values from another. The original Transformer uses this pattern in its decoder: encoder outputs provide keys and values, while current decoder states provide queries.[1] The resulting weights choose which source positions contribute to each target representation:

text
1Encoder output (source): "package delayed warehouse" provides K, V 2Decoder state (target): "order status ___" provides Q 3 4Q from decoder times K from encoder gives attention weights 5Weights times V from encoder give decoder context

In practice, cross-attention usually applies a source padding mask so decoder tokens don't attend to padded encoder positions. It doesn't use a causal mask over the source sequence, because the encoder has already seen the whole input.

Self-attention produces a square query-by-key matrix. Cross-attention doesn't have to: two target queries reading three source positions produce a 2 x 3 routing matrix.

cross-attention-can-be-rectangular.py
1import math 2 3decoder_queries = [[1.0, 0.0], [0.0, 1.0]] 4encoder_keys = [[1.0, 0.0], [0.2, 0.8], [0.0, 1.0]] 5encoder_values = [[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]] 6 7def softmax(row: list[float]) -> list[float]: 8 exps = [math.exp(x - max(row)) for x in row] 9 return [value / sum(exps) for value in exps] 10 11logits = [ 12 [sum(q_i * k_i for q_i, k_i in zip(q, k)) / math.sqrt(2) for k in encoder_keys] 13 for q in decoder_queries 14] 15weights = [softmax(row) for row in logits] 16context = [ 17 [sum(weight * value[d] for weight, value in zip(row, encoder_values)) for d in range(2)] 18 for row in weights 19] 20 21print(f"routing shape: {len(weights)} x {len(weights[0])}") 22print("row sums:", [round(sum(row), 3) for row in weights]) 23print("context:", [[round(x, 3) for x in row] for row in context])
Output
1routing shape: 2 x 3 2row sums: [1.0, 1.0] 3context: [[0.623, 0.377], [0.393, 0.607]]
Attention TypeQ SourceK, V SourceMaskArchitecture
Bidirectional SelfSame sequenceSame sequencePadding mask only (if needed)BERT, Vision Transformer (ViT)
Causal SelfSame sequenceSame sequenceLower-triangular, plus padding if neededGPT, Llama
CrossTarget sequenceSource sequenceSource padding mask commonT5, original Transformer
Comparison of bidirectional self-attention, causal self-attention, and cross-attention mask patterns. Comparison of bidirectional self-attention, causal self-attention, and cross-attention mask patterns.
The three attention patterns differ by mask and by where Q, K, and V come from. Causal self-attention is the only one that blocks future tokens by default.

Multi-head attention

Imagine giving several analysts separate learned views of an order timeline, then letting a final projection combine their reports. Multi-head attention makes that capacity available: each head has its own query, key, and value projections. A trained head may acquire a recognizable routing pattern, but the architecture doesn't assign jobs such as "policy clause head" in advance.

Instead of one large attention operation, we run hhh parallel heads, each with dk=dmodel/hd_k = d_{\text{model}} / hdk​=dmodel​/h:[1]

MultiHead(Q,K,V)=Concat(head1,…,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^OMultiHead(Q,K,V)=Concat(head1​,…,headh​)WO

Reading the formula

Instead of running one big attention operation, we project into hhh narrower heads that attend independently. After each head produces its output, we concatenate them and multiply by a final matrix WOW^OWO so information from those routes rejoins the residual stream.

Each head computes: headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi​=Attention(QWiQ​,KWiK​,VWiV​)

Multi-head attention layout showing projected QKV split into parallel heads, concatenated, and passed through the output projection. Multi-head attention layout showing projected QKV split into parallel heads, concatenated, and passed through the output projection.
Multi-head attention splits the projected representation into smaller routing heads, runs attention in parallel, then concatenates and projects back to the model dimension.

What do attention heads learn?

Interpretability results here are evidence about specific trained models, not a promise about every transformer. Voita et al. found positional and syntactic patterns among heads in neural machine translation encoders, and Michel et al. found that many heads in the models they tested could be removed at inference with limited quality loss.[2][3] Olsson et al. studied induction heads, circuits that support copying patterns in autoregressive models under their experiments.[4]

Result from a studyUseful inferenceUnsafe inference
Some heads show consistent patternsInspect heads when debugging or researching a trained modelEvery head has a named human-readable purpose
Some tested models tolerate head pruningRedundancy can exist and can be measuredArbitrarily deleting heads preserves a new model's quality
Induction-head circuits can emergeAttention can implement copy-like sequence algorithmsAn attention heatmap alone proves causal model behavior

Critical nuance: same asymptotic attention FLOPs

Multi-head attention doesn't increase the asymptotic cost of the attention core when you keep dmodeld_{\text{model}}dmodel​ fixed. It restructures the work. Single-head attention on dmodel=512d_{\text{model}}=512dmodel​=512 uses roughly the same leading-order FLOPs (Floating Point Operations) for score computation and value mixing as 8-head attention with dk=64d_k=64dk​=64 each, because h×dk=dmodelh \times d_k = d_{\text{model}}h×dk​=dmodel​. The dense Q/K/V and output projections still cost O(ndmodel2)O(n d_{\text{model}}^2)O(ndmodel2​) either way.

The arithmetic also makes the compute comparison testable. Splitting a fixed width into more heads doesn't change the total number of score-and-value multiply-adds in the attention core:

multi-head-core-work-at-fixed-width.py
1seq_len = 2048 2d_model = 512 3 4for heads in (1, 8, 16): 5 d_head = d_model // heads 6 score_and_value_work = 2 * heads * seq_len**2 * d_head 7 print(f"heads={heads:2d}, d_head={d_head:3d}, core units={score_and_value_work:,}")
Output
1heads= 1, d_head=512, core units=4,294,967,296 2heads= 8, d_head= 64, core units=4,294,967,296 3heads=16, d_head= 32, core units=4,294,967,296

Here's a runnable PyTorch shape implementation. To keep shapes simple, it uses the common case dv=dk=dmodel/hd_v = d_k = d_{\text{model}} / hdv​=dk​=dmodel​/h. The three dense projection layers each produce all head slices at once; reshaping exposes those slices to batched attention.

critical-nuance-same-asymptotic-attention.py
1import torch 2import torch.nn.functional as F 3 4class MultiHeadAttention(torch.nn.Module): 5 def __init__(self, d_model: int, n_heads: int): 6 super().__init__() 7 assert d_model % n_heads == 0, "d_model must be divisible by n_heads" 8 self.d_k = d_model // n_heads 9 self.n_heads = n_heads 10 11 # Each projection produces all per-head slices in one dense operation. 12 self.W_q = torch.nn.Linear(d_model, d_model) 13 self.W_k = torch.nn.Linear(d_model, d_model) 14 self.W_v = torch.nn.Linear(d_model, d_model) 15 self.W_o = torch.nn.Linear(d_model, d_model) 16 17 def forward(self, x: torch.Tensor) -> torch.Tensor: 18 B, N, D = x.shape 19 20 # Project and reshape: (B, N, D) -> (B, h, N, d_k) 21 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 22 K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 23 V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 24 25 # Scaled dot-product attention per head. 26 out = F.scaled_dot_product_attention(Q, K, V, is_causal=True) 27 28 # Concatenate heads and project back: (B, h, N, d_k) -> (B, N, D) 29 out = out.transpose(1, 2).contiguous().view(B, N, D) 30 return self.W_o(out) 31 32torch.manual_seed(0) 33layer = MultiHeadAttention(d_model=8, n_heads=2) 34x = torch.randn(1, 4, 8) 35y = layer(x) 36print("input shape:", tuple(x.shape)) 37print("head shape:", (1, layer.n_heads, 4, layer.d_k)) 38print("output shape:", tuple(y.shape))
Output
1input shape: (1, 4, 8) 2head shape: (1, 2, 4, 4) 3output shape: (1, 4, 8)

From-scratch checklist

Before relying on torch.nn.MultiheadAttention, make sure you can implement the pieces above by hand. High-level modules are useful later, but they hide the exact shape and masking mistakes that break production attention code.

For a decoder-only block, your implementation should do each step explicitly:

  1. Project x into Q, K, and V.
  2. Reshape (B, N, D) into (B, h, N, d_k).
  3. Compute Q @ K.transpose(-2, -1) / sqrt(d_k).
  4. Apply the causal mask before softmax.
  5. Use numerically stable softmax along the key dimension.
  6. Multiply attention weights by V.
  7. Concatenate heads back to (B, N, D).
  8. Apply the output projection.

Two shape assertions catch many bugs:

from-scratch-checklist.py
1assert Q.shape == (B, n_heads, N, d_k) 2assert attn_weights.shape == (B, n_heads, N, N)

One mask assertion catches leakage: assert not causal[0][1]. Token 0 must not see token 1.

If a model can read future tokens during training, the loss can look excellent while generation fails. That's why causal masking isn't a detail. It's the contract that makes next-token prediction honest.

Gradient flow through attention

Attention has three gradient paths:

PathWhat receives gradientWhy it matters
Output to Vvalue projection and upstream token representationsteaches what information each token should carry
Output to attention weightssoftmax probabilitiesteaches which source positions should matter
Weights back to Q and Kquery/key projectionsteaches the routing function itself

The scale factor helps keep the Q/K path trainable. If scores become too large, softmax saturates, attention weights become nearly one-hot, and the gradient through the routing path becomes tiny. If the causal mask is wrong, gradients flow through illegal future positions and the model learns a shortcut it can't use at inference time.

When debugging attention, don't only print the final output. Inspect the score range, the mask, one row of attention weights, and the gradient norm on W_q and W_k. Those four checks tell you whether the model is learning routing or only moving values through a broken router.


Complexity analysis

MetricComplexityExplanation
Time (attention core, single head)O(n2⋅dk)O(n^2 \cdot d_k)O(n2⋅dk​)Both QKTQK^TQKT and αV\alpha VαV touch all n2n^2n2 query-key pairs
Time (attention core, full multi-head)O(n2⋅dmodel)O(n^2 \cdot d_{\text{model}})O(n2⋅dmodel​)Across hhh heads, h⋅dk=dmodelh \cdot d_k = d_{\text{model}}h⋅dk​=dmodel​
Time (Q/K/V + output projections)O(n⋅dmodel2)O(n \cdot d_{\text{model}}^2)O(n⋅dmodel2​)Dense linear layers before and after the attention core
Memory (naive weights)O(n2)O(n^2)O(n2) per headThe score or weight matrix is n×nn \times nn×n
ParametersO(dmodel2)O(d_{\text{model}}^2)O(dmodel2​)WQ,WK,WV,WOW_Q, W_K, W_V, W_OWQ​,WK​,WV​,WO​ are dense projections
Attention complexity chart showing attention matrix entries growing quadratically with sequence length. Attention complexity chart showing attention matrix entries growing quadratically with sequence length.
Attention score matrices grow with $n^2$. Moving from 512 to 2K tokens is 4x more tokens but 16x more score entries.

Concrete memory example

Batch of 8, 32 heads, n=8192n = 8192n=8192, FP16 (16-bit floating point format):

Attention matrix=8×32×81922×2 bytes≈34.4 GB≈32 GiB\text{Attention matrix} = 8 \times 32 \times 8192^2 \times 2 \text{ bytes} \approx 34.4 \text{ GB} \approx 32 \text{ GiB}Attention matrix=8×32×81922×2 bytes≈34.4 GB≈32 GiB

How the numbers work

8 sequences in the batch x 32 attention heads x 8192² entries per attention map x 2 bytes per number (FP16) = about 34.4 GB of raw storage, or about 32 GiB, just for one attention-score tensor before values needed for backpropagation, model weights, or optimizer state.

how-the-numbers-work.py
1batch = 8 2heads = 32 3seq_len = 8192 4bytes_per_fp16 = 2 5 6bytes_total = batch * heads * seq_len**2 * bytes_per_fp16 7gb = bytes_total / 1_000_000_000 8gib = bytes_total / 1024**3 9 10print(round(gb, 1), "GB") 11print(round(gib, 1), "GiB") 12print(round(gb, 1) == 34.4, round(gib, 1) == 32.0)
Output
134.4 GB 232.0 GiB 3True True

In the naive formulation, this is why O(n2)O(n^2)O(n2) temporary memory becomes a bottleneck for long sequences. Doubling sequence length makes one materialized score tensor four times larger. Fused kernels can avoid storing that full tensor, but exact dense attention still computes interactions across all query-key pairs.

Training vs. inference bottlenecks

During training, or in a naive implementation, the temporary n×nn \times nn×n score matrix is the obvious O(n2)O(n^2)O(n2) memory problem. During autoregressive decoding, optimized kernels often avoid materializing that matrix, but the model still has to repeatedly read the accumulated KV cache (past keys and values) for every new token. That makes incremental inference heavily constrained by memory bandwidth, not just FLOPs.[5]

Architectural variants attack that persistent cache cost directly. Multi-query attention (MQA) shares one key/value head across all query heads, while grouped-query attention (GQA) uses a smaller number of key/value heads than query heads.[5][6] Both shrink KV-cache bytes. Whether they improve latency enough for a workload is a measurement question because kernel choice, batch size, and quality requirements also matter.

For a simplified decoder cache, the storage count is proportional to layers x tokens x kv_heads x head_dim x 2 (the final factor stores both K and V). Keeping 32 query heads but reducing KV heads changes this count directly:

kv-head-count-controls-cache-bytes.py
1layers = 32 2tokens = 8192 3head_dim = 128 4bytes_per_value = 2 # FP16 5 6def cache_gib(kv_heads: int) -> float: 7 bytes_total = layers * tokens * kv_heads * head_dim * 2 * bytes_per_value 8 return bytes_total / 1024**3 9 10mha = cache_gib(32) 11for label, kv_heads in [("MHA", 32), ("GQA", 8), ("MQA", 1)]: 12 size = cache_gib(kv_heads) 13 print(f"{label}: kv_heads={kv_heads:2d}, cache={size:.2f} GiB, reduction={mha / size:.0f}x")
Output
1MHA: kv_heads=32, cache=4.00 GiB, reduction=1x 2GQA: kv_heads= 8, cache=1.00 GiB, reduction=4x 3MQA: kv_heads= 1, cache=0.12 GiB, reduction=32x

FlashAttention and MQA/GQA solve different problems. FlashAttention cuts temporary attention I/O. MQA/GQA cut persistent KV-cache size.


FlashAttention: tiled exact attention

FlashAttention (Dao et al., 2022)[7] computes exact dense attention with an IO-aware algorithm. Its attention working memory is linear in sequence length rather than storing a quadratic score matrix; it changes the execution order, not the attention definition.

Core idea

Instead of materializing the full n×nn \times nn×n attention matrix in GPU HBM, FlashAttention computes attention in tiles that fit in on-chip SRAM, using online softmax to avoid storing the full matrix while preserving the exact result.[7][8]

PropertyStandard AttentionFlashAttention
Memory for attention computationO(n2)O(n^2)O(n2)O(n)O(n)O(n)
HBM trafficWrites and rereads large score matricesKeeps tiles on chip and avoids materializing full score matrix
ExactYesYes
Wall-clock speedBaselineDepends on hardware, shapes, dtype, and kernel availability

FlashAttention is IO-aware. It minimizes traffic between large off-chip HBM and small on-chip SRAM by restructuring the computation order, not by changing the mathematical result.

FlashAttention tiling diagram showing Q tiles streamed against K and V tiles in SRAM without materializing the full attention matrix. FlashAttention tiling diagram showing Q tiles streamed against K and V tiles in SRAM without materializing the full attention matrix.
FlashAttention streams tiles through fast on-chip memory and uses online softmax state so the full attention matrix never has to live in HBM.

Using fused attention in practice doesn't require writing custom CUDA kernels. PyTorch exposes scaled dot-product attention; the backend chosen for a given run depends on device, dtype, shapes, masks, and framework version. On CUDA, the function attempts to select an enabled implementation based on its inputs, but a fused kernel isn't guaranteed for every call.[9] The public contract is the result, so start by checking it against the explicit computation:

sdpa-matches-explicit-causal-attention.py
1import math 2import torch 3from torch.nn.functional import scaled_dot_product_attention as sdpa 4 5torch.manual_seed(4) 6Q = torch.randn(1, 1, 4, 8) 7K = torch.randn(1, 1, 4, 8) 8V = torch.randn(1, 1, 4, 8) 9 10scores = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1)) 11causal = torch.ones(4, 4, dtype=torch.bool).tril() 12explicit = torch.softmax(scores.masked_fill(~causal, float("-inf")), dim=-1) @ V 13library = sdpa(Q, K, V, is_causal=True) 14 15print("output shape:", tuple(library.shape)) 16print("matches explicit computation:", torch.allclose(library, explicit, atol=1e-6))
Output
1output shape: (1, 1, 4, 8) 2matches explicit computation: True

Two API details prevent quiet bugs when you replace explicit attention with PyTorch's fused primitive. For F.scaled_dot_product_attention, True in a boolean attn_mask means the position participates in attention; that's the inverse of nn.MultiheadAttention's boolean key_padding_mask. Also, dropout_p is always applied when it's greater than zero, so pass 0.0 during evaluation.[9]


Softmax numerical stability

A subtle but critical implementation detail in attention mechanisms is numerical stability. The naive mathematical formulation of softmax overflows for large inputs because the exponential function exe^{x}ex grows exceptionally fast. When dealing with dot products from large vector spaces, even scaled ones can occasionally produce large positive values, causing the exponentiation to result in Inf (infinity) floating-point values and completely breaking the model's training process.

The max-shift technique

To prevent this overflow, deep learning frameworks implement a numerically stable version of softmax. They subtract the maximum value from the input vector before exponentiating:

softmax(xi)=exi−max⁡(x)∑jexj−max⁡(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_j e^{x_j - \max(\mathbf{x})}}softmax(xi​)=∑j​exj​−max(x)exi​−max(x)​

Why shifting doesn't change the answer

Subtracting max⁡(x)\max(\mathbf{x})max(x) shifts all logits by the same constant. Because of the properties of exponentials, this constant factors out and cancels between the numerator and denominator. The resulting probabilities don't change, but the highest value being exponentiated is exactly 000 (since max⁡(x)−max⁡(x)=0\max(\mathbf{x}) - \max(\mathbf{x}) = 0max(x)−max(x)=0), so the maximum exponential evaluated is e0=1e^0 = 1e0=1. This keeps all numbers numerically safe and avoids overflow.[8]

Building on this, online softmax (Milakov & Gimelshein, 2018)[8] extends the max-shift technique to compute softmax in a single streaming pass. Instead of needing to read the entire vector into memory to find the maximum, online softmax tracks the running maximum and the running sum of exponentials simultaneously, dynamically correcting the sum as a new, larger maximum is found. This streaming capability is exactly the mathematical foundation that allows FlashAttention's memory-efficient block tiling.

why-shifting-doesnt-change-the-answer.py
1import math 2 3def stable_softmax(logits: list[float]) -> list[float]: 4 shift = max(logits) 5 exps = [math.exp(x - shift) for x in logits] 6 total = sum(exps) 7 return [x / total for x in exps] 8 9probs = stable_softmax([1000.0, 1001.0, 999.0]) 10print([round(p, 4) for p in probs]) 11print(round(sum(probs), 6)) 12print(probs[1] == max(probs))
Output
1[0.2447, 0.6652, 0.09] 21.0 3True

Attention in modern architectures

The mask shape and the source of Q, K, and V reveal which information an architecture permits each output to use:

ArchitectureSelf-AttentionCross-AttentionTypical objective / task
BERTBidirectionalNoMasked language modeling
Decoder-only LMCausalNoNext-token prediction
T5Bidirectional (enc) + Causal (dec)YesSpan corruption
Stable DiffusionSelf (in U-Net)Yes (text to image)Diffusion denoising
WhisperBidirectional (enc) + Causal (dec)YesSpeech-to-text seq2seq
Vision Transformer (ViT)BidirectionalNoImage classification / self-supervision

Encoder-only models like BERT and ViT use bidirectional self-attention because their goal is to build a comprehensive understanding of the entire input at once. They process the full text or image simultaneously to generate rich embeddings. In contrast, decoder-only language models use causal self-attention because they must generate output step-by-step. If they could look ahead at future tokens during training, they would simply "cheat" instead of learning to predict.

Encoder-decoder models (like T5 and Whisper) mix these approaches. They use a bidirectional encoder to fully understand the input (text or audio), and a causal decoder to generate the output. Cross-attention acts as the bridge between the two, allowing the decoder to continuously refer back to the encoder's rich representation of the source material.

Decoder-only language models exercise causal self-attention; encoder-decoder models add cross-attention; ViT uses bidirectional attention over image tokens. Those mechanics matter more than a market survey here. The next chapter will reuse the same attention operation after turning an image into patch tokens.


What you should be able to defend

TierYou should be able to defend
FoundationalDerive Attention(Q,K,V)=softmax(QKT/dk)V\text{Attention}(Q,K,V)=\text{softmax}(QK^T/\sqrt{d_k})VAttention(Q,K,V)=softmax(QKT/dk​​)V and annotate the dimensions of each tensor.
IntermediateDerive why dividing by dk\sqrt{d_k}dk​​ normalizes score variance under independent unit-variance assumptions.
AdvancedDistinguish bidirectional self-attention, causal self-attention, and cross-attention by Q/K/V source and mask.
AdvancedExplain why multi-head attention runs several lower-dimensional attention heads without changing leading attention-core asymptotic FLOPs when dmodeld_{\text{model}}dmodel​ is fixed.
AdvancedSeparate O(n2d)O(n^2 d)O(n2d) attention-core time, O(ndmodel2)O(n d_{\text{model}}^2)O(ndmodel2​) projection time, and O(n2)O(n^2)O(n2) naive attention memory.
AdvancedExplain why FlashAttention cuts temporary attention I/O while MQA/GQA shrink persistent KV-cache traffic during decoding.
AdvancedDescribe max-shift softmax, online softmax, and why numerical stability matters inside attention kernels.

Mistakes that break attention

MistakeSymptomFix
Forgetting the K transposeMatmul shape error before softmaxCompute Q @ K.transpose(-2, -1) so scores have shape (B, h, N, N).
Skipping dk\sqrt{d_k}dk​​ scalingScore spread rises with head width under the initialization model; attention may saturateDivide scores by math.sqrt(d_k) before masking and softmax.
Mixing up attention memory and timeBad long-context sizing estimatesTrack attention-core FLOPs, projection FLOPs, temporary score memory, and KV-cache memory separately.
Treating FlashAttention and GQA as the same optimizationWrong performance diagnosisUse FlashAttention for temporary attention I/O; use MQA/GQA for persistent KV-cache traffic.
Forgetting padding masks in cross-attentionDecoder attends to fake source tokensMask padded encoder positions even though source positions don't need a causal mask.
Confusing Q/K/V rolesHard-to-debug routing behaviorRemember: Q and K choose where information flows; V carries the content being mixed.
Using naive softmaxInf, NaN, or unstable probabilitiesSubtract the row max, or use a framework primitive that already applies stable softmax.
Softmax on the wrong axisRows don't sum to 1; routing is meaninglessApply softmax over the key dimension (dim=-1), not the query dimension.
Allowing a fully masked query rowNaN with -inf, or silent blocked-key mixing with a finite fillGuarantee one valid key per active query, or zero/skip outputs for padded queries.

Going deeper

"Why not use additive attention instead of dot-product?"

Bahdanau et al. (2015)[10] used additive attention: score(q,k)=vTtanh⁡(Wqq+Wkk)\text{score}(q, k) = v^T \tanh(W_q q + W_k k)score(q,k)=vTtanh(Wq​q+Wk​k). In the Transformer paper, Vaswani et al. note that additive and dot-product attention behave similarly at small dimensions, but dot-product maps much better to batched matrix multiplication and is much faster at the larger dimensions used in transformers.[1] The 1dk\frac{1}{\sqrt{d_k}}dk​​1​ scaling is what keeps dot-product attention stable as dkd_kdk​ grows.

"Can attention attend to nothing / everything equally?"

If all scores are equal, softmax produces a uniform distribution 1/n1/n1/n, and the output is the average of all value vectors. If one score dominates, softmax approximates an argmax, so the output is approximately one value vector. The sharpness of the logits controls where on this spectrum you land.

Next Step
Continue to Vision Transformers and Image Encoders

You now understand attention over token sequences; the next chapter shows how images become patch tokens so the same attention machinery can process visual data.

PreviousEmbedding Similarity & Quantization
Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail
References

Attention Is All You Need.

Vaswani, A., et al. · 2017

Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting.

Voita, E., et al. · 2019 · ACL 2019

Are Sixteen Heads Really Better than One?.

Michel, P., Levy, O., & Neubig, G. · 2019 · NeurIPS 2019

In-context Learning and Induction Heads.

Olsson, C., et al. · 2022

Fast Transformer Decoding: One Write-Head is All You Need.

Shazeer, N. · 2019 · arXiv preprint

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.

Ainslie, J., et al. · 2023 · EMNLP 2023

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022

Online normalizer calculation for softmax.

Milakov, M. & Gimelshein, N. · 2018

torch.nn.functional.scaled_dot_product_attention

PyTorch Contributors · 2026

Neural Machine Translation by Jointly Learning to Align and Translate.

Bahdanau, D., Cho, K., & Bengio, Y. · 2015 · ICLR 2015