Derive the attention formula, prove the scaling factor, and implement multi-head attention. Analyze O(n²) complexity and understand the three attention variants (self, causal, cross) in modern architectures.
When you read the sentence "The cat sat on the mat because it was tired," you instantly know that "it" refers to "the cat," not the mat. Your brain focuses on the most relevant word to make sense of the sentence. This ability to selectively focus on the right information is exactly what AI models need to understand language, and it's called attention.
Attention is the single most important mechanism in modern AI. Every transformer-based model (ChatGPT, BERT, Llama, Gemini) is built on it. In this article, you'll understand exactly how attention works, from the intuition all the way to the math and code.
💡 Key insight: Attention is best understood as an information routing mechanism. Given a sequence of words, it answers the question: "For each word, which other words should it pay attention to, and how much?" Once you grasp this idea, the math and code become natural extensions of the intuition.
Attention answers a simple question: "For each token in a sequence, which other tokens should it focus on?"
Think of reading a sentence: when you see "it" in "The cat sat on the mat because it was tired", your brain attends back to "cat" to resolve what "it" means. Self-attention does this computationally: every token computes a weighted combination of all other tokens based on relevance.[1]
To perform attention, the model transforms each word's representation into three different roles. Think of it like a library system:
The model creates these three vectors by multiplying each word's numerical representation (called an embedding, a list of numbers that captures a word's meaning) through learned weight matrices:
Reading the formula: each word's embedding is multiplied by three different weight matrices (, , ) to produce three different "views" of the same word. Think of it like putting the same sentence through three different lenses. One highlights what this word is searching for, another highlights what this word offers, and the third captures the actual information it carries.
Here and are the learned projections.[1]
| Vector | Role | Analogy |
|---|---|---|
| Query (Q) | "What am I looking for?" | A search query |
| Key (K) | "What do I contain?" | A database index key |
| Value (V) | "What information do I carry?" | The actual data returned by the search |
💡 Key insight: Q and K determine where to look. V determines what information to extract. The separation of "routing" (Q, K) from "content" (V) is fundamental to why attention is so powerful.
Step by step:
Here is how the scaled dot-product attention formula translates into practical PyTorch code. The function takes the Query, Key, and Value tensors along with an optional mask (used for causal or padding purposes). It computes the attention scores, scales them by the square root of the dimension, normalizes them into probabilities, and returns both the final weighted output and the attention weights themselves.
python1import torch 2import torch.nn.functional as F 3import math 4 5def scaled_dot_product_attention( 6 Q: torch.Tensor, # (batch, n_heads, seq_len, d_k) 7 K: torch.Tensor, # (batch, n_heads, seq_len, d_k) 8 V: torch.Tensor, # (batch, n_heads, seq_len, d_v) 9 mask: torch.Tensor | None = None, 10 dropout_p: float = 0.0, 11) -> tuple[torch.Tensor, torch.Tensor]: 12 """Scaled dot-product attention (Vaswani et al., 2017).""" 13 d_k = Q.size(-1) 14 15 # Step 1: Compute raw attention scores 16 attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, n, n) 17 18 # Step 2: Scale to unit variance 19 attn_scores = attn_scores / math.sqrt(d_k) 20 21 # Step 3: Apply mask (causal or padding) 22 if mask is not None: 23 attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) 24 25 # Step 4: Softmax normalization (each row sums to 1) 26 attn_weights = F.softmax(attn_scores, dim=-1) # (B, h, n, n) 27 28 # Optional: dropout on attention weights (regularization) 29 if dropout_p > 0.0: 30 attn_weights = F.dropout(attn_weights, p=dropout_p) 31 32 # Step 5: Weighted aggregation of values 33 output = torch.matmul(attn_weights, V) # (B, h, n, d_v) 34 35 return output, attn_weights
💡 Analogy: Imagine you're mixing music tracks. Without scaling, adding more instruments (higher ) makes the mix louder and louder until everything distorts (softmax saturates). Dividing by is like an automatic volume leveler: no matter how many instruments are playing, the overall volume stays in the sweet spot where you can distinguish each one clearly.
Here's the formal derivation for why the scaling factor is strictly necessary:
Assume: are independent components. The dot product is:
Each term has and (since ).
By the sum of independent variances:
So the standard deviation of the raw dot product is . As grows, dot products get larger and push softmax into saturation, where the output becomes nearly one-hot and gradients vanish:
| Softmax behavior | ||
|---|---|---|
| 16 | 4.0 | Moderate peakiness |
| 64 | 8.0 | Getting peaked |
| 512 | 22.6 | Extremely peaked ⚠️ |
| 4096 | 64.0 | Nearly one-hot 💀 |
After dividing by : regardless of , maintaining well-behaved softmax gradients throughout training.
💡 Key insight: Many engineers know to scale, but few can derive why specifically. The variance argument is the mathematical justification for this architectural choice.
The transformer architecture uses attention in three distinct patterns. Understanding when and why each is used is critical:
💡 Analogy: Bidirectional attention is like a conference call: everyone can hear everyone else at the same time. Causal attention is like a chain of walkie-talkies where each person can only hear messages from people who spoke before them. Cross-attention is like a translator listening to one conversation (encoder) and relaying information into another (decoder).
Every token attends to every other token with no masking. Used in encoder architectures (BERT, RoBERTa):
text1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat, on, the, mat] ← full context
Use case: Understanding/classification tasks where you have the full input.
Each token can only attend to itself and previous tokens. Future positions are masked with . Used in decoder architectures (GPT, Llama, Claude):
text1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat] ← only past + self 3 Token "mat" attends to: [The, cat, sat, on, the, mat] ← full history
The causal mask is a lower-triangular matrix that ensures each position only looks at itself and the positions before it. Here is a simple implementation that takes the sequence length as input and outputs a binary matrix where 1 indicates an allowed connection and 0 indicates a masked one.
python1def create_causal_mask(seq_len: int) -> torch.Tensor: 2 """Lower-triangular mask: position i can attend to positions [0, i].""" 3 return torch.tril(torch.ones(seq_len, seq_len)) 4 5# Result for seq_len=4: 6# [[1, 0, 0, 0], ← token 0 sees only itself 7# [1, 1, 0, 0], ← token 1 sees tokens 0, 1 8# [1, 1, 1, 0], ← token 2 sees tokens 0, 1, 2 9# [1, 1, 1, 1]] ← token 3 sees all tokens
Use case: Autoregressive generation, where the model must predict the next token without seeing the future.
Queries come from one sequence, Keys and Values from another. Used in encoder-decoder architectures (T5, original Transformer, diffusion models):
text1Encoder output (source): "Le chat est assis" → provides K, V 2Decoder state (target): "The cat is ___" → provides Q 3 4Q from decoder × K from encoder → attention weights 5Weights × V from encoder → decoder incorporates source information
Use case: Translation, summarization, text-to-image (DALL-E cross-attends text embeddings).
| Attention Type | Q Source | K, V Source | Mask | Architecture |
|---|---|---|---|---|
| Bidirectional Self | Same sequence | Same sequence | None | BERT, Vision Transformer (ViT) |
| Causal Self | Same sequence | Same sequence | Lower-triangular | GPT, Llama |
| Cross | Target sequence | Source sequence | None (typically) | T5, DALL-E |
💡 Analogy: Imagine you're reading a legal contract and you consult a committee: a grammar expert checks sentence structure, a domain expert understands the legal terms, and a context expert tracks which clauses reference each other. Multi-head attention works the same way. Each head becomes a specialist that focuses on different types of relationships (syntax, semantics, position), and their findings are combined for a complete understanding.
Instead of one large attention operation, we run parallel heads, each with :[1]
Reading the formula: instead of running one big attention operation, we split the input into smaller "heads" that each attend independently, like having specialists each looking at a different aspect of the same text. After each head produces its output, we stitch them back together and multiply by a final weight matrix to blend the specialists' findings into a single result.
Each head computes:
Research has revealed that individual heads specialize for different functions (Voita et al., 2019[2]; Olsson et al., 2022[3]):
| Head Type | What It Does | Found In |
|---|---|---|
| Positional heads | Attend to adjacent or nearby tokens | Early layers |
| Syntactic heads | Track subject-verb agreement, dependency arcs | Middle layers |
| Induction heads | Copy patterns from earlier context (key for in-context learning) | Middle layers |
| Rare word heads | Upweight infrequent tokens | Various layers |
| Semantic heads | Capture coreference, entity relationships | Later layers |
🔬 Research insight: Research shows that 20–40% of attention heads can be pruned without noticeable performance degradation, and in some layers, all but a single head can be removed (Michel et al., 2019)[4]. Models maintain significant redundancy as a form of robustness. Not all heads are equally important, but the system is resilient to losing any individual head.
Multi-head attention doesn't add computation. It restructures it. Single-head attention on uses the same FLOPs (Floating Point Operations) as 8-head attention with each, because .
The following module implements this parallel attention mechanism in PyTorch. It takes an input tensor of shape (batch, seq_len, d_model) and projects it into multiple heads for queries, keys, and values. After computing scaled dot-product attention independently for each head, it concatenates the results and applies a final linear projection to produce the output.
python1class MultiHeadAttention(torch.nn.Module): 2 def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0): 3 super().__init__() 4 assert d_model % n_heads == 0, "d_model must be divisible by n_heads" 5 self.d_k = d_model // n_heads 6 self.n_heads = n_heads 7 8 # Single large projections (more efficient than per-head matrices) 9 self.W_q = torch.nn.Linear(d_model, d_model) 10 self.W_k = torch.nn.Linear(d_model, d_model) 11 self.W_v = torch.nn.Linear(d_model, d_model) 12 self.W_o = torch.nn.Linear(d_model, d_model) 13 self.dropout = dropout 14 15 def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: 16 B, N, D = x.shape 17 18 # Project and reshape: (B, N, D) → (B, h, N, d_k) 19 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 20 K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 21 V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 22 23 # Scaled dot-product attention per head 24 out, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout) 25 26 # Concatenate heads and project back: (B, h, N, d_k) → (B, N, D) 27 out = out.transpose(1, 2).contiguous().view(B, N, D) 28 return self.W_o(out)
| Metric | Complexity | Explanation |
|---|---|---|
| Time | : | |
| Memory | Must store the full attention weight matrix | |
| Parameters | Weight matrices , each |
Concrete memory example: Batch of 8, 32 heads, , FP16:
Reading the formula: 8 sequences in the batch × 32 attention heads × 8192² entries per attention map (one for every pair of tokens) × 2 bytes per number (FP16). That's 32 GB just for the attention scores, more than the entire memory of many GPUs.
This is why the memory is the primary bottleneck for long sequences, not the compute.
FlashAttention (Dao et al., 2022)[5] is now the standard in virtually every production LLM. It reduces memory from to without any approximation: the output is mathematically identical.
Core idea: Instead of materializing the full attention matrix in GPU HBM (High Bandwidth Memory, the main GPU memory; slow, ~2TB/s), compute attention in tiles that fit in SRAM (Static Random Access Memory, a small, ultra-fast on-chip cache; fast, ~19TB/s on A100), using online softmax to avoid storing the full matrix:[6]
| Property | Standard Attention | FlashAttention |
|---|---|---|
| Memory | ✅ | |
| I/O complexity | ||
| Exact | Yes | Yes ✅ |
| Wall-clock speed | 1× | 2–4× faster |
Where is the size of GPU SRAM (~20MB on A100).
The key insight: FlashAttention is IO-aware. It minimizes reads/writes between slow GPU HBM and fast SRAM by restructuring the computation order, using the tiling trick from online softmax.
Using FlashAttention in practice does not require writing custom CUDA kernels. PyTorch provides a unified interface that automatically routes to the most efficient implementation (like FlashAttention) based on your hardware. It takes the Query, Key, and Value tensors and returns the output directly, bypassing the need to store the full attention matrix.
python1# In production, simply use PyTorch's built-in SDPA: 2from torch.nn.functional import scaled_dot_product_attention as sdpa 3 4# This automatically dispatches to FlashAttention when available 5output = sdpa(Q, K, V, attn_mask=mask, is_causal=True) 6# No manual implementation needed; PyTorch handles the kernel selection
A subtle but critical implementation detail: naive softmax overflows for large inputs because grows exponentially.
The max-shift trick: Subtract the maximum before exponentiating:
Reading the formula: subtracting shifts all logits by the same constant, so probabilities do not change, but exponentials stay numerically safe and avoid overflow (Inf).[6]
Online softmax (Milakov & Gimelshein, 2018)[6] extends this to compute softmax in a single streaming pass. This is the foundation for FlashAttention's tiling.
Understanding which attention pattern each architecture uses:
| Architecture | Self-Attention | Cross-Attention | Pre-training |
|---|---|---|---|
| BERT | Bidirectional | ✗ | Masked LM |
| GPT / Llama | Causal | ✗ | Next-token prediction |
| T5 | Bidirectional (enc) + Causal (dec) | ✓ | Span corruption |
| DALL-E / Stable Diffusion | Self (in U-Net) | ✓ (text → image) | Diffusion |
| Whisper | Bidirectional (enc) + Causal (dec) | ✓ | Audio → text |
| Vision Transformer (ViT) | Bidirectional | ✗ | Image patches |
💡 Key insight: The industry has largely converged on decoder-only (causal self-attention) for language models. Encoder-decoder architectures like T5 remain important for specific tasks (translation, summarization) but GPT-style models dominate due to their unified generation framework.
"Why not use additive attention instead of dot-product?" Bahdanau et al. (2015)[9] used additive attention: . Dot-product attention is faster because it takes advantage of optimized matrix multiplication hardware (tensor cores), but additive attention can perform better when is small. In practice, scaling the dot product gives equivalent quality with much better throughput.
"Can attention attend to nothing / everything equally?" If all scores are equal, softmax produces a uniform distribution , and the output is the average of all value vectors. If one score dominates, softmax approximates an argmax, so the output is approximately one value vector. The temperature (scaling) controls where on this spectrum you land.
Attention Is All You Need.
Vaswani, A., et al. · 2017
Are Sixteen Heads Really Better than One?
Michel, P., Levy, O., & Neubig, G. · 2019 · NeurIPS 2019
Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting.
Voita, E., et al. · 2019 · ACL 2019
In-context Learning and Induction Heads.
Olsson, C., et al. · 2022
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & R\u00e9, C. · 2022 · NeurIPS 2022
Online normalizer calculation for softmax.
Milakov, M. & Gimelshein, N. · 2018
Mistral 7B.
Jiang, A. Q., et al. · 2023
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.
Katharopoulos, A., et al. · 2020 · ICML 2020
Neural Machine Translation by Jointly Learning to Align and Translate.
Bahdanau, D., Cho, K., & Bengio, Y. · 2015 · ICLR 2015