Master the scaled dot-product attention formula from first principles. Deep dive into the variance proof, multi-head parallelization, O(n²) memory complexity, and the three core attention variants.
When you read the sentence "The cat sat on the mat because it was tired," you instantly know that "it" refers to "the cat," not the mat. Your brain focuses on the most relevant word to make sense of the sentence. This ability to selectively focus on the right information is exactly what modern AI models need to understand language. This is called attention.
Attention is the single most important mechanism in modern AI. Every transformer-based model, from encoder-only models like BERT to decoder-only language models and multimodal transformers, is built on it. In this article, you'll understand exactly how attention works, from the intuition all the way to the math and code.
💡 Key insight: Attention is an information routing mechanism. Given a sequence of words, it answers the question: "For each word, which other words should it pay attention to, and how much?" Once you grasp this idea, the math and code become natural extensions of the intuition.
Attention answers a simple question: "For each token in a sequence, which other tokens should it focus on?"
Think of reading a sentence: when you see "it" in "The cat sat on the mat because it was tired", your brain attends back to "cat" to resolve what "it" means. Self-attention does this computationally: every token computes a weighted combination of all other tokens based on relevance.[1]
To perform attention, the model transforms each word's representation into three different roles. Think of it like a library system:
The model creates these three vectors by multiplying each word's numerical representation (called an embedding, a list of numbers that captures a word's meaning) through learned weight matrices:
Each word's embedding is multiplied by three different weight matrices (, , ) to produce three different "views" of the same word. One view highlights what this word is searching for (Query), another highlights what this word offers to others (Key), and the third carries the actual information (Value).
Here and are the learned projections (where is the model's overall hidden dimension size, and are the dimensions of the queries/keys and values respectively).[1]
| Vector | Role | Analogy |
|---|---|---|
| Query (Q) | "What am I looking for?" | A search query |
| Key (K) | "What do I contain?" | A database index key |
| Value (V) | "What information do I carry?" | The actual data returned by the search |
💡 Key insight: Q and K determine where to look. V determines what information to extract. The separation of "routing" (Q, K) from "content" (V) is fundamental to why attention is so powerful.
The scaled dot-product attention formula computes a weighted combination of value vectors based on the similarity between queries and keys:[1]
Here's how the scaled dot-product attention formula translates into practical PyTorch code. The function takes the Query, Key, and Value tensors along with an optional mask (used for causal or padding purposes). It computes the attention scores, scales them by the square root of the dimension, normalizes them into probabilities, and returns both the final weighted output and the attention weights themselves.
python1import torch 2import torch.nn.functional as F 3import math 4 5def scaled_dot_product_attention( 6 Q: torch.Tensor, # (batch, n_heads, seq_len, d_k) 7 K: torch.Tensor, # (batch, n_heads, seq_len, d_k) 8 V: torch.Tensor, # (batch, n_heads, seq_len, d_v) 9 mask: torch.Tensor | None = None, 10 dropout_p: float = 0.0, 11) -> tuple[torch.Tensor, torch.Tensor]: 12 """Scaled dot-product attention (Vaswani et al., 2017).""" 13 d_k = Q.size(-1) 14 15 # Step 1: Compute raw attention scores 16 attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, n, n) 17 18 # Step 2: Scale to unit variance 19 attn_scores = attn_scores / math.sqrt(d_k) 20 21 # Step 3: Apply mask (causal or padding) 22 if mask is not None: 23 attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) 24 25 # Step 4: Softmax normalization (each row sums to 1) 26 attn_weights = F.softmax(attn_scores, dim=-1) 27 28 # Optional: dropout on attention weights (regularization) 29 if dropout_p > 0.0: 30 attn_weights = F.dropout(attn_weights, p=dropout_p) 31 32 # Step 5: Weighted aggregation of values 33 output = torch.matmul(attn_weights, V) # (B, h, n, d_v) 34 35 return output, attn_weights
💡 Analogy: Imagine you're mixing music tracks. Without scaling, adding more instruments (higher ) makes the mix louder and louder until everything distorts (softmax saturates). Dividing by is like an automatic volume leveler: no matter how many instruments are playing, the overall volume stays in the sweet spot where you can distinguish each one clearly.
Here's the formal derivation for why the scaling factor is strictly necessary. For the sake of the proof, we assume the entries of and are independent with mean 0 and variance 1. This holds approximately true at initialization before the weight matrices shift the distributions.[1]
are independent components. The dot product is:
Each term has (since and they're independent). The variance is:
By the sum of independent variances:
So the standard deviation of the raw dot product is . As grows, dot products get larger and push softmax into saturation, where the output becomes nearly one-hot and gradients vanish:
| Softmax behavior | ||
|---|---|---|
| 16 | 4.0 | Moderate peakiness |
| 64 | 8.0 | Getting peaked |
| 512 | 22.6 | Extremely peaked ⚠️ |
| 4096 | 64.0 | Nearly one-hot 💀 |
regardless of , maintaining well-behaved softmax gradients throughout training.
💡 Key insight: Many engineers know to scale, but few can derive why specifically. The variance argument is the mathematical justification for this architectural choice.
The transformer architecture uses attention in three distinct patterns. Understanding when and why each is used is critical:
💡 Analogy: Bidirectional attention is like a conference call: everyone can hear everyone else at the same time. Causal attention is like a chain of walkie-talkies where each person can only hear messages from people who spoke before them. Cross-attention is like a translator listening to one conversation (encoder) and relaying information into another (decoder).
Every token attends to every other token with no masking. Used in encoder architectures (BERT, RoBERTa). For example, given a complete sentence, every word's representation is updated by looking at every other word, forming a full bidirectional context:
text1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat, on, the, mat] ← full context
Understanding/classification tasks where you have the full input.
Each token can only attend to itself and previous tokens. Future positions are masked with . Used in decoder architectures such as GPT-style and other autoregressive language models. For instance, when processing a sequence step-by-step, the model progressively builds context but remains strictly blind to upcoming words:
text1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat] ← only past + self 3 Token "mat" attends to: [The, cat, sat, on, the, mat] ← full history
The causal mask is a lower-triangular matrix that ensures each position only looks at itself and the positions before it. Here's a simple implementation that takes the sequence length as input and outputs a binary matrix where 1 indicates an allowed connection and 0 indicates a masked one.
python1def create_causal_mask(seq_len: int) -> torch.Tensor: 2 """Lower-triangular mask: position i can attend to positions [0, i].""" 3 return torch.tril(torch.ones(seq_len, seq_len)) 4 5# Result for seq_len=4: 6# [[1, 0, 0, 0], ← token 0 sees only itself 7# [1, 1, 0, 0], ← token 1 sees tokens 0, 1 8# [1, 1, 1, 0], ← token 2 sees tokens 0, 1, 2 9# [1, 1, 1, 1]] ← token 3 sees all tokens
Autoregressive generation (generating text one step at a time), where the model must predict the next token without seeing the future.
Queries come from one sequence, Keys and Values from another. Used in encoder-decoder architectures (T5, original Transformer, diffusion models). For example, in a translation task, the encoder processes the source sentence and provides the Keys and Values, while the decoder uses the current target sentence state as the Query. The resulting attention weights dictate how much of each source word to incorporate into the next predicted target word:
text1Encoder output (source): "Le chat est assis" → provides K, V 2Decoder state (target): "The cat is ___" → provides Q 3 4Q from decoder × K from encoder → attention weights 5Weights × V from encoder → decoder incorporates source information
Translation, summarization, text-to-image (DALL-E cross-attends text embeddings).
| Attention Type | Q Source | K, V Source | Mask | Architecture |
|---|---|---|---|---|
| Bidirectional Self | Same sequence | Same sequence | None | BERT, Vision Transformer (ViT) |
| Causal Self | Same sequence | Same sequence | Lower-triangular | GPT, Qwen3.5 |
| Cross | Target sequence | Source sequence | None (typically) | T5, DALL-E |
💡 Analogy: Imagine you're reading a legal contract and you consult a committee: a grammar expert checks sentence structure, a domain expert understands the legal terms, and a context expert tracks which clauses reference each other. Multi-head attention works the same way. Each head becomes a specialist that focuses on different types of relationships (syntax, semantics, position), and their findings are combined for a complete understanding.
Instead of one large attention operation, we run parallel heads, each with :[1]
Instead of running one big attention operation, we split the input into smaller "heads" that each attend independently, like having specialists each looking at a different aspect of the same text. After each head produces its output, we stitch them back together and multiply by a final weight matrix to blend the specialists' findings into a single result.
Each head computes:
Research has revealed that individual heads specialize for different functions (Voita et al., 2019[2]; Olsson et al., 2022[3]):
| Head Type | What It Does | Found In |
|---|---|---|
| Positional heads | Attend to adjacent or nearby tokens | Early layers |
| Syntactic heads | Track subject-verb agreement, dependency arcs | Middle layers |
| Induction heads | Copy patterns from earlier context (key for in-context learning) | Middle layers |
| Rare word heads | Upweight infrequent tokens | Various layers |
| Semantic heads | Capture coreference, entity relationships | Later layers |
🔬 Research insight: Studies show that a large proportion of attention heads can be pruned without noticeable performance degradation, and in some layers, all but a single head can be removed (Michel et al., 2019)[4]. Models maintain significant redundancy as a form of robustness. Not all heads are equally important, but the system is robust to losing any individual head.
Multi-head attention doesn't add computation. It restructures it. Single-head attention on uses the same FLOPs (Floating Point Operations) as 8-head attention with each, because .
Here's a production-ready PyTorch implementation. Note that we use single large projection matrices (W_q, W_k, W_v) rather than per-head matrices. The reshape-and-transpose (view + transpose) splits the projection into heads internally, which is far more memory-efficient than storing separate weight matrices.
python1class MultiHeadAttention(torch.nn.Module): 2 def __init__(self, d_model: int, n_heads: int): 3 super().__init__() 4 assert d_model % n_heads == 0, "d_model must be divisible by n_heads" 5 self.d_k = d_model // n_heads 6 self.n_heads = n_heads 7 8 # Single large projections (more efficient than per-head matrices) 9 self.W_q = torch.nn.Linear(d_model, d_model) 10 self.W_k = torch.nn.Linear(d_model, d_model) 11 self.W_v = torch.nn.Linear(d_model, d_model) 12 self.W_o = torch.nn.Linear(d_model, d_model) 13 14 def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: 15 B, N, D = x.shape 16 17 # Project and reshape: (B, N, D) → (B, h, N, d_k) 18 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 19 K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 20 V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 21 22 # Scaled dot-product attention per head 23 out, _ = scaled_dot_product_attention(Q, K, V, mask) 24 25 # Concatenate heads and project back: (B, h, N, d_k) → (B, N, D) 26 out = out.transpose(1, 2).contiguous().view(B, N, D) 27 return self.W_o(out)
| Metric | Complexity | Explanation |
|---|---|---|
| Time | : | |
| Memory | Must store the full attention weight matrix | |
| Parameters | Weight matrices , each |
Batch of 8, 32 heads, , FP16 (16-bit floating point format):
8 sequences in the batch × 32 attention heads × 8192² entries per attention map × 2 bytes per number (FP16) = 32 GB just for the attention scores. A single A100 GPU has 80 GB of HBM, so the attention matrix alone would consume 40% of it for one batch.
This is why the memory is the primary bottleneck for long sequences, not the compute. As large language models (LLMs) are tasked with reading entire books or codebases, this quadratic memory scaling becomes a hard wall. If you want to double the context window of a model, you need four times the memory just for the attention mechanism. This physical hardware constraint is the primary reason why expanding context windows from 4K tokens to 128K+ tokens required completely new engineering approaches rather than just adding more GPUs.
FlashAttention (Dao et al., 2022)[5] is now the standard in virtually every production LLM. It reduces memory from to without any approximation: the output is mathematically identical.
Instead of materializing the full attention matrix in GPU HBM (High Bandwidth Memory, the main GPU memory; slow, ~2TB/s), FlashAttention computes attention in tiles that fit in SRAM (Static Random Access Memory, a small, ultra-fast on-chip cache; fast, ~19TB/s on A100), using online softmax to avoid storing the full matrix:[6]
| Property | Standard Attention | FlashAttention |
|---|---|---|
| Memory | ✅ | |
| HBM I/O passes | ~35 passes over matrix | 3–4 passes total |
| Exact | Yes | Yes ✅ |
| Wall-clock speed | 1× | 2-4× faster |
FlashAttention is IO-aware. It minimizes reads/writes between slow GPU HBM and fast SRAM by restructuring the computation order, using the tiling technique from online softmax.
Using FlashAttention in practice doesn't require writing custom CUDA kernels. PyTorch provides a unified interface that automatically routes to the most efficient implementation (like FlashAttention) based on your hardware. It takes the Query, Key, and Value tensors and returns the output directly, bypassing the need to store the full attention matrix.
python1# In production, simply use PyTorch's built-in SDPA: 2from torch.nn.functional import scaled_dot_product_attention as sdpa 3 4# This automatically dispatches to FlashAttention when available 5output = sdpa(Q, K, V, is_causal=True) # For causal attention 6# No manual implementation needed; PyTorch handles the kernel selection
A subtle but critical implementation detail in attention mechanisms is numerical stability. The naive mathematical formulation of softmax overflows for large inputs because the exponential function grows exceptionally fast. When dealing with dot products from large vector spaces, even scaled ones can occasionally produce large positive values, causing the exponentiation to result in Inf (infinity) floating-point values and completely breaking the model's training process.
To prevent this overflow, deep learning frameworks implement a numerically stable version of softmax. They subtract the maximum value from the input vector before exponentiating:
Subtracting shifts all logits by the same constant. Because of the properties of exponentials, this constant factors out and cancels between the numerator and denominator. The resulting probabilities don't change, but the highest value being exponentiated is exactly (since ), so the maximum exponential evaluated is . This keeps all numbers numerically safe and avoids overflow.[6]
Building on this, online softmax (Milakov & Gimelshein, 2018)[6] extends the max-shift technique to compute softmax in a single streaming pass. Instead of needing to read the entire vector into memory to find the maximum, online softmax tracks the running maximum and the running sum of exponentials simultaneously, dynamically correcting the sum as a new, larger maximum is found. This streaming capability is exactly the mathematical foundation that allows FlashAttention's memory-efficient block tiling.
Different AI tasks require different ways of routing information, which is why various architectures implement attention differently. By looking at how a model structures its attention mask and where it sources its queries, keys, and values, we can immediately understand its primary use case.
Understanding which attention pattern each architecture uses helps you work effectively with modern foundation models:
| Architecture | Self-Attention | Cross-Attention | Pre-training |
|---|---|---|---|
| BERT | Bidirectional | ✗ | Masked Language Model (LM) |
| Decoder-only LM | Causal | ✗ | Next-token prediction |
| T5 | Bidirectional (enc) + Causal (dec) | ✓ | Span corruption |
| DALL-E / Stable Diffusion | Self (in U-Net) | ✓ (text → image) | Diffusion |
| Whisper | Bidirectional (enc) + Causal (dec) | ✓ | Audio → text |
| Vision Transformer (ViT) | Bidirectional | ✗ | Image patches |
Encoder-only models like BERT and ViT use bidirectional self-attention because their goal is to build a comprehensive understanding of the entire input at once. They process the full text or image simultaneously to generate rich embeddings. In contrast, decoder-only language models use causal self-attention because they must generate output step-by-step. If they could look ahead at future tokens during training, they would simply "cheat" instead of learning to predict.
Encoder-decoder models (like T5 and Whisper) mix these approaches. They use a bidirectional encoder to fully understand the input (text or audio), and a causal decoder to generate the output. Cross-attention acts as the bridge between the two, allowing the decoder to continuously refer back to the encoder's rich representation of the source material.
💡 Key insight: The industry has largely converged on decoder-only (causal self-attention) for large language models. While encoder-decoder architectures remain important for specific sequence-to-sequence tasks (like translation or speech recognition), GPT-style models dominate general-purpose AI due to the simplicity and unified generation framework of pure autoregressive modeling.
Bahdanau et al. (2015)[9] used additive attention: . Dot-product attention is faster because it takes advantage of optimized matrix multiplication hardware (tensor cores), but additive attention can perform better when is small. In practice, scaling the dot product gives equivalent quality with much better throughput.
If all scores are equal, softmax produces a uniform distribution , and the output is the average of all value vectors. If one score dominates, softmax approximates an argmax, so the output is approximately one value vector. The temperature (scaling) controls where on this spectrum you land.
Attention Is All You Need.
Vaswani, A., et al. · 2017
Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting.
Voita, E., et al. · 2019 · ACL 2019
In-context Learning and Induction Heads.
Olsson, C., et al. · 2022
Are Sixteen Heads Really Better than One?.
Michel, P., Levy, O., & Neubig, G. · 2019 · NeurIPS 2019
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022
Online normalizer calculation for softmax.
Milakov, M. & Gimelshein, N. · 2018
Mistral 7B.
Jiang, A. Q., et al. · 2023
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.
Katharopoulos, A., et al. · 2020 · ICML 2020
Neural Machine Translation by Jointly Learning to Align and Translate.
Bahdanau, D., Cho, K., & Bengio, Y. · 2015 · ICLR 2015