LeetLLM
LearnFeaturesPricingBlog
Menu
LearnFeaturesPricingBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Features
  • Pricing
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 76 articles completed

🧪AI Engineering Foundations0/11
The Bitter Lesson & ComputeTokenization: BPE & SentencePieceWord to Contextual EmbeddingsSentence Embeddings & Contrastive LossDimensionality Reduction for EmbeddingsEmbedding Similarity & QuantizationScaled Dot-Product AttentionPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNDecoding Strategies: Greedy to NucleusPerplexity & Model Evaluation
⚡Inference Systems & Optimization0/12
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceSpeculative DecodingLong Context Window ManagementModel Quantization: GPTQ, AWQ & GGUFMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time Compute
🔍Advanced Retrieval & Enterprise Memory0/7
Chunking StrategiesVector DB Internals: HNSW & IVFHybrid Search: Dense + SparseProduction RAG PipelinesAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access Control
🤖Agentic Architecture & Orchestration0/13
CoT, ToT & Self-Consistency PromptingStructured Output GenerationFunction Calling & Tool UseMCP & Tool Protocol StandardsReAct & Plan-and-ExecuteAgent Memory & PersistenceHuman-in-the-Loop AgentsGuardrails & Safety FiltersPrompt Injection DefenseCode Generation & SandboxingAgent Failure & RecoveryMulti-Agent OrchestrationAI Agent Evaluation and Benchmarking
📊Evaluation & Reliability0/6
LLM Benchmarks & LimitationsLLM-as-a-Judge EvaluationA/B Testing for LLMsLLM Observability & MonitoringHallucination Detection & MitigationBias & Fairness in LLMs
🛠️LLMOps & Production Engineering0/4
Semantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Versioning & DeploymentGPU Serving & Autoscaling
🧬Training, Alignment & Reasoning0/13
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleInstruction Tuning & Chat TemplatesMixed Precision TrainingDistributed Training: FSDP & ZeROPrompt Optimization with DSPyRecursive Language Models (RLM)LoRA & Parameter-Efficient TuningKnowledge Distillation for LLMsModel Merging and Weight InterpolationConstitutional AI & Red TeamingRLHF & DPO AlignmentRLVR & Verifiable Rewards
🏗️System Design Case Studies0/10
Design an Automated Support AgentContent Moderation SystemLLM-Powered Search EngineCode Completion SystemMulti-Tenant LLM PlatformReasoning & Test-Time ComputeReal-Time Voice AI AgentVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image Generation
Track Your Progress

Create a free account to save your reading progress across devices and unlock the full learning experience.

LeetLLM Premium
  • All question breakdowns
  • Architecture diagrams
  • Model answers & rubrics
  • Follow-up Q&A analysis
  • New content weekly
Back to Topics
LearnAI Engineering FoundationsScaled Dot-Product Attention
🧠HardTransformer Architecture

Scaled Dot-Product Attention

Master the scaled dot-product attention formula from first principles. Deep dive into the variance proof, multi-head parallelization, O(n²) memory complexity, and the three core attention variants.

45 min readGoogle, Meta, OpenAI +310 key concepts

When you read the sentence "The cat sat on the mat because it was tired," you instantly know that "it" refers to "the cat," not the mat. Your brain focuses on the most relevant word to make sense of the sentence. This ability to selectively focus on the right information is exactly what modern AI models need to understand language. This is called attention.

Attention is the single most important mechanism in modern AI. Every transformer-based model, from encoder-only models like BERT to decoder-only language models and multimodal transformers, is built on it. In this article, you'll understand exactly how attention works, from the intuition all the way to the math and code.

💡 Key insight: Attention is an information routing mechanism. Given a sequence of words, it answers the question: "For each word, which other words should it pay attention to, and how much?" Once you grasp this idea, the math and code become natural extensions of the intuition.

Scaled dot-product attention flow: Query and Key matrices are multiplied, scaled by sqrt(d_k), masked, softmax-normalized, then multiplied with Value to produce the output. Scaled dot-product attention flow: Query and Key matrices are multiplied, scaled by sqrt(d_k), masked, softmax-normalized, then multiplied with Value to produce the output.
The complete attention pipeline: input embeddings are projected into Query, Key, and Value vectors, then combined through scaled dot-product attention to produce context-aware outputs.

The attention mechanism

Intuition

Attention answers a simple question: "For each token in a sequence, which other tokens should it focus on?"

Think of reading a sentence: when you see "it" in "The cat sat on the mat because it was tired", your brain attends back to "cat" to resolve what "it" means. Self-attention does this computationally: every token computes a weighted combination of all other tokens based on relevance.[1]

Setup: Q, K, V projections

To perform attention, the model transforms each word's representation into three different roles. Think of it like a library system:

  • •Query (Q): "What am I looking for?" Like typing a search into Google.
  • •Key (K): "What do I contain?" Like the index entry for a book.
  • •Value (V): "What information do I carry?" Like the actual content of the book.

The model creates these three vectors by multiplying each word's numerical representation (called an embedding, a list of numbers that captures a word's meaning) through learned weight matrices:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_VQ=XWQ​,K=XWK​,V=XWV​

Reading the formula

Each word's embedding XXX is multiplied by three different weight matrices (WQW_QWQ​, WKW_KWK​, WVW_VWV​) to produce three different "views" of the same word. One view highlights what this word is searching for (Query), another highlights what this word offers to others (Key), and the third carries the actual information (Value).

Here WQ,WK∈Rdmodel×dkW_Q, W_K \in \mathbb{R}^{d_{\text{model}} \times d_k}WQ​,WK​∈Rdmodel​×dk​ and WV∈Rdmodel×dvW_V \in \mathbb{R}^{d_{\text{model}} \times d_v}WV​∈Rdmodel​×dv​ are the learned projections (where dmodeld_{\text{model}}dmodel​ is the model's overall hidden dimension size, and dk,dvd_k, d_vdk​,dv​ are the dimensions of the queries/keys and values respectively).[1]

Diagram Diagram
VectorRoleAnalogy
Query (Q)"What am I looking for?"A search query
Key (K)"What do I contain?"A database index key
Value (V)"What information do I carry?"The actual data returned by the search

💡 Key insight: Q and K determine where to look. V determines what information to extract. The separation of "routing" (Q, K) from "content" (V) is fundamental to why attention is so powerful.


The core formula

The scaled dot-product attention formula computes a weighted combination of value vectors based on the similarity between queries and keys:[1]

Attention(Q,K,V)=softmax(Q⋅KTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk​​Q⋅KT​)V

Step by step

  1. •Compute similarity scores: S=Q⋅KTS = Q \cdot K^TS=Q⋅KT produces an n×nn \times nn×n matrix where SijS_{ij}Sij​ measures how much token iii should attend to token jjj
  2. •Scale: Divide by dk\sqrt{d_k}dk​​ to keep softmax gradients well-behaved (proof below)
  3. •Mask (optional): For causal/decoder models, set future positions to −∞-\infty−∞ so they become 0 after softmax
  4. •Normalize: Apply softmax row-wise: αij=exp⁡(Sij)∑kexp⁡(Sik)\alpha_{ij} = \frac{\exp(S_{ij})}{\sum_k \exp(S_{ik})}αij​=∑k​exp(Sik​)exp(Sij​)​ so each row sums to 1
  5. •Aggregate: Output =α⋅V= \alpha \cdot V=α⋅V, a weighted blend of value vectors
Diagram Diagram

PyTorch implementation

Here's how the scaled dot-product attention formula translates into practical PyTorch code. The function takes the Query, Key, and Value tensors along with an optional mask (used for causal or padding purposes). It computes the attention scores, scales them by the square root of the dimension, normalizes them into probabilities, and returns both the final weighted output and the attention weights themselves.

python
1import torch 2import torch.nn.functional as F 3import math 4 5def scaled_dot_product_attention( 6 Q: torch.Tensor, # (batch, n_heads, seq_len, d_k) 7 K: torch.Tensor, # (batch, n_heads, seq_len, d_k) 8 V: torch.Tensor, # (batch, n_heads, seq_len, d_v) 9 mask: torch.Tensor | None = None, 10 dropout_p: float = 0.0, 11) -> tuple[torch.Tensor, torch.Tensor]: 12 """Scaled dot-product attention (Vaswani et al., 2017).""" 13 d_k = Q.size(-1) 14 15 # Step 1: Compute raw attention scores 16 attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, n, n) 17 18 # Step 2: Scale to unit variance 19 attn_scores = attn_scores / math.sqrt(d_k) 20 21 # Step 3: Apply mask (causal or padding) 22 if mask is not None: 23 attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) 24 25 # Step 4: Softmax normalization (each row sums to 1) 26 attn_weights = F.softmax(attn_scores, dim=-1) 27 28 # Optional: dropout on attention weights (regularization) 29 if dropout_p > 0.0: 30 attn_weights = F.dropout(attn_weights, p=dropout_p) 31 32 # Step 5: Weighted aggregation of values 33 output = torch.matmul(attn_weights, V) # (B, h, n, d_v) 34 35 return output, attn_weights
Attention heatmap for the sentence "The cat sat on the mat": each cell shows how much one token attends to another, with bright cells indicating strong attention. Attention heatmap for the sentence "The cat sat on the mat": each cell shows how much one token attends to another, with bright cells indicating strong attention.
Attention weight matrix for "The cat sat on the mat": each row shows how much a query token attends to each key token. Notice how "sat" strongly attends to "cat" (its subject) and "on" attends to "mat" (its object).

Why scale by dk\sqrt{d_k}dk​​? The variance proof

💡 Analogy: Imagine you're mixing music tracks. Without scaling, adding more instruments (higher dkd_kdk​) makes the mix louder and louder until everything distorts (softmax saturates). Dividing by dk\sqrt{d_k}dk​​ is like an automatic volume leveler: no matter how many instruments are playing, the overall volume stays in the sweet spot where you can distinguish each one clearly.

Here's the formal derivation for why the scaling factor is strictly necessary. For the sake of the proof, we assume the entries of qqq and kkk are independent with mean 0 and variance 1. This holds approximately true at initialization before the weight matrices shift the distributions.[1]

Assume

qi,kj∼N(0,1)q_i, k_j \sim \mathcal{N}(0, 1)qi​,kj​∼N(0,1) are independent components. The dot product is:

q⋅k=∑i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_iq⋅k=∑i=1dk​​qi​ki​

Each term qikiq_i k_iqi​ki​ has E[qiki]=0\mathbb{E}[q_i k_i] = 0E[qi​ki​]=0 (since E[qi]=E[ki]=0\mathbb{E}[q_i] = \mathbb{E}[k_i] = 0E[qi​]=E[ki​]=0 and they're independent). The variance is:

Var(qiki)=E[(qiki)2]−(E[qiki])2=E[qi2]E[ki2]−0=1⋅1=1\text{Var}(q_i k_i) = \mathbb{E}[(q_i k_i)^2] - (\mathbb{E}[q_i k_i])^2 = \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] - 0 = 1 \cdot 1 = 1Var(qi​ki​)=E[(qi​ki​)2]−(E[qi​ki​])2=E[qi2​]E[ki2​]−0=1⋅1=1

By the sum of independent variances:

Var(q⋅k)=dk\text{Var}(q \cdot k) = d_kVar(q⋅k)=dk​

So the standard deviation of the raw dot product is dk\sqrt{d_k}dk​​. As dkd_kdk​ grows, dot products get larger and push softmax into saturation, where the output becomes nearly one-hot and gradients vanish:

dkd_kdk​Std(q⋅k)\text{Std}(q \cdot k)Std(q⋅k)Softmax behavior
164.0Moderate peakiness
648.0Getting peaked
51222.6Extremely peaked ⚠️
409664.0Nearly one-hot 💀

After dividing by dk\sqrt{d_k}dk​​

Var ⁣(q⋅kdk)=dkdk=1\text{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1Var(dk​​q⋅k​)=dk​dk​​=1 regardless of dkd_kdk​, maintaining well-behaved softmax gradients throughout training.

💡 Key insight: Many engineers know to scale, but few can derive why dk\sqrt{d_k}dk​​ specifically. The variance argument is the mathematical justification for this architectural choice.


Three types of attention

The transformer architecture uses attention in three distinct patterns. Understanding when and why each is used is critical:

1. Bidirectional self-attention (encoder)

💡 Analogy: Bidirectional attention is like a conference call: everyone can hear everyone else at the same time. Causal attention is like a chain of walkie-talkies where each person can only hear messages from people who spoke before them. Cross-attention is like a translator listening to one conversation (encoder) and relaying information into another (decoder).

Every token attends to every other token with no masking. Used in encoder architectures (BERT, RoBERTa). For example, given a complete sentence, every word's representation is updated by looking at every other word, forming a full bidirectional context:

text
1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat, on, the, mat] ← full context

Use case

Understanding/classification tasks where you have the full input.

2. Causal self-attention (decoder)

Each token can only attend to itself and previous tokens. Future positions are masked with −∞-\infty−∞. Used in decoder architectures such as GPT-style and other autoregressive language models. For instance, when processing a sequence step-by-step, the model progressively builds context but remains strictly blind to upcoming words:

text
1"The cat sat on the mat" 2 Token "sat" attends to: [The, cat, sat] ← only past + self 3 Token "mat" attends to: [The, cat, sat, on, the, mat] ← full history

The causal mask is a lower-triangular matrix that ensures each position only looks at itself and the positions before it. Here's a simple implementation that takes the sequence length as input and outputs a binary matrix where 1 indicates an allowed connection and 0 indicates a masked one.

python
1def create_causal_mask(seq_len: int) -> torch.Tensor: 2 """Lower-triangular mask: position i can attend to positions [0, i].""" 3 return torch.tril(torch.ones(seq_len, seq_len)) 4 5# Result for seq_len=4: 6# [[1, 0, 0, 0], ← token 0 sees only itself 7# [1, 1, 0, 0], ← token 1 sees tokens 0, 1 8# [1, 1, 1, 0], ← token 2 sees tokens 0, 1, 2 9# [1, 1, 1, 1]] ← token 3 sees all tokens

Use case

Autoregressive generation (generating text one step at a time), where the model must predict the next token without seeing the future.

3. Cross-attention (encoder-decoder)

Queries come from one sequence, Keys and Values from another. Used in encoder-decoder architectures (T5, original Transformer, diffusion models). For example, in a translation task, the encoder processes the source sentence and provides the Keys and Values, while the decoder uses the current target sentence state as the Query. The resulting attention weights dictate how much of each source word to incorporate into the next predicted target word:

text
1Encoder output (source): "Le chat est assis" → provides K, V 2Decoder state (target): "The cat is ___" → provides Q 3 4Q from decoder × K from encoder → attention weights 5Weights × V from encoder → decoder incorporates source information

Use case

Translation, summarization, text-to-image (DALL-E cross-attends text embeddings).

Attention TypeQ SourceK, V SourceMaskArchitecture
Bidirectional SelfSame sequenceSame sequenceNoneBERT, Vision Transformer (ViT)
Causal SelfSame sequenceSame sequenceLower-triangularGPT, Qwen3.5
CrossTarget sequenceSource sequenceNone (typically)T5, DALL-E

Multi-head attention

💡 Analogy: Imagine you're reading a legal contract and you consult a committee: a grammar expert checks sentence structure, a domain expert understands the legal terms, and a context expert tracks which clauses reference each other. Multi-head attention works the same way. Each head becomes a specialist that focuses on different types of relationships (syntax, semantics, position), and their findings are combined for a complete understanding.

Instead of one large attention operation, we run hhh parallel heads, each with dk=dmodel/hd_k = d_{\text{model}} / hdk​=dmodel​/h:[1]

MultiHead(Q,K,V)=Concat(head1,…,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^OMultiHead(Q,K,V)=Concat(head1​,…,headh​)WO

Reading the formula

Instead of running one big attention operation, we split the input into hhh smaller "heads" that each attend independently, like having hhh specialists each looking at a different aspect of the same text. After each head produces its output, we stitch them back together and multiply by a final weight matrix WOW^OWO to blend the specialists' findings into a single result.

Each head computes: headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi​=Attention(QWiQ​,KWiK​,VWiV​)

Diagram Diagram

What do attention heads learn?

Research has revealed that individual heads specialize for different functions (Voita et al., 2019[2]; Olsson et al., 2022[3]):

Head TypeWhat It DoesFound In
Positional headsAttend to adjacent or nearby tokensEarly layers
Syntactic headsTrack subject-verb agreement, dependency arcsMiddle layers
Induction headsCopy patterns from earlier context (key for in-context learning)Middle layers
Rare word headsUpweight infrequent tokensVarious layers
Semantic headsCapture coreference, entity relationshipsLater layers

🔬 Research insight: Studies show that a large proportion of attention heads can be pruned without noticeable performance degradation, and in some layers, all but a single head can be removed (Michel et al., 2019)[4]. Models maintain significant redundancy as a form of robustness. Not all heads are equally important, but the system is robust to losing any individual head.

Critical nuance: same total FLOPs

Multi-head attention doesn't add computation. It restructures it. Single-head attention on dmodel=512d_{\text{model}}=512dmodel​=512 uses the same FLOPs (Floating Point Operations) as 8-head attention with dk=64d_k=64dk​=64 each, because h×dk=dmodelh \times d_k = d_{\text{model}}h×dk​=dmodel​.

Here's a production-ready PyTorch implementation. Note that we use single large projection matrices (W_q, W_k, W_v) rather than per-head matrices. The reshape-and-transpose (view + transpose) splits the projection into heads internally, which is far more memory-efficient than storing hhh separate weight matrices.

python
1class MultiHeadAttention(torch.nn.Module): 2 def __init__(self, d_model: int, n_heads: int): 3 super().__init__() 4 assert d_model % n_heads == 0, "d_model must be divisible by n_heads" 5 self.d_k = d_model // n_heads 6 self.n_heads = n_heads 7 8 # Single large projections (more efficient than per-head matrices) 9 self.W_q = torch.nn.Linear(d_model, d_model) 10 self.W_k = torch.nn.Linear(d_model, d_model) 11 self.W_v = torch.nn.Linear(d_model, d_model) 12 self.W_o = torch.nn.Linear(d_model, d_model) 13 14 def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: 15 B, N, D = x.shape 16 17 # Project and reshape: (B, N, D) → (B, h, N, d_k) 18 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 19 K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 20 V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 21 22 # Scaled dot-product attention per head 23 out, _ = scaled_dot_product_attention(Q, K, V, mask) 24 25 # Concatenate heads and project back: (B, h, N, d_k) → (B, N, D) 26 out = out.transpose(1, 2).contiguous().view(B, N, D) 27 return self.W_o(out)

Complexity analysis

MetricComplexityExplanation
TimeO(n2⋅dk)O(n^2 \cdot d_k)O(n2⋅dk​)Q⋅KTQ \cdot K^TQ⋅KT: (n×dk)⋅(dk×n)=O(n2dk)(n \times d_k) \cdot (d_k \times n) = O(n^2 d_k)(n×dk​)⋅(dk​×n)=O(n2dk​)
MemoryO(n2)O(n^2)O(n2)Must store the full n×nn \times nn×n attention weight matrix
ParametersO(d2)O(d^2)O(d2)Weight matrices WQ,WK,WV,WOW_Q, W_K, W_V, W_OWQ​,WK​,WV​,WO​, each d×dd \times dd×d
Diagram Diagram

Concrete memory example

Batch of 8, 32 heads, n=8192n = 8192n=8192, FP16 (16-bit floating point format):

Attention matrix=8×32×81922×2 bytes≈32 GB\text{Attention matrix} = 8 \times 32 \times 8192^2 \times 2 \text{ bytes} \approx 32 \text{ GB}Attention matrix=8×32×81922×2 bytes≈32 GB

Reading the formula

8 sequences in the batch × 32 attention heads × 8192² entries per attention map × 2 bytes per number (FP16) = 32 GB just for the attention scores. A single A100 GPU has 80 GB of HBM, so the attention matrix alone would consume 40% of it for one batch.

This is why the O(n2)O(n^2)O(n2) memory is the primary bottleneck for long sequences, not the compute. As large language models (LLMs) are tasked with reading entire books or codebases, this quadratic memory scaling becomes a hard wall. If you want to double the context window of a model, you need four times the memory just for the attention mechanism. This physical hardware constraint is the primary reason why expanding context windows from 4K tokens to 128K+ tokens required completely new engineering approaches rather than just adding more GPUs.


FlashAttention: The production solution

FlashAttention (Dao et al., 2022)[5] is now the standard in virtually every production LLM. It reduces memory from O(n2)O(n^2)O(n2) to O(n)O(n)O(n) without any approximation: the output is mathematically identical.

Core idea

Instead of materializing the full n×nn \times nn×n attention matrix in GPU HBM (High Bandwidth Memory, the main GPU memory; slow, ~2TB/s), FlashAttention computes attention in tiles that fit in SRAM (Static Random Access Memory, a small, ultra-fast on-chip cache; fast, ~19TB/s on A100), using online softmax to avoid storing the full matrix:[6]

PropertyStandard AttentionFlashAttention
MemoryO(n2)O(n^2)O(n2)O(n)O(n)O(n) ✅
HBM I/O passes~35 passes over n×nn \times nn×n matrix3–4 passes total
ExactYesYes ✅
Wall-clock speed1×2-4× faster

The key insight

FlashAttention is IO-aware. It minimizes reads/writes between slow GPU HBM and fast SRAM by restructuring the computation order, using the tiling technique from online softmax.

Using FlashAttention in practice doesn't require writing custom CUDA kernels. PyTorch provides a unified interface that automatically routes to the most efficient implementation (like FlashAttention) based on your hardware. It takes the Query, Key, and Value tensors and returns the output directly, bypassing the need to store the full attention matrix.

python
1# In production, simply use PyTorch's built-in SDPA: 2from torch.nn.functional import scaled_dot_product_attention as sdpa 3 4# This automatically dispatches to FlashAttention when available 5output = sdpa(Q, K, V, is_causal=True) # For causal attention 6# No manual implementation needed; PyTorch handles the kernel selection

Softmax numerical stability

A subtle but critical implementation detail in attention mechanisms is numerical stability. The naive mathematical formulation of softmax overflows for large inputs because the exponential function exe^{x}ex grows exceptionally fast. When dealing with dot products from large vector spaces, even scaled ones can occasionally produce large positive values, causing the exponentiation to result in Inf (infinity) floating-point values and completely breaking the model's training process.

The max-shift technique

To prevent this overflow, deep learning frameworks implement a numerically stable version of softmax. They subtract the maximum value from the input vector before exponentiating:

softmax(xi)=exi−max⁡(x)∑jexj−max⁡(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_j e^{x_j - \max(\mathbf{x})}}softmax(xi​)=∑j​exj​−max(x)exi​−max(x)​

Reading the formula

Subtracting max⁡(x)\max(\mathbf{x})max(x) shifts all logits by the same constant. Because of the properties of exponentials, this constant factors out and cancels between the numerator and denominator. The resulting probabilities don't change, but the highest value being exponentiated is exactly 000 (since max⁡(x)−max⁡(x)=0\max(\mathbf{x}) - \max(\mathbf{x}) = 0max(x)−max(x)=0), so the maximum exponential evaluated is e0=1e^0 = 1e0=1. This keeps all numbers numerically safe and avoids overflow.[6]

Building on this, online softmax (Milakov & Gimelshein, 2018)[6] extends the max-shift technique to compute softmax in a single streaming pass. Instead of needing to read the entire vector into memory to find the maximum, online softmax tracks the running maximum and the running sum of exponentials simultaneously, dynamically correcting the sum as a new, larger maximum is found. This streaming capability is exactly the mathematical foundation that allows FlashAttention's memory-efficient block tiling.


Attention in modern architectures

Different AI tasks require different ways of routing information, which is why various architectures implement attention differently. By looking at how a model structures its attention mask and where it sources its queries, keys, and values, we can immediately understand its primary use case.

Understanding which attention pattern each architecture uses helps you work effectively with modern foundation models:

ArchitectureSelf-AttentionCross-AttentionPre-training
BERTBidirectional✗Masked Language Model (LM)
Decoder-only LMCausal✗Next-token prediction
T5Bidirectional (enc) + Causal (dec)✓Span corruption
DALL-E / Stable DiffusionSelf (in U-Net)✓ (text → image)Diffusion
WhisperBidirectional (enc) + Causal (dec)✓Audio → text
Vision Transformer (ViT)Bidirectional✗Image patches

Encoder-only models like BERT and ViT use bidirectional self-attention because their goal is to build a comprehensive understanding of the entire input at once. They process the full text or image simultaneously to generate rich embeddings. In contrast, decoder-only language models use causal self-attention because they must generate output step-by-step. If they could look ahead at future tokens during training, they would simply "cheat" instead of learning to predict.

Encoder-decoder models (like T5 and Whisper) mix these approaches. They use a bidirectional encoder to fully understand the input (text or audio), and a causal decoder to generate the output. Cross-attention acts as the bridge between the two, allowing the decoder to continuously refer back to the encoder's rich representation of the source material.

💡 Key insight: The industry has largely converged on decoder-only (causal self-attention) for large language models. While encoder-decoder architectures remain important for specific sequence-to-sequence tasks (like translation or speech recognition), GPT-style models dominate general-purpose AI due to the simplicity and unified generation framework of pure autoregressive modeling.


Key takeaways

Summary

  • •Attention's information routing: The mechanism computes a weighted combination of Value vectors based on Query-Key compatibility, allowing tokens to contextually aggregate information.
  • •Variance scaling is critical: The dk\sqrt{d_k}dk​​ factor normalizes dot products to unit variance. Without it, larger dimensions push softmax into saturation, causing vanishing gradients.
  • •Multi-head capacity: Parallel heads allow the model to attend to different representation subspaces (positional, syntactic, semantic) simultaneously without increasing total FLOPs.
  • •Memory is the bottleneck: While time complexity is O(n2dk)O(n^2 d_k)O(n2dk​), the O(n2)O(n^2)O(n2) memory complexity for the attention matrix limits sequence length.
  • •Production solutions: Modern LLMs use IO-aware exact attention (FlashAttention) or approximation methods (sliding windows[7], linear attention/SSMs (State Space Models)[8]) to handle long contexts efficiently.

Common misconceptions

  • •Mixing up time and memory bottlenecks: The O(n2)O(n^2)O(n2) memory cost (storing the full attention matrix) is the primary limiter for sequence length, not just the compute.
  • •Misunderstanding scaling: Scaling isn't just about magnitude; it's mathematically derived to maintain unit variance for independent random variables, preserving gradient flow through softmax.
  • •Assuming one attention type: Transformers use three distinct patterns: bidirectional (encoder), causal (decoder), and cross-attention (encoder-decoder).
  • •Routing vs. Content: QQQ and KKK determine where to look (routing), while VVV determines what information to retrieve (content).

Going deeper

"Why not use additive attention instead of dot-product?"

Bahdanau et al. (2015)[9] used additive attention: score(q,k)=vTtanh⁡(Wqq+Wkk)\text{score}(q, k) = v^T \tanh(W_q q + W_k k)score(q,k)=vTtanh(Wq​q+Wk​k). Dot-product attention is faster because it takes advantage of optimized matrix multiplication hardware (tensor cores), but additive attention can perform better when dkd_kdk​ is small. In practice, scaling the dot product gives equivalent quality with much better throughput.

"Can attention attend to nothing / everything equally?"

If all scores are equal, softmax produces a uniform distribution 1/n1/n1/n, and the output is the average of all value vectors. If one score dominates, softmax approximates an argmax, so the output is approximately one value vector. The temperature (scaling) controls where on this spectrum you land.

Evaluation Rubric
  • 1
    Correctly derives Attention(Q,K,V) = softmax(QK^T / √d_k) · V with dimension annotations
  • 2
    Proves why scaling by √d_k prevents softmax saturation using the variance argument
  • 3
    Distinguishes self-attention (encoder), causal self-attention (decoder), and cross-attention
  • 4
    Explains multi-head as parallel attention in different subspaces with same total FLOPs
  • 5
    Shows O(n²·d) time and O(n²) memory, with concrete memory calculation example
  • 6
    Explains FlashAttention as IO-aware exact attention with O(n) memory
  • 7
    Mentions softmax numerical stability (max-shift technique / online softmax)
  • 8
    Describes attention head specialization (positional, syntactic, induction heads)
Common Pitfalls
  • Neglecting the variance scaling justification for √d_k
  • Confusing time complexity O(n²d) with memory complexity O(n²)
  • Assuming attention is always O(n²) without mentioning efficient alternatives like FlashAttention
  • Misunderstanding the distinction between the three attention types (self, causal, cross)
  • Confusing Q/K/V roles. Q and K do routing, V carries content
  • Overlooking softmax numerical stability (naive implementation overflows)
Follow-up Questions to Expect

Key Concepts Tested
Scaled dot-product attention formula derivationQ/K/V projections and their roles (routing vs content)Variance proof for √d_k scaling factorThree attention types: bidirectional, causal, crossMulti-head attention: parallel subspaces, same total FLOPsAttention head specialization and redundancyTime O(n²d) and memory O(n²) complexity analysisFlashAttention: IO-aware tiling for O(n) memorySoftmax numerical stability (max-shift technique)Additive vs dot-product attention
References

Attention Is All You Need.

Vaswani, A., et al. · 2017

Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting.

Voita, E., et al. · 2019 · ACL 2019

In-context Learning and Induction Heads.

Olsson, C., et al. · 2022

Are Sixteen Heads Really Better than One?.

Michel, P., Levy, O., & Neubig, G. · 2019 · NeurIPS 2019

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022

Online normalizer calculation for softmax.

Milakov, M. & Gimelshein, N. · 2018

Mistral 7B.

Jiang, A. Q., et al. · 2023

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.

Katharopoulos, A., et al. · 2020 · ICML 2020

Neural Machine Translation by Jointly Learning to Align and Translate.

Bahdanau, D., Cho, K., & Bengio, Y. · 2015 · ICLR 2015

Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail

Your account is free and you can post anonymously if you choose.