Learn scaled dot-product attention from first principles, including Q/K/V routing, variance scaling, masks, multi-head shapes, KV-cache costs, and FlashAttention.
The previous chapter showed how to compare embedding vectors for retrieval. layers do a related job inside a sequence: each token compares itself to other tokens, decides which positions matter, and blends information from those positions.
When you read "The package missed the dock because it was closed," you instantly know that "it" refers to the dock, not the package. Your brain routes focus to the relevant word. Scaled dot-product attention is the transformer operation that learns that kind of routing with multiplication.
Attention is the core information-routing mechanism in modern transformers. Given a sequence of , it answers: "For each token, which other positions should it use, and how much?" Once that routing idea is clear, the math and code become direct translations of it.
Attention answers a simple question: "For each token in a sequence, which other tokens should it focus on?"
Think of reading a sentence: when you see "it" in "The package missed the dock because it was closed", your brain attends back to "dock" to resolve what "it" means. Self-attention does this computationally: every token computes a weighted combination of other token representations based on learned relevance.[1]
To perform attention, the model transforms each token representation into three different roles. Think of it like a support search system:
| Vector | Question it answers | Support-search analogy |
|---|---|---|
| Query (Q) | What am I looking for? | An order-status question |
| Key (K) | What do I contain? | Searchable tags on a shipment event |
| Value (V) | What information do I carry? | The actual carrier scan or policy text |
The model creates these three vectors by multiplying each token representation through learned weight matrices:
Each token representation is multiplied by three different weight matrices (, , ) to produce three different views of the same token. One view highlights what this token is searching for (Query), another highlights what this token offers to others (Key), and the third carries the actual information (Value).
Here and are the learned projections (where is the model's overall hidden dimension size, and are the dimensions of the queries/keys and values respectively).[1]
Q and K determine where to look. V determines what information to extract. The separation of "routing" (Q, K) from "content" (V) is fundamental to why attention is so powerful.
The scaled dot-product attention formula computes a weighted combination of value vectors based on the compatibility between queries and keys:[1]
When some key positions aren't allowed, add a mask matrix after scaling. An allowed location has ; a blocked location has :
| Step | Operation | What it means |
|---|---|---|
| Compute scores | Build an matrix where measures how much token should attend to token . | |
| Scale | Keep softmax logits in a range where gradients stay useful. | |
| Mask | Add or to each logit | Block future keys for causal attention or padded keys for batching. |
| Normalize | Turn each row into weights that sum to 1. | |
| Aggregate | Blend value vectors according to the attention weights. |
The shape flow is compact enough to keep beside the formula. Queries and keys build one routing matrix; values join only after softmax:
Before running PyTorch code, walk through one step with actual vectors. Imagine a two-word sequence, "order delayed." Each word lives in a 2-dimensional space for this toy example, and we have learned tiny projection matrices that give:
| Token | Query vector | Key vector | Value vector |
|---|---|---|---|
| order | [1.0, 0.5] | [0.8, 0.2] | [2.0, 1.0] |
| delayed | [0.5, 1.0] | [0.3, 0.9] | [1.0, 2.0] |
Compute the dot product of every query with every key:
With , divide by :
Normalize each row so it sums to 1:
Multiply the weights by the value vectors:
The output replaces each original embedding with a blend of the whole sequence, weighted by relevance. Even in this toy example, "order" pulls slightly more from its own value (0.53) than from "delayed" (0.47), while "delayed" mixes both values nearly evenly. In a real model with 512 or 4096 dimensions, this blending happens across thousands of numbers at once.
The same arithmetic is easy to verify without a tensor library:
1import math
2
3Q = [[1.0, 0.5], [0.5, 1.0]]
4K = [[0.8, 0.2], [0.3, 0.9]]
5V = [[2.0, 1.0], [1.0, 2.0]]
6
7def softmax(row: list[float]) -> list[float]:
8 shift = max(row)
9 exps = [math.exp(x - shift) for x in row]
10 total = sum(exps)
11 return [x / total for x in exps]
12
13scores = [[sum(q_i * k_i for q_i, k_i in zip(q, k)) for k in K] for q in Q]
14scaled = [[score / math.sqrt(2) for score in row] for row in scores]
15weights = [softmax(row) for row in scaled]
16outputs = [
17 [sum(weight * value[col] for weight, value in zip(row, V)) for col in range(2)]
18 for row in weights
19]
20
21print([[round(x, 2) for x in row] for row in weights])
22print([[round(x, 2) for x in row] for row in outputs])
23print([round(sum(row), 3) for row in weights])1[[0.53, 0.47], [0.42, 0.58]]
2[[1.53, 1.47], [1.42, 1.58]]
3[1.0, 1.0]Here's how the scaled dot-product attention formula translates into practical PyTorch code. The function takes the Query, Key, and Value tensors along with an optional mask (used for causal or padding purposes). It computes the attention scores, scales them by the square root of the dimension, normalizes them into probabilities, and returns both the final weighted output and the attention weights themselves.
1import torch
2import torch.nn.functional as F
3import math
4
5def scaled_dot_product_attention(
6 Q: torch.Tensor, # (batch, n_heads, seq_len, d_k)
7 K: torch.Tensor, # (batch, n_heads, seq_len, d_k)
8 V: torch.Tensor, # (batch, n_heads, seq_len, d_v)
9 mask: torch.Tensor | None = None, # complete visibility mask, broadcastable to (B, h, n, n)
10 dropout_p: float = 0.0,
11 training: bool = True,
12) -> tuple[torch.Tensor, torch.Tensor]:
13 """Scaled dot-product attention (Vaswani et al., 2017)."""
14 d_k = Q.size(-1)
15
16 # Step 1: Compute raw attention scores
17 attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, n, n)
18
19 # Step 2: Scale by head width to control logit spread
20 attn_scores = attn_scores / math.sqrt(d_k)
21
22 # Step 3: Apply mask (causal or padding)
23 if mask is not None:
24 if not bool(mask.any(dim=-1).all()):
25 raise ValueError("each query row must have at least one visible key")
26 attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
27
28 # Step 4: Softmax normalization (each row sums to 1)
29 attn_weights = F.softmax(attn_scores, dim=-1)
30
31 # Optional: dropout on attention weights (regularization)
32 if dropout_p > 0.0:
33 attn_weights = F.dropout(attn_weights, p=dropout_p, training=training)
34
35 # Step 5: Weighted aggregation of values
36 output = torch.matmul(attn_weights, V) # (B, h, n, d_v)
37
38 return output, attn_weights
39
40Q = torch.tensor([[[[1.0, 0.5], [0.5, 1.0]]]])
41K = torch.tensor([[[[0.8, 0.2], [0.3, 0.9]]]])
42V = torch.tensor([[[[2.0, 1.0], [1.0, 2.0]]]])
43causal = torch.tensor([[[[True, False], [True, True]]]])
44output, weights = scaled_dot_product_attention(Q, K, V, causal)
45
46pretty_weights = [[round(float(value), 3) for value in row] for row in weights[0, 0]]
47print("causal weights:", pretty_weights)
48print("token 0 future weight:", weights[0, 0, 0, 1].item())
49print("output shape:", tuple(output.shape))1causal weights: [[1.0, 0.0], [0.421, 0.579]]
2token 0 future weight: 0.0
3output shape: (1, 1, 2, 2)Two masking details matter in real code. First, softmax must run along the key dimension (dim=-1) so each query row sums to 1; normalizing along the query axis silently changes the operation. Second, a row with no permitted key has no valid attention distribution. -inf makes that mistake visible as NaN. Replacing it with a finite negative value hides the bug: softmax assigns weight to blocked keys because every blocked logit ties. Ensure each active query has at least one allowed key (causal attention includes its own position), or explicitly suppress outputs for padded query rows.
This short failure case makes the second rule concrete:
1import torch
2
3scores = torch.tensor([[0.8, 0.1], [0.4, -0.2]])
4visible = torch.tensor([[True, False], [False, False]])
5
6finite_fill_weights = torch.softmax(scores.masked_fill(~visible, -1e4), dim=-1)
7active_rows = visible.any(dim=-1, keepdim=True)
8served_weights = torch.where(active_rows, finite_fill_weights, torch.zeros_like(finite_fill_weights))
9
10print("finite fill, invalid row:", finite_fill_weights[1].tolist())
11print("after padded-query suppression:", served_weights[1].tolist())
12print("valid row ignores blocked key:", served_weights[0].tolist())1finite fill, invalid row: [0.5, 0.5]
2after padded-query suppression: [0.0, 0.0]
3valid row ignores blocked key: [1.0, 0.0]The most frequent debugging task in transformer work is tracing a dimension mismatch. Here are two typical mistakes and the exact symptoms they produce.
If you write torch.matmul(Q, K) instead of torch.matmul(Q, K.transpose(-2, -1)), the inner dimensions won't align. With Q shape (B, h, n, d_k) and K shape (B, h, n, d_k), the operation attempts to multiply two arrays whose final matrix axes don't form , so it fails before softmax.
The fix is always the same: the last two dimensions of K must be swapped so the matrix multiplication becomes , yielding the score matrix.
If you skip attn_scores / math.sqrt(d_k), the code won't crash. Under the independent unit-variance setup below, a raw dot product has standard deviation 8 rather than 1. That larger spread can saturate softmax and reduce routing gradients. Inspect score statistics and attention entropy when debugging, then restore the scale factor unless you're intentionally testing a different attention formulation.
Imagine combining many warehouse sensor signals into one routing score. Without scaling, adding more signals (higher ) makes raw dot products larger and larger until softmax saturates. Dividing by normalizes the score so the model can still distinguish several plausible routes instead of locking onto one too early.
Here's the derivation that motivated the transformer's scale factor. For this calculation, assume entries of and are independent with mean 0 and variance 1. Learned activations won't satisfy those assumptions exactly, but the calculation explains why unscaled logits begin with dimension-dependent spread.[1]
are independent components, each with mean and variance . The proof needs only those two moments, not a specific distribution, which is why it holds for any reasonable initialization. The dot product is:
Each term has (since and they're independent). The variance is:
By the sum of independent variances:
So the standard deviation of the raw dot product is . As grows under this model, logits spread more widely before softmax:
| 16 | 4.0 | 1.0 |
| 64 | 8.0 | 1.0 |
| 512 | 22.6 | 1.0 |
| 4096 | 64.0 | 1.0 |
under these assumptions. Scaling doesn't promise a particular learned attention pattern; it removes a predictable source of width-dependent logit growth.
See it in code. This experiment samples independent unit-variance query/key vectors and checks the standard-deviation calculation rather than choosing one dramatic softmax row:
1import math
2import random
3import statistics
4
5rng = random.Random(7)
6
7def dot_products(width: int, samples: int = 5000) -> list[float]:
8 return [
9 sum(rng.gauss(0, 1) * rng.gauss(0, 1) for _ in range(width))
10 for _ in range(samples)
11 ]
12
13for width in (16, 64, 512):
14 raw = dot_products(width)
15 raw_std = statistics.pstdev(raw)
16 scaled_std = statistics.pstdev([x / math.sqrt(width) for x in raw])
17 print(f"d_k={width:3d}: raw std={raw_std:5.2f}, scaled std={scaled_std:4.2f}")1d_k= 16: raw std= 4.07, scaled std=1.02
2d_k= 64: raw std= 8.09, scaled std=1.01
3d_k=512: raw std=22.39, scaled std=0.99The sampled values won't be exactly the theoretical values, but their trend should match: raw spread grows with width while scaled spread stays near one.
The transformer architecture uses attention in three distinct patterns. Bidirectional attention is like a support dashboard where every event in an order timeline can see every other event. Causal attention processes that timeline left to right, where each new event can only use earlier events. Cross-attention is like a reply generator reading from a separate policy document while writing the customer response.
Every non-padding token may attend to every other non-padding token; there is no future-token mask. Encoder architectures such as BERT use this pattern for tasks where the full input is available. For example, given a complete sentence, every word's representation can use every other word:
1"The package missed the dock"
2 Token "missed" attends to: [The, package, missed, the, dock] (full context)Each token can only attend to itself and previous tokens. Future positions are masked with . Used in decoder architectures such as GPT-style and other autoregressive language models for generation tasks where the model must predict the next token without seeing the future. For instance, when processing a sequence step-by-step, the model progressively builds context but remains strictly blind to upcoming words:
1"The package missed the dock"
2 Token "missed" attends to: [The, package, missed] (only past + self)
3 Token "dock" attends to: [The, package, missed, the, dock] (full history)The causal mask is a lower-triangular matrix that ensures each position only looks at itself and the positions before it. This minimal Python version takes the sequence length as input and outputs a boolean matrix where True indicates an allowed connection and False indicates a masked one.
1def create_causal_mask(seq_len: int) -> list[list[bool]]:
2 return [[key_pos <= query_pos for key_pos in range(seq_len)] for query_pos in range(seq_len)]
3
4mask = create_causal_mask(4)
5for row in mask:
6 print(row)
7print(mask[0] == [True, False, False, False])
8print(mask[1] == [True, True, False, False])
9print(mask[3] == [True, True, True, True])1[True, False, False, False]
2[True, True, False, False]
3[True, True, True, False]
4[True, True, True, True]
5True
6True
7TrueQueries come from one sequence, Keys and Values from another. The original Transformer uses this pattern in its decoder: encoder outputs provide keys and values, while current decoder states provide queries.[1] The resulting weights choose which source positions contribute to each target representation:
1Encoder output (source): "package delayed warehouse" provides K, V
2Decoder state (target): "order status ___" provides Q
3
4Q from decoder times K from encoder gives attention weights
5Weights times V from encoder give decoder contextIn practice, cross-attention usually applies a source padding mask so decoder tokens don't attend to padded encoder positions. It doesn't use a causal mask over the source sequence, because the encoder has already seen the whole input.
Self-attention produces a square query-by-key matrix. Cross-attention doesn't have to: two target queries reading three source positions produce a 2 x 3 routing matrix.
1import math
2
3decoder_queries = [[1.0, 0.0], [0.0, 1.0]]
4encoder_keys = [[1.0, 0.0], [0.2, 0.8], [0.0, 1.0]]
5encoder_values = [[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]
6
7def softmax(row: list[float]) -> list[float]:
8 exps = [math.exp(x - max(row)) for x in row]
9 return [value / sum(exps) for value in exps]
10
11logits = [
12 [sum(q_i * k_i for q_i, k_i in zip(q, k)) / math.sqrt(2) for k in encoder_keys]
13 for q in decoder_queries
14]
15weights = [softmax(row) for row in logits]
16context = [
17 [sum(weight * value[d] for weight, value in zip(row, encoder_values)) for d in range(2)]
18 for row in weights
19]
20
21print(f"routing shape: {len(weights)} x {len(weights[0])}")
22print("row sums:", [round(sum(row), 3) for row in weights])
23print("context:", [[round(x, 3) for x in row] for row in context])1routing shape: 2 x 3
2row sums: [1.0, 1.0]
3context: [[0.623, 0.377], [0.393, 0.607]]| Attention Type | Q Source | K, V Source | Mask | Architecture |
|---|---|---|---|---|
| Bidirectional Self | Same sequence | Same sequence | Padding mask only (if needed) | BERT, Vision Transformer (ViT) |
| Causal Self | Same sequence | Same sequence | Lower-triangular, plus padding if needed | GPT, Llama |
| Cross | Target sequence | Source sequence | Source padding mask common | T5, original Transformer |
Imagine giving several analysts separate learned views of an order timeline, then letting a final projection combine their reports. Multi-head attention makes that capacity available: each head has its own query, key, and value projections. A trained head may acquire a recognizable routing pattern, but the architecture doesn't assign jobs such as "policy clause head" in advance.
Instead of one large attention operation, we run parallel heads, each with :[1]
Instead of running one big attention operation, we project into narrower heads that attend independently. After each head produces its output, we concatenate them and multiply by a final matrix so information from those routes rejoins the residual stream.
Each head computes:
Interpretability results here are evidence about specific trained models, not a promise about every transformer. Voita et al. found positional and syntactic patterns among heads in neural machine translation encoders, and Michel et al. found that many heads in the models they tested could be removed at inference with limited quality loss.[2][3] Olsson et al. studied induction heads, circuits that support copying patterns in autoregressive models under their experiments.[4]
| Result from a study | Useful inference | Unsafe inference |
|---|---|---|
| Some heads show consistent patterns | Inspect heads when debugging or researching a trained model | Every head has a named human-readable purpose |
| Some tested models tolerate head pruning | Redundancy can exist and can be measured | Arbitrarily deleting heads preserves a new model's quality |
| Induction-head circuits can emerge | Attention can implement copy-like sequence algorithms | An attention heatmap alone proves causal model behavior |
Multi-head attention doesn't increase the asymptotic cost of the attention core when you keep fixed. It restructures the work. Single-head attention on uses roughly the same leading-order FLOPs (Floating Point Operations) for score computation and value mixing as 8-head attention with each, because . The dense Q/K/V and output projections still cost either way.
The arithmetic also makes the compute comparison testable. Splitting a fixed width into more heads doesn't change the total number of score-and-value multiply-adds in the attention core:
1seq_len = 2048
2d_model = 512
3
4for heads in (1, 8, 16):
5 d_head = d_model // heads
6 score_and_value_work = 2 * heads * seq_len**2 * d_head
7 print(f"heads={heads:2d}, d_head={d_head:3d}, core units={score_and_value_work:,}")1heads= 1, d_head=512, core units=4,294,967,296
2heads= 8, d_head= 64, core units=4,294,967,296
3heads=16, d_head= 32, core units=4,294,967,296Here's a runnable PyTorch shape implementation. To keep shapes simple, it uses the common case . The three dense projection layers each produce all head slices at once; reshaping exposes those slices to batched attention.
1import torch
2import torch.nn.functional as F
3
4class MultiHeadAttention(torch.nn.Module):
5 def __init__(self, d_model: int, n_heads: int):
6 super().__init__()
7 assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
8 self.d_k = d_model // n_heads
9 self.n_heads = n_heads
10
11 # Each projection produces all per-head slices in one dense operation.
12 self.W_q = torch.nn.Linear(d_model, d_model)
13 self.W_k = torch.nn.Linear(d_model, d_model)
14 self.W_v = torch.nn.Linear(d_model, d_model)
15 self.W_o = torch.nn.Linear(d_model, d_model)
16
17 def forward(self, x: torch.Tensor) -> torch.Tensor:
18 B, N, D = x.shape
19
20 # Project and reshape: (B, N, D) -> (B, h, N, d_k)
21 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
22 K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
23 V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
24
25 # Scaled dot-product attention per head.
26 out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
27
28 # Concatenate heads and project back: (B, h, N, d_k) -> (B, N, D)
29 out = out.transpose(1, 2).contiguous().view(B, N, D)
30 return self.W_o(out)
31
32torch.manual_seed(0)
33layer = MultiHeadAttention(d_model=8, n_heads=2)
34x = torch.randn(1, 4, 8)
35y = layer(x)
36print("input shape:", tuple(x.shape))
37print("head shape:", (1, layer.n_heads, 4, layer.d_k))
38print("output shape:", tuple(y.shape))1input shape: (1, 4, 8)
2head shape: (1, 2, 4, 4)
3output shape: (1, 4, 8)Before relying on torch.nn.MultiheadAttention, make sure you can implement the pieces above by hand. High-level modules are useful later, but they hide the exact shape and masking mistakes that break production attention code.
For a decoder-only block, your implementation should do each step explicitly:
x into Q, K, and V.(B, N, D) into (B, h, N, d_k).Q @ K.transpose(-2, -1) / sqrt(d_k).(B, N, D).Two shape assertions catch many bugs:
1assert Q.shape == (B, n_heads, N, d_k)
2assert attn_weights.shape == (B, n_heads, N, N)One mask assertion catches leakage: assert not causal[0][1]. Token 0 must not see token 1.
If a model can read future tokens during training, the loss can look excellent while generation fails. That's why causal masking isn't a detail. It's the contract that makes next-token prediction honest.
Attention has three gradient paths:
| Path | What receives gradient | Why it matters |
|---|---|---|
| Output to V | value projection and upstream token representations | teaches what information each token should carry |
| Output to attention weights | softmax probabilities | teaches which source positions should matter |
| Weights back to Q and K | query/key projections | teaches the routing function itself |
The scale factor helps keep the Q/K path trainable. If scores become too large, softmax saturates, attention weights become nearly one-hot, and the gradient through the routing path becomes tiny. If the causal mask is wrong, gradients flow through illegal future positions and the model learns a shortcut it can't use at inference time.
When debugging attention, don't only print the final output. Inspect the score range, the mask, one row of attention weights, and the gradient norm on W_q and W_k. Those four checks tell you whether the model is learning routing or only moving values through a broken router.
| Metric | Complexity | Explanation |
|---|---|---|
| Time (attention core, single head) | Both and touch all query-key pairs | |
| Time (attention core, full multi-head) | Across heads, | |
| Time (Q/K/V + output projections) | Dense linear layers before and after the attention core | |
| Memory (naive weights) | per head | The score or weight matrix is |
| Parameters | are dense projections |
Batch of 8, 32 heads, , FP16 (16-bit floating point format):
8 sequences in the batch x 32 attention heads x 8192² entries per attention map x 2 bytes per number (FP16) = about 34.4 GB of raw storage, or about 32 GiB, just for one attention-score tensor before values needed for backpropagation, model weights, or optimizer state.
1batch = 8
2heads = 32
3seq_len = 8192
4bytes_per_fp16 = 2
5
6bytes_total = batch * heads * seq_len**2 * bytes_per_fp16
7gb = bytes_total / 1_000_000_000
8gib = bytes_total / 1024**3
9
10print(round(gb, 1), "GB")
11print(round(gib, 1), "GiB")
12print(round(gb, 1) == 34.4, round(gib, 1) == 32.0)134.4 GB
232.0 GiB
3True TrueIn the naive formulation, this is why temporary memory becomes a bottleneck for long sequences. Doubling sequence length makes one materialized score tensor four times larger. Fused kernels can avoid storing that full tensor, but exact dense attention still computes interactions across all query-key pairs.
During training, or in a naive implementation, the temporary score matrix is the obvious memory problem. During autoregressive decoding, optimized kernels often avoid materializing that matrix, but the model still has to repeatedly read the accumulated KV cache (past keys and values) for every new token. That makes incremental inference heavily constrained by memory bandwidth, not just FLOPs.[5]
Architectural variants attack that persistent cache cost directly. Multi-query attention (MQA) shares one key/value head across all query heads, while grouped-query attention (GQA) uses a smaller number of key/value heads than query heads.[5][6] Both shrink KV-cache bytes. Whether they improve latency enough for a workload is a measurement question because kernel choice, batch size, and quality requirements also matter.
For a simplified decoder cache, the storage count is proportional to layers x tokens x kv_heads x head_dim x 2 (the final factor stores both K and V). Keeping 32 query heads but reducing KV heads changes this count directly:
1layers = 32
2tokens = 8192
3head_dim = 128
4bytes_per_value = 2 # FP16
5
6def cache_gib(kv_heads: int) -> float:
7 bytes_total = layers * tokens * kv_heads * head_dim * 2 * bytes_per_value
8 return bytes_total / 1024**3
9
10mha = cache_gib(32)
11for label, kv_heads in [("MHA", 32), ("GQA", 8), ("MQA", 1)]:
12 size = cache_gib(kv_heads)
13 print(f"{label}: kv_heads={kv_heads:2d}, cache={size:.2f} GiB, reduction={mha / size:.0f}x")1MHA: kv_heads=32, cache=4.00 GiB, reduction=1x
2GQA: kv_heads= 8, cache=1.00 GiB, reduction=4x
3MQA: kv_heads= 1, cache=0.12 GiB, reduction=32xFlashAttention and MQA/GQA solve different problems. FlashAttention cuts temporary attention I/O. MQA/GQA cut persistent KV-cache size.
FlashAttention (Dao et al., 2022)[7] computes exact dense attention with an IO-aware algorithm. Its attention working memory is linear in sequence length rather than storing a quadratic score matrix; it changes the execution order, not the attention definition.
Instead of materializing the full attention matrix in GPU HBM, FlashAttention computes attention in tiles that fit in on-chip SRAM, using online softmax to avoid storing the full matrix while preserving the exact result.[7][8]
| Property | Standard Attention | FlashAttention |
|---|---|---|
| Memory for attention computation | ||
| HBM traffic | Writes and rereads large score matrices | Keeps tiles on chip and avoids materializing full score matrix |
| Exact | Yes | Yes |
| Wall-clock speed | Baseline | Depends on hardware, shapes, dtype, and kernel availability |
FlashAttention is IO-aware. It minimizes traffic between large off-chip HBM and small on-chip SRAM by restructuring the computation order, not by changing the mathematical result.
Using fused attention in practice doesn't require writing custom CUDA kernels. PyTorch exposes scaled dot-product attention; the backend chosen for a given run depends on device, dtype, shapes, masks, and framework version. On CUDA, the function attempts to select an enabled implementation based on its inputs, but a fused kernel isn't guaranteed for every call.[9] The public contract is the result, so start by checking it against the explicit computation:
1import math
2import torch
3from torch.nn.functional import scaled_dot_product_attention as sdpa
4
5torch.manual_seed(4)
6Q = torch.randn(1, 1, 4, 8)
7K = torch.randn(1, 1, 4, 8)
8V = torch.randn(1, 1, 4, 8)
9
10scores = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))
11causal = torch.ones(4, 4, dtype=torch.bool).tril()
12explicit = torch.softmax(scores.masked_fill(~causal, float("-inf")), dim=-1) @ V
13library = sdpa(Q, K, V, is_causal=True)
14
15print("output shape:", tuple(library.shape))
16print("matches explicit computation:", torch.allclose(library, explicit, atol=1e-6))1output shape: (1, 1, 4, 8)
2matches explicit computation: TrueTwo API details prevent quiet bugs when you replace explicit attention with PyTorch's fused primitive. For F.scaled_dot_product_attention, True in a boolean attn_mask means the position participates in attention; that's the inverse of nn.MultiheadAttention's boolean key_padding_mask. Also, dropout_p is always applied when it's greater than zero, so pass 0.0 during evaluation.[9]
A subtle but critical implementation detail in attention mechanisms is numerical stability. The naive mathematical formulation of softmax overflows for large inputs because the exponential function grows exceptionally fast. When dealing with dot products from large vector spaces, even scaled ones can occasionally produce large positive values, causing the exponentiation to result in Inf (infinity) floating-point values and completely breaking the model's training process.
To prevent this overflow, deep learning frameworks implement a numerically stable version of softmax. They subtract the maximum value from the input vector before exponentiating:
Subtracting shifts all logits by the same constant. Because of the properties of exponentials, this constant factors out and cancels between the numerator and denominator. The resulting probabilities don't change, but the highest value being exponentiated is exactly (since ), so the maximum exponential evaluated is . This keeps all numbers numerically safe and avoids overflow.[8]
Building on this, online softmax (Milakov & Gimelshein, 2018)[8] extends the max-shift technique to compute softmax in a single streaming pass. Instead of needing to read the entire vector into memory to find the maximum, online softmax tracks the running maximum and the running sum of exponentials simultaneously, dynamically correcting the sum as a new, larger maximum is found. This streaming capability is exactly the mathematical foundation that allows FlashAttention's memory-efficient block tiling.
1import math
2
3def stable_softmax(logits: list[float]) -> list[float]:
4 shift = max(logits)
5 exps = [math.exp(x - shift) for x in logits]
6 total = sum(exps)
7 return [x / total for x in exps]
8
9probs = stable_softmax([1000.0, 1001.0, 999.0])
10print([round(p, 4) for p in probs])
11print(round(sum(probs), 6))
12print(probs[1] == max(probs))1[0.2447, 0.6652, 0.09]
21.0
3TrueThe mask shape and the source of Q, K, and V reveal which information an architecture permits each output to use:
| Architecture | Self-Attention | Cross-Attention | Typical objective / task |
|---|---|---|---|
| BERT | Bidirectional | No | Masked language modeling |
| Decoder-only LM | Causal | No | Next-token prediction |
| T5 | Bidirectional (enc) + Causal (dec) | Yes | Span corruption |
| Stable Diffusion | Self (in U-Net) | Yes (text to image) | Diffusion denoising |
| Whisper | Bidirectional (enc) + Causal (dec) | Yes | Speech-to-text seq2seq |
| Vision Transformer (ViT) | Bidirectional | No | Image classification / self-supervision |
Encoder-only models like BERT and ViT use bidirectional self-attention because their goal is to build a comprehensive understanding of the entire input at once. They process the full text or image simultaneously to generate rich embeddings. In contrast, decoder-only language models use causal self-attention because they must generate output step-by-step. If they could look ahead at future tokens during training, they would simply "cheat" instead of learning to predict.
Encoder-decoder models (like T5 and Whisper) mix these approaches. They use a bidirectional encoder to fully understand the input (text or audio), and a causal decoder to generate the output. Cross-attention acts as the bridge between the two, allowing the decoder to continuously refer back to the encoder's rich representation of the source material.
Decoder-only language models exercise causal self-attention; encoder-decoder models add cross-attention; ViT uses bidirectional attention over image tokens. Those mechanics matter more than a market survey here. The next chapter will reuse the same attention operation after turning an image into patch tokens.
| Tier | You should be able to defend |
|---|---|
| Foundational | Derive and annotate the dimensions of each tensor. |
| Intermediate | Derive why dividing by normalizes score variance under independent unit-variance assumptions. |
| Advanced | Distinguish bidirectional self-attention, causal self-attention, and cross-attention by Q/K/V source and mask. |
| Advanced | Explain why multi-head attention runs several lower-dimensional attention heads without changing leading attention-core asymptotic FLOPs when is fixed. |
| Advanced | Separate attention-core time, projection time, and naive attention memory. |
| Advanced | Explain why FlashAttention cuts temporary attention I/O while MQA/GQA shrink persistent KV-cache traffic during decoding. |
| Advanced | Describe max-shift softmax, online softmax, and why numerical stability matters inside attention kernels. |
| Mistake | Symptom | Fix |
|---|---|---|
| Forgetting the K transpose | Matmul shape error before softmax | Compute Q @ K.transpose(-2, -1) so scores have shape (B, h, N, N). |
| Skipping scaling | Score spread rises with head width under the initialization model; attention may saturate | Divide scores by math.sqrt(d_k) before masking and softmax. |
| Mixing up attention memory and time | Bad long-context sizing estimates | Track attention-core FLOPs, projection FLOPs, temporary score memory, and KV-cache memory separately. |
| Treating FlashAttention and GQA as the same optimization | Wrong performance diagnosis | Use FlashAttention for temporary attention I/O; use MQA/GQA for persistent KV-cache traffic. |
| Forgetting padding masks in cross-attention | Decoder attends to fake source tokens | Mask padded encoder positions even though source positions don't need a causal mask. |
| Confusing Q/K/V roles | Hard-to-debug routing behavior | Remember: Q and K choose where information flows; V carries the content being mixed. |
| Using naive softmax | Inf, NaN, or unstable probabilities | Subtract the row max, or use a framework primitive that already applies stable softmax. |
| Softmax on the wrong axis | Rows don't sum to 1; routing is meaningless | Apply softmax over the key dimension (dim=-1), not the query dimension. |
| Allowing a fully masked query row | NaN with -inf, or silent blocked-key mixing with a finite fill | Guarantee one valid key per active query, or zero/skip outputs for padded queries. |
Bahdanau et al. (2015)[10] used additive attention: . In the Transformer paper, Vaswani et al. note that additive and dot-product attention behave similarly at small dimensions, but dot-product maps much better to batched matrix multiplication and is much faster at the larger dimensions used in transformers.[1] The scaling is what keeps dot-product attention stable as grows.
If all scores are equal, softmax produces a uniform distribution , and the output is the average of all value vectors. If one score dominates, softmax approximates an argmax, so the output is approximately one value vector. The sharpness of the logits controls where on this spectrum you land.
Attention Is All You Need.
Vaswani, A., et al. · 2017
Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting.
Voita, E., et al. · 2019 · ACL 2019
Are Sixteen Heads Really Better than One?.
Michel, P., Levy, O., & Neubig, G. · 2019 · NeurIPS 2019
In-context Learning and Induction Heads.
Olsson, C., et al. · 2022
Fast Transformer Decoding: One Write-Head is All You Need.
Shazeer, N. · 2019 · arXiv preprint
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.
Ainslie, J., et al. · 2023 · EMNLP 2023
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022
Online normalizer calculation for softmax.
Milakov, M. & Gimelshein, N. · 2018
torch.nn.functional.scaled_dot_product_attention
PyTorch Contributors · 2026
Neural Machine Translation by Jointly Learning to Align and Translate.
Bahdanau, D., Cho, K., & Bengio, Y. · 2015 · ICLR 2015