LeetLLM
LearnFeaturesBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Features
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 155 articles completed

🛠️Computing Foundations0/6
NumPy and Tensor ShapesCUDA for ML TrainingMPS & Metal for ML on MacData Structures for AISQL and Data ModelingAlgorithms for ML Engineers
📊Math & Statistics0/8
Gradients and BackpropVectors, Matrices & TensorsLinear Algebra for MLAdam, Momentum, SchedulersProbability for Machine LearningStatistics and UncertaintyDistributions and SamplingHypothesis Tests, Intervals, and pass@k
📚Preparation & Prerequisites0/13
Neural Networks from ScratchCNNs from ScratchTraining & BackpropagationSoftmax, Cross-Entropy & OptimizationRNNs, LSTMs, GRUs, and Sequence ModelingAutoencoders and VAEsThe Transformer Architecture End-to-EndLanguage Modeling & Next TokensFrom GPT to Modern LLMsPrompt Engineering FundamentalsCalling LLM APIs in ProductionFirst AI App End-to-EndThe LLM Lifecycle
🧮ML Algorithms & Evaluation0/11
Linear Regression from ScratchLogistic Regression and MetricsDecision Trees, Forests, and BoostingReinforcement Learning BasicsValidation and LeakageClustering and PCACore Retrieval AlgorithmsDecoding AlgorithmsExperiment Design and A/B TestingPyTorch Training LoopsDataset Pipelines and Data Quality
📦Production ML Systems0/6
Feature Engineering for Production MLBatch and Streaming Feature PipelinesGradient Boosted Trees in ProductionRanking and Recommendation SystemsForecasting and Anomaly DetectionMonitoring Predictive Models
🧪Core LLM Foundations0/8
The Bitter Lesson & ComputeBPE, WordPiece, and SentencePieceStatic to Contextual EmbeddingsPerplexity & Model EvaluationFile Ingestion for AIChunking StrategiesLLM Benchmarks & LimitationsInstruction Tuning & Chat Templates
🧰Applied LLM Engineering0/23
Dimensionality Reduction for EmbeddingsCoT, ToT & Self-Consistency PromptingFunction Calling & Tool UseMCP & Tool Protocol StandardsPrompt Injection DefenseResponsible AI GovernanceData Labeling and Human FeedbackEvaluating AI AgentsProduction RAG PipelinesHybrid Search: Dense + SparseReranking and Cross-Encoders for RAGRAG Evaluation for Reliable AnswersLLM-as-a-Judge EvaluationBias & Fairness in LLMsHallucination Detection & MitigationLLM Observability & MonitoringExperiment Tracking with MLflow and W&BMixed Precision TrainingModel Versioning & DeploymentSemantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Gateways, Routing, and FallbacksDesign an Automated Support Agent
🎓Portfolio Capstones0/9
Capstone: Delivery ETA PredictionCapstone: Product RankingCapstone: Demand ForecastingCapstone: Image Damage ClassifierCapstone: Production ML PipelineCapstone: Document QACapstone: Eval DashboardCapstone: Fine-Tuned ClassifierCapstone: Production Agent
🧠Transformer Deep Dives0/8
Sentence Embeddings & Contrastive LossEmbedding Similarity & QuantizationScaled Dot-Product AttentionVision Transformers and Image EncodersPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNMechanistic InterpretabilityDecoding Strategies: Greedy to Nucleus
🧬Advanced Training & Adaptation0/16
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleBuild GPT from Scratch LabContinued Pretraining for Domain ShiftSynthetic Data PipelinesSupervised Fine-Tuning PipelineDistributed Training: FSDP & ZeROLoRA & Parameter-Efficient TuningReward Modeling from Preference DataRLHF & DPO AlignmentConstitutional AI & Red TeamingRLVR & Verifiable RewardsKnowledge Distillation for LLMsModel Merging and Weight InterpolationPrompt Optimization with DSPyRecursive Language Models (RLM)
🤖Advanced Agents & Retrieval0/14
Vector DB Internals: HNSW & IVFAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access ControlStructured Output GenerationReAct & Plan-and-ExecuteGuardrails & Safety FiltersCode Generation & SandboxingComputer-Use / GUI / Browser AgentsHuman-in-the-Loop Agent ArchitectureAI Coding Workflow with AgentsAgent Memory & PersistenceAgent Failure & RecoveryMulti-Agent Orchestration
⚡Inference & Production Scale0/20
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionPrefix Caching and Prompt CachingFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceModel Parallelism for LLM InferenceModel Quantization: GPTQ, AWQ & GGUFLocal LLM DeploymentSLM Specialization & Edge DeploymentSpeculative DecodingLong Context Window ManagementContext EngineeringMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time ComputeAdvanced MLOps & DevOps for AIGPU Serving & AutoscalingA/B Testing for LLMs
🏗️System Design Capstones0/9
Content Moderation SystemCode Completion SystemMulti-Tenant LLM PlatformLLM-Powered Search EngineVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image GenerationReal-Time Voice AI AgentReasoning & Test-Time Compute
🎤AI Lab Interviewing0/4
AI Lab Coding Interview: Python SystemsAI Lab System Design InterviewAI Lab Behavioral InterviewAI Lab Technical Presentation
Back to Topics
LearnInference & Production ScaleMulti-Query & Grouped-Query Attention
🚀HardInference Optimization

Multi-Query & Grouped-Query Attention

Compare MHA, MQA, and GQA architectures, calculate their KV cache footprint, and reason about memory-limited serving tradeoffs.

38 min read
Learning path
Step 124 of 155 in the full curriculum
Inference: TTFT, TPS & KV CacheKV Cache & PagedAttention

The previous chapter made the concrete: every active token consumes GPU memory during decode. This chapter asks how the model architecture can shrink that cache before serving tricks like paging, scheduling, or quantization are applied.

Multi-query and grouped-query attention reduce KV cache memory by sharing key and value heads across attention heads. This chapter explains why that memory saving matters for long prompts, high concurrency, and production serving cost.

Imagine a fulfillment center where every picker keeps their own personal copy of the route sheet. With five pickers, the paperwork pile is small. With sixty-four pickers and thousands of orders, the storage room overflows. The fix is simple: have small teams share one route sheet instead of giving every picker a personal copy.

That's the idea behind Grouped-Query Attention (GQA) and Multi-Query Attention (MQA). In standard decoder-side multi-head attention[1], each query head has its own Key and Value projections, so incremental decoding keeps separate cached K/V state per head. As conversations get longer and you serve more users simultaneously, this cache can grow to hundreds of gigabytes. GQA and MQA reduce that dynamic memory cost by sharing cached data across heads. If KV memory is the admission bottleneck, that saving can increase concurrency; it is not an automatic throughput or quality guarantee.

Why KV cache memory becomes the bottleneck

As we saw in the previous article on inference mechanics, every time a model generates a token, it saves the Key and Value from that step so it doesn't have to recompute them later. This KV cache is essential for performance, but it grows linearly with sequence length, and in standard Multi-Head Attention (MHA), every single head keeps its own separate copy. For big models, this adds up fast.

A tiny example to build intuition

Before we scale up to billions of parameters, let's walk through a miniature case you can count on your fingers.

Suppose a single layer has 2 attention heads and each head has a dimension of 4. For one , we must store a Key vector and a Value vector for every head. That's 2 heads times 2 vectors (K and V) times 4 numbers each:

ComponentCountNumbers stored
Key vectors2 heads2 x 4 = 8
Value vectors2 heads2 x 4 = 8
Total per token16

If you serve 8 requests in parallel and each request reaches 2,048 tokens, the cache explodes to 16 x 2,048 x 8 = 262,144 numbers for that one layer alone. Scale that to 80 layers and the pile becomes enormous. The problem isn't the math itself; it's the memory needed to hold all those numbers while the GPU streams them through High Bandwidth Memory (HBM) on every decoding step.

count-kv-elements.py
1def kv_elements(layers: int, tokens: int, batch: int, kv_heads: int, head_dim: int) -> int: 2 return 2 * layers * tokens * batch * kv_heads * head_dim 3 4tiny_one_layer = kv_elements(layers=1, tokens=2048, batch=8, kv_heads=2, head_dim=4) 5scaled_layers = kv_elements(layers=80, tokens=2048, batch=8, kv_heads=2, head_dim=4) 6print("tiny one-layer elements:", tiny_one_layer) 7print("with 80 layers:", scaled_layers)
Output
1tiny one-layer elements: 262144 2with 80 layers: 20971520

Compact MHA, GQA, and MQA K/V head-sharing comparison. Compact MHA, GQA, and MQA K/V head-sharing comparison.
MHA keeps one K/V pair per query head. GQA and MQA share K/V heads to shrink decode cache.

In standard MHA, each of the hhh attention heads maintains its own Key (KKK) and Value (VVV) projections. The attention computation for a single layer is:

Attention(Q,K,V)=softmax ⁣(QK⊤dk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QK⊤​)V

Where Q∈Rh×n×dkQ \in \mathbb{R}^{h \times n \times d_k}Q∈Rh×n×dk​, K∈Rh×n×dkK \in \mathbb{R}^{h \times n \times d_k}K∈Rh×n×dk​, and V∈Rh×n×dkV \in \mathbb{R}^{h \times n \times d_k}V∈Rh×n×dk​ for a sequence of length nnn. Each head learns a different attention pattern by attending to different aspects of the input. The KV cache stores KKK and VVV for all previously computed tokens so the model doesn't have to recompute them on each decoding step.

KV cache memory with standard MHA

For standard Multi-Head Attention (MHA) with hhh heads, each with dimension dkd_kdk​, here is how to read the cache size formula before you apply it:

  • 2 counts the Key and Value tensors separately.
  • hhh is the number of attention heads.
  • dkd_kdk​ is the dimension inside each head.
  • bytes per element depends on dtype: 2 bytes for FP16, 1 byte for INT8, and so on.

Multiply those four numbers and you get the cache for a single token in a single layer. For the full cache across all layers, sequence length, and batch:

KV cache per token per layer=2×h×dk×bytes per element\text{KV cache per token per layer} = 2 \times h \times d_k \times \text{bytes per element}KV cache per token per layer=2×h×dk​×bytes per element

Total KV cache=2×L×S×B×h×dk×bytes per element\text{Total KV cache} = 2 \times L \times S \times B \times h \times d_k \times \text{bytes per element}Total KV cache=2×L×S×B×h×dk​×bytes per element

Where LLL is layers, SSS is sequence length, and BBB is batch size.

Common trap: 2×h×dk×bytes2 \times h \times d_k \times \text{bytes}2×h×dk​×bytes is only the cache for one token in one layer. For a full serving estimate, you still need to multiply by layers, sequence length, and batch size.

Concrete example (72B-style decoder dimensions)

Qwen2.5-72B[2] is a useful reference point because it uses 80 layers, 64 query heads, and a head dimension of 128. The published model uses GQA-8, not full MHA, but if a model with those same dimensions used FP16 MHA, the KV cache would be:

ParameterValue
dmodeld_{\text{model}}dmodel​8192
hhh (heads)64
dkd_kdk​128
Layers80
Sequence length4096
Batch size32

KV cache=2×80×4096×32×64×128×2 bytes≈344 GB\text{KV cache} = 2 \times 80 \times 4096 \times 32 \times 64 \times 128 \times 2 \text{ bytes} \approx \mathbf{344 \text{ GB}}KV cache=2×80×4096×32×64×128×2 bytes≈344 GB

(Equivalent formulation using dmodel=h×dk=8192d_{\text{model}} = h \times d_k = 8192dmodel​=h×dk​=8192: 2×80×4096×32×8192×22 \times 80 \times 4096 \times 32 \times 8192 \times 22×80×4096×32×8192×2 bytes.)

222 counts K and V tensors, 808080 is layers, 409640964096 is tokens per request, 323232 is batch size, 646464 is heads, 128128128 is head dimension (dkd_kdk​), and the final 222 is FP16 bytes per element.

That's more than twice the model weights themselves (~144 GB in FP16). For long-context or high-concurrency decoding, moving that cache through HBM can dominate the incremental attention path.[3]

The next figure uses the same 72B-style dimensions, but shifts the shape of the workload from 32 parallel 4K requests to one 128K request. Both cases contain 131,072 cached token positions, so full MHA lands in the same 344 GB range.

KV cache budget chart for a 72B-style decoder at 128K context: MHA uses about 344 GB per request, GQA with eight KV heads uses about 43 GB, and MQA uses about 5.4 GB. KV cache budget chart for a 72B-style decoder at 128K context: MHA uses about 344 GB per request, GQA with eight KV heads uses about 43 GB, and MQA uses about 5.4 GB.
At long context, head sharing can decide whether one request fits at all. These numbers are cache only; weights and scheduler overhead still count.

In an e-commerce setting, this is the difference between a full-MHA cache that doesn't fit on a single 80 GB GPU and a GQA-8 cache around 43 GB before weights, runtime buffers, and fragmentation. Whether either configuration is viable still depends on weights, quantization, hardware layout, scheduler policy, and measured traffic.

compare-attention-kv-footprints.py
1def kv_cache_gb(kv_heads: int, tokens: int = 4096, batch: int = 32) -> float: 2 bytes_used = 2 * 80 * tokens * batch * kv_heads * 128 * 2 3 return bytes_used / 1e9 4 5for name, kv_heads in (("MHA", 64), ("GQA-8", 8), ("MQA", 1)): 6 print(f"{name}: {kv_cache_gb(kv_heads):.2f} GB")
Output
1MHA: 343.60 GB 2GQA-8: 42.95 GB 3MQA: 5.37 GB


Multi-Query Attention (MQA)

The extreme compression

Back to the warehouse analogy: MQA takes the extreme approach. Instead of every picking lane keeping its own copy of the route notes, all lanes share one copy. This saves enormous storage, but some lanes might lose details that would have helped their specific work.

MQA[4] uses one shared K head and one shared V head across all query heads:

  • hhh separate query projections (same as standard MHA)
  • 1 shared key projection for all hhh query heads
  • 1 shared value projection for all hhh query heads

In the head-sharing figure above, MQA is the rightmost design: every query head routes through the same cached K/V pair.

How much memory MQA actually saves

Return to our tiny example: 2 heads, dimension 4, one token. Under MHA we stored 16 numbers. Under MQA we store only 1 Key and 1 Value head, so the count drops to 2×1×4=82 \times 1 \times 4 = 82×1×4=8 numbers. That's a 2x reduction for 2 heads. At 64 heads the reduction becomes 64x.

The memory footprint of MQA is drastically smaller because the number of attention heads (hhh) in the standard MHA formula is replaced by 111:

MQA KV cache per token per layer=2×1×dk×bytes per element\text{MQA KV cache per token per layer} = 2 \times 1 \times d_k \times \text{bytes per element}MQA KV cache per token per layer=2×1×dk​×bytes per element

Where 111 means a single shared KV head (instead of hhh heads), so cache per token per layer scales with dkd_kdk​ rather than with h⋅dkh \cdot d_kh⋅dk​.

Instead of storing separate K and V for each of the 64 heads, MQA stores just one shared K and one shared V, cutting cache by 64x. All query heads look up the same key-value pair, like every picking lane using the same route note.

This scaling is proportional to the number of query heads: with hhh query heads and a single KV head, you get an hhhx reduction. The 8-head illustration above shows this with 8x reduction; the same principle applies at scale with 64 heads for a 64x reduction.

MethodKV headsCache per token per layer (FP16)Savings
MHA642×64×128×2=32 KB2 \times 64 \times 128 \times 2 = 32\text{ KB}2×64×128×2=32 KB1x
MQA12×1×128×2=512 B2 \times 1 \times 128 \times 2 = 512\text{ B}2×1×128×2=512 B64x

For the same 72B-style dimensions with batch=32 and seq=4096, MHA uses 344 GB of KV cache. MQA cuts that by 64x to ~5.4 GB.

Performance vs. quality tradeoffs in MQA

MQA can reduce quality because all heads share the same key-value representation. Query heads can still ask different questions, but they no longer get separate K/V subspaces.

Common mistake: Treating MQA quality loss as either zero or catastrophic. Real impact depends on model size and task. Shazeer's original paper motivates MQA as a serving optimization, and later GQA work shows why many larger models want more than one KV head.[4][5]

MQA is like projecting a single telescope's view onto a screen that all astronomers in an observatory can see simultaneously. They can each ask different questions about what they see (separate queries), but they're all looking at the same image (shared K/V). Most of the time, the single view is good enough, but occasionally, one astronomer misses a detail they would have caught with their own specialized lens.

A minimal MQA implementation

Here is a basic PyTorch implementation of Multi-Query Attention. It projects the input into multiple query heads, but uses only a single shared key and value projection across all query heads. Notice how the key and value tensors are broadcast across the query heads during the attention score calculation to produce the final output.

a-minimal-mqa-implementation.py
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4import math 5 6class MultiQueryAttention(nn.Module): 7 """ 8 Multi-Query Attention (MQA). 9 Q gets h separate heads; K and V share a single head. 10 """ 11 def __init__(self, d_model: int, n_heads: int): 12 super().__init__() 13 self.n_heads = n_heads 14 self.d_k = d_model // n_heads 15 16 # h separate query projections 17 self.W_q = nn.Linear(d_model, d_model) 18 # ONE shared key and value projection (output dim is d_k, not d_model) 19 self.W_k = nn.Linear(d_model, self.d_k) 20 self.W_v = nn.Linear(d_model, self.d_k) 21 self.W_o = nn.Linear(d_model, d_model) 22 23 def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: 24 B, N, D = x.shape 25 26 # Q: [batch, seq, d_model] -> [batch, heads, seq, d_k] 27 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) 28 29 # K and V: [batch, seq, d_k] -> [batch, 1, seq, d_k] 30 # The singleton head dimension broadcasts across all query heads automatically 31 K = self.W_k(x).unsqueeze(1) # [B, 1, N, d_k] 32 V = self.W_v(x).unsqueeze(1) # [B, 1, N, d_k] 33 34 # Attention scores: Q @ K^T 35 # PyTorch broadcasts the singleton head dimension to match h 36 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) 37 # shape: [B, h, N, N] 38 39 if mask is not None: 40 scores = scores.masked_fill(mask == 0, float('-inf')) 41 42 attn = F.softmax(scores, dim=-1) 43 44 # Weighted sum: attn @ V 45 # V's singleton head dimension broadcasts to h 46 out = torch.matmul(attn, V) # [B, h, N, d_k] 47 out = out.transpose(1, 2).contiguous().view(B, N, D) 48 return self.W_o(out) 49 50# Shape smoke test 51mqa = MultiQueryAttention(d_model=64, n_heads=4) 52x = torch.randn(2, 10, 64) # batch=2, seq=10 53out = mqa(x) 54print("MQA output shape:", out.shape)
Output
1MQA output shape: torch.Size([2, 10, 64])


Grouped-Query Attention (GQA)

The balanced compromise

MQA saves the most memory, but forcing all query heads to share one K/V pair can reduce quality on some workloads. The compromise is to split the lanes into groups, with each group sharing one route note. That's GQA: less aggressive than MQA (one note for everyone), but more memory-efficient than MHA (one note per query head).

GQA[5] is the compromise between MHA and MQA. Instead of 1 KV head (MQA) or hhh KV heads (MHA), use ggg KV groups where 1<g<h1 < g < h1<g<h:

hg=queries per group\frac{h}{g} = \text{queries per group}gh​=queries per group

With ggg groups and hhh total query heads, each group of h/gh/gh/g queries shares one K and one V. For example, Mistral 7B's 32 query heads with 8 KV groups means 4 queries share each KV pair. That's a 4:1 ratio: 4x less KV cache than MHA, but with significantly more representational diversity than MQA's single KV pair.

Think of GQA like a logistics hub: instead of every driver carrying their own copy of the delivery manifest (MHA) or one manifest for the entire fleet (MQA), you organize drivers into pods that share a manifest. Each pod can still take different routes (different queries), but it shares cached reference data.

map-query-heads-to-kv-groups.py
1def kv_group_for_query(query_head: int, query_heads: int, kv_heads: int) -> int: 2 assert query_heads % kv_heads == 0 3 return query_head // (query_heads // kv_heads) 4 5assignments = [kv_group_for_query(head, query_heads=8, kv_heads=2) for head in range(8)] 6print("query to KV group:", assignments) 7print("cache reduction:", 8 // 2, "x")
Output
1query to KV group: [0, 0, 0, 0, 1, 1, 1, 1] 2cache reduction: 4 x

GQA in practice

To understand how this looks in production, let's look at the configurations used by several popular models:

ModelTotal heads (hhh)KV groups (ggg)Ratio
Qwen2.5-72B[2]6488:1
Llama 2 70B[6]6488:1
Mistral 7B[7]3284:1
Gemma 2 9B[8]1682:1

Beyond memory savings, GQA can make tensor-parallel serving cleaner. If num_kv_heads divides the tensor-parallel degree, each shard can own an even slice of KV heads. If it doesn't divide cleanly, runtimes often replicate some KV heads instead, which wastes memory.

check-kv-head-sharding.py
1def sharding_plan(kv_heads: int, tensor_parallel: int) -> str: 2 if kv_heads % tensor_parallel == 0: 3 return f"even: {kv_heads // tensor_parallel} KV heads per shard" 4 return "runtime may need replication or uneven ownership" 5 6print("TP=4:", sharding_plan(kv_heads=8, tensor_parallel=4)) 7print("TP=6:", sharding_plan(kv_heads=8, tensor_parallel=6))
Output
1TP=4: even: 2 KV heads per shard 2TP=6: runtime may need replication or uneven ownership

Memory comparison (72B model, seq=4096)

These are per-request KV cache sizes at the given sequence length:

MethodKV groupsKV cache per requestQuality posture
MHA64~10.7 GBReference architecture
GQA-88~1.34 GBValidate converted or trained model on task evals
MQA1~0.17 GBStrongest sharing constraint; evaluate carefully

In Ainslie et al., intermediate group counts recovered much of MHA quality while retaining MQA-like inference benefits after uptraining.[5] The correct group count for another model remains an architecture and evaluation choice, not a guarantee inherited from that experiment.

Real-world adoption across model families

Different model families have made distinct architectural choices based on their serving requirements and quality targets:

ArchitectureTypical usageDesign Rationale
MHAOlder decoder designs, or deployments where KV memory is less constrainedMaximum per-head flexibility, highest KV-cache cost
MQAServing-first deployments that need the smallest possible KV cacheAggressive memory reduction with the strongest sharing constraint
GQAPublished models such as Llama 2 34B/70B, Qwen2.5-72B, Mistral 7B, and Gemma 2 9BIntermediate KV-head count; evaluate quality and serving together

Key observations

  • MQA is best understood as a serving optimization: shared KV heads cut cache size and reduce memory traffic during incremental decoding.[4][3]
  • Many modern decoder-only LLMs use grouped KV heads, but the choice is model-specific rather than family-wide. Published examples in this article include Llama 2 34B/70B, Mistral 7B, Qwen2.5-72B, and Gemma 2 9B.[6][7][2][8]
  • There's no universal best group count. Model size, quality target, and serving stack decide whether 2:1, 4:1, or 8:1 is right.[5]

Converting MHA to GQA via uptraining

A key practical insight for teams adapting older models: you don't need to train a GQA architecture from scratch. You can convert an existing MHA model to GQA through a process called "uptraining."

Key insight: Converting MHA to GQA is like a company merger. Several departments each had their own filing system (KV heads). Instead of rebuilding from scratch, you merge departments into groups, average their files, and then give the new organization time to adapt through continued training. Quality must be evaluated after the conversion.

The conversion has two main steps:

  1. Mean-pool KV heads: Group the existing Key and Value heads into the desired number of partitions (e.g. merging 64 heads into 8 groups of 8). Take the mean of the weights for the KKK and VVV projection matrices within each group. This initializes the new GQA heads to the average behavior of the original MHA heads, preserving the model's existing attention patterns as a strong starting point.
  2. Fine-tune (uptrain): Train the converted model for a small fraction of the original pretraining budget using the standard next-token prediction objective. This allows the model to adapt its internal representations to the newly merged KV projections.
Compact GQA uptraining flow from MHA checkpoint to validation. Compact GQA uptraining flow from MHA checkpoint to validation.
GQA uptraining mean-pools old K/V heads, trains briefly, then validates quality and serving metrics.

The original GQA paper[5] showed that this recipe recovers most of the lost quality with about 5% of the original pretraining compute. For a fresh pretraining run, teams usually bake the desired KV-head pattern into the architecture from day one. Uptraining matters when you inherit an older MHA checkpoint and want serving gains without restarting pretraining.

mean-pool-kv-heads-for-uptraining.py
1def mean_pool_heads(heads: list[list[float]], group_size: int) -> list[list[float]]: 2 assert len(heads) % group_size == 0 3 pooled: list[list[float]] = [] 4 for start in range(0, len(heads), group_size): 5 group = heads[start : start + group_size] 6 pooled.append([ 7 sum(values) / len(group) 8 for values in zip(*group) 9 ]) 10 return pooled 11 12mha_k_heads = [[1.0, 3.0], [3.0, 5.0], [10.0, 12.0], [14.0, 16.0]] 13print("GQA initial K heads:", mean_pool_heads(mha_k_heads, group_size=2))
Output
1GQA initial K heads: [[2.0, 4.0], [12.0, 14.0]]


Beyond GQA: Multi-Head Latent Attention (MLA)

GQA reduces the number of distinct KV heads stored in the cache. Multi-Head Latent Attention (MLA) takes a fundamentally different route: it reduces the dimensionality of what must be stored per position by learning a low-rank latent representation.

How MLA works (high-level)

Instead of caching full per-head Key and Value content vectors for every token, the model does the following:

  1. Input hidden states are projected into a compact latent vector of dimension dcd_cdc​.

  2. Only this compact latent state (plus a small amount of decoupled positional information) is written into the KV cache.

  3. The architecture defines learned up-projections from that latent state. In an optimized inference implementation, those projection matrices can be absorbed into the query and output paths so decode does not materialize full per-head cached K/V again.[9]

DeepSeek-V2 also separates a positional RoPE component from compressed content so the cache can retain the required position-dependent term without preventing content compression.[9]

DeepSeek-V2 reports a 93.3% reduction in KV cache memory footprint compared to a comparable 67B baseline model, alongside its reported benchmark results.[9] The savings are not a simple head-count ratio; they depend on the chosen latent width dcd_cdc​, positional component, and execution path.

GQA versus MLA comparison: GQA shrinks KV-head count, while MLA shrinks stored state. GQA versus MLA comparison: GQA shrinks KV-head count, while MLA shrinks stored state.
GQA reduces stored heads; MLA changes the cached representation and therefore requires a compatible runtime path.

Trade-offs and adoption

AspectGQAMLA
MechanismFewer KV heads (g<hg < hg<h)Low-rank latent compression + up-projection
Kernel compatibilityFits GQA-aware attention kernelsNeeds an MLA-aware latent-cache implementation
Compression measureKV-head ratio gives exact cache-factor comparisonCached latent width and positional component determine savings
Published examplesLlama 2 70B, Qwen2.5-72B, Mistral 7B, Gemma 2 9BDeepSeek-V2

Why the runtime distinction matters. GQA slots into grouped-attention implementations and familiar tensor-parallel strategies. MLA changes the cached representation and needs a compatible execution path. A smaller cache on paper is only useful when the selected runtime implements that path efficiently.

This architectural fork appears in production interviews: "When would you choose GQA over a more aggressive compression scheme like MLA?" The defensible answer is that GQA has a simpler K/V-head contract, while MLA can compress harder if the selected model and runtime both support its latent path.

Using DeepSeek-V2's documented 512-dimensional compressed K/V latent and 64-dimensional decoupled key component, the raw cached-value comparison looks like this.[9]

compare-cached-state-widths.py
1def cached_values_per_token(kv_heads: int, head_dim: int) -> int: 2 return 2 * kv_heads * head_dim 3 4gqa_values = cached_values_per_token(kv_heads=8, head_dim=128) 5mla_example_values = 512 + 64 # compact content latent plus positional component 6print("GQA cached values:", gqa_values) 7print("MLA-style cached values:", mla_example_values) 8print("compare actual model dimensions before deployment:", True)
Output
1GQA cached values: 2048 2MLA-style cached values: 576 3compare actual model dimensions before deployment: True


Real-world serving impact

Concurrency impact (72B-style dimensions, 4K context)

Reducing KV cache doesn't guarantee an equal throughput gain, but it does raise the memory-limited concurrency ceiling. Using the same 80-layer, 64-query-head, 128-dim-head example:

Memory-limited concurrency ceiling chart: compared with MHA at one times, GQA-8 raises the KV-cache-limited ceiling to eight times and MQA raises it to sixty-four times, with caveats for weights, compute, schedulers, and kernels. Memory-limited concurrency ceiling chart: compared with MHA at one times, GQA-8 raises the KV-cache-limited ceiling to eight times and MQA raises it to sixty-four times, with caveats for weights, compute, schedulers, and kernels.
Head sharing raises the memory-limited batch ceiling. Real throughput still depends on weights, FFN compute, kernel support, scheduler behavior, and prefill traffic.
ConfigKV Cache per RequestRelative Memory-Limited Concurrency Ceiling
MHA~10.7 GB1x
GQA-8~1.34 GB~8x
MQA~0.17 GB~64x

If KV cache is what limits batch size, those ratios are first-order good estimates. Real systems land below that theoretical ceiling because weights, scheduler overhead, kernel efficiency, and interconnect traffic still matter.[3]

Long-context impact

KV savings become larger in absolute bytes with long-context models. For the same 72B-style dimensions, per-request cache grows linearly with sequence length:

Sequence LengthMHA KV Cache (approx.)GQA-8 KV Cache (approx.)Savings
4K~10.7 GB~1.34 GB8x
32K~85.9 GB~10.7 GB8x
128K~343.6 GB~42.9 GB8x
1M~2.75 TB~343.6 GB8x

At long context, head-sharing or another cache compression strategy becomes an explicit capacity decision. At 128K, full MHA uses roughly 344 GB of KV state per request in this example, before weights, allocator fragmentation, or extra concurrency. GQA-8 drops that cache to about 43 GB, which is still expensive and still needs an end-to-end memory budget.

budget-long-context-cache.py
1def kv_cache_gb(kv_heads: int, tokens: int) -> float: 2 return 2 * 80 * tokens * kv_heads * 128 * 2 / 1e9 3 4mha_128k = kv_cache_gb(kv_heads=64, tokens=131_072) 5gqa_128k = kv_cache_gb(kv_heads=8, tokens=131_072) 6print(f"MHA 128K cache: {mha_128k:.1f} GB") 7print(f"GQA-8 128K cache: {gqa_128k:.1f} GB") 8print("GQA cache plus 140 GB weights fits in 4x80 GB raw:", gqa_128k + 140 <= 320)
Output
1MHA 128K cache: 343.6 GB 2GQA-8 128K cache: 42.9 GB 3GQA cache plus 140 GB weights fits in 4x80 GB raw: True

Common mistake: Treating long-context support as an architectural context-window claim rather than a serving budget. A long support history, codebase prompt, or retrieved document bundle can consume the same KV allocation. Calculate active-token memory before promising concurrency.


How serving engines handle GQA head expansion

When building or working with inference engines, you must handle the KV head expansion efficiently. Here is a conceptual implementation of how an engine matches the smaller number of KV heads to the larger number of query heads before computing attention.

how-serving-engines-handle-gqa-head.py
1import math 2import torch 3import torch.nn.functional as F 4 5def gqa_attention_reference(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, 6 n_heads: int, n_kv_heads: int) -> torch.Tensor: 7 """ 8 Readable GQA implementation concept. 9 10 Args: 11 Q: [batch, n_heads, seq, d_k] 12 K: [batch, n_kv_heads, seq, d_k] 13 V: [batch, n_kv_heads, seq, d_k] 14 n_heads: Number of query heads 15 n_kv_heads: Number of KV heads 16 """ 17 # Expand KV heads to match Q heads. 18 # This materializes repeated K/V for clarity. Production serving kernels 19 # compute grouped attention without building these larger tensors. 20 heads_per_group = n_heads // n_kv_heads 21 22 # [batch, n_kv_heads, seq, d_k] -> [batch, n_kv_heads * group_size, seq, d_k] 23 K_expanded = K.repeat_interleave(heads_per_group, dim=1) 24 V_expanded = V.repeat_interleave(heads_per_group, dim=1) 25 26 # Standard attention from here 27 d_k = Q.size(-1) 28 scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / math.sqrt(d_k) 29 attn = F.softmax(scores, dim=-1) 30 return torch.matmul(attn, V_expanded) 31 32# Shape smoke test 33B, N, h, h_kv, d = 1, 8, 32, 8, 64 34Q = torch.randn(B, h, N, d) 35K = torch.randn(B, h_kv, N, d) 36V = torch.randn(B, h_kv, N, d) 37out = gqa_attention_reference(Q, K, V, h, h_kv) 38print("GQA output shape:", out.shape)
Output
1GQA output shape: torch.Size([1, 32, 8, 64])

Modern kernels (FlashAttention[10], FlashInfer[11]) handle GQA without materializing the expanded K/V tensors. This avoids the memory overhead of repeat_interleave and computes grouped attention directly in the fused kernel.

Production tip: When evaluating models for serving, total parameter count isn't enough. num_kv_heads, context length, and KV-cache dtype often dominate concurrency. GQA-8 raises the memory-limited concurrency ceiling by roughly 8x versus same-dimension MHA, but realized cost per query still depends on batching, quantization, and kernels.

KV cache update correctness

In serving code, the hardest GQA bug is often not the attention formula. It's updating the cache correctly one token at a time.

GQA cache update invariant showing narrow KV writes, all query heads reading, and parity test against full recompute. GQA cache update invariant showing narrow KV writes, all query heads reading, and parity test against full recompute.
GQA cache correctness has one useful invariant: write narrow KV heads, read with all query heads, then prove incremental decode matches full-prefix recompute.

For one decode step, the runtime should:

  1. Project only the new token into Q, K, and V.
  2. Apply RoPE or another positional transform using the absolute position of that token.
  3. Append the new K/V slice at the next cache position.
  4. Run attention with Q from the new token against all cached K/V positions.
  5. Track cache length per request, because different requests finish at different times.

The cache shape should use KV heads, not query heads:

text
1K_cache: (batch, n_kv_heads, max_seq, d_k) 2V_cache: (batch, n_kv_heads, max_seq, d_k)

For an MHA model, n_kv_heads == n_heads. For MQA, n_kv_heads == 1. For GQA, 1 < n_kv_heads < n_heads.

enforce-cache-append-cursor.py
1def append_kv(cache: list[str], cursor: int, token_value: str) -> int: 2 if cursor != len(cache): 3 raise ValueError("cursor would overwrite or skip cache state") 4 cache.append(token_value) 5 return cursor + 1 6 7cache: list[str] = [] 8cursor = append_kv(cache, cursor=0, token_value="token-0-kv") 9cursor = append_kv(cache, cursor=cursor, token_value="token-1-kv") 10print("cache length:", len(cache)) 11try: 12 append_kv(cache, cursor=1, token_value="stale-write") 13except ValueError as exc: 14 print("blocked:", exc)
Output
1cache length: 2 2blocked: cursor would overwrite or skip cache state

Common cache-update bugs look like this:

SymptomLikely bugCheck
answer quality degrades after a few tokensoverwrote position t - 1 instead of appending at tprint cache length after every decode step
works at batch size 1, fails under batchingused one global cache length for all requestsstore per-request positions
GQA memory is still hugeallocated cache with query heads instead of KV headsassert K_cache.size(1) == n_kv_heads
long-context output becomes incoherentapplied RoPE with local chunk position instead of absolute positionlog the position id used for each appended token

That's why cache tests should compare a cached decode path against a full-prefix recompute path on the same tiny prompt. The logits should match closely for the next token. If they don't, the bug is usually position IDs, mask shape, or cache append order.


What you should be able to defend

By this point, you should be able to reason from the architecture diagram to serving cost, not just name the attention variant.

SkillWhat a strong answer includes
KV-cache sizingFull cache math multiplies K/V tensors, layers, sequence length, batch size, KV heads, head dimension, and bytes per element.
MHA vs MQA vs GQAMHA stores one K/V pair per query head, MQA stores one shared K/V pair, and GQA stores an intermediate number of K/V groups.
Memory reduction factorSavings follow num_query_heads / num_key_value_heads for the KV cache, such as 4x for Mistral 7B and 8x for Qwen2.5-72B.
Quality tradeoffFewer KV heads reduce memory traffic, but they also reduce K/V representational capacity.
UptrainingExisting MHA checkpoints can be converted by mean-pooling K/V heads and continuing language-model training for a small compute budget.
Serving fitGQA appears in the cited model architectures and can raise memory-limited concurrency when the runtime supports its head layout efficiently.
GQA vs MLA choiceGQA shrinks the number of full K/V heads. MLA stores a compact latent and needs an MLA-aware inference path that can avoid full cached K/V materialization. The right choice depends on runtime support and cache pressure.
Memory win vs throughput winKV-cache savings can raise the memory-limited batch ceiling without translating into the same wall-clock speedup if compute, prefill, or scheduler overhead still dominates.

Follow-up questions

A 70B-style chat model must serve 128K requests. What memory story do you defend?

For 80 layers, 64 query heads, head dimension 128, FP16 cache, and full MHA, the per-token-per-layer cache is 2 * 64 * 128 * 2 = 32 KB. Across all 80 layers, that is about 2.5 MiB per token. At 128K context, one request needs about 344 GB of KV cache under full MHA. GQA-8 cuts that to about 43 GB per request, which is much better but still too large for comfortable single-node serving. Strong answers mention head sharing plus paging, scheduling, retrieval, or KV-cache quantization.

A team wants GQA but cannot afford full retraining. What rollout plan makes sense?

Use uptraining. Group old K/V heads, mean-pool each group into one new K head and one new V head, then continue language-model training for a small fraction of original pretraining budget so the model adapts to reduced K/V capacity. After that, validate both task quality and serving behavior: perplexity or downstream evals, KV-cache size, tensor-parallel sharding fit, and decode throughput under realistic batch and context settings. Ainslie et al. used this path to recover much of the quality lost by reducing KV heads.[5]

num_key_value_heads = 8 and tensor parallel degree is 6. Why is that awkward?

GQA is easiest when KV heads divide cleanly across tensor-parallel shards. With 8 KV heads and TP degree 6, some runtimes must replicate KV data or use uneven ownership. That wastes memory and cuts into the concurrency gain you expected from GQA.

Your runtime offers both GQA and MLA paths today. How do you choose between them?

Choose GQA when your selected runtime supports grouped heads and its sharding layout fits the model. Choose MLA when your model is MLA-native, your stack ships its latent-cache path, and cache pressure is dominant enough to justify that execution contract.

Cached decode diverges from full recompute after 200 tokens. What do you test first?

Check that cache shape uses n_kv_heads, not query heads, and that absolute positions, cache cursors, and append order all stay correct. The quickest proof is a tiny-prompt parity test: cached decode logits should match full-prefix recompute closely at each step.


Common pitfalls

"Model weights are the same thing as KV cache"

Symptom: You calculate that a quantized model fits in GPU memory, then the server still runs out of memory under long context or high concurrency.

Cause: Weights are static model parameters. The KV cache is dynamic per-request state. GQA mainly shrinks dynamic K/V activations, not the feed-forward layers or the full weight tensor.

Fix: Budget weights and KV cache separately. Then add activation buffers, allocator slack, scheduler state, and fragmentation.

"GQA speeds up training"

Symptom: You benchmark training throughput and see almost no change after switching from MHA to GQA.

Cause: Training often has a different bottleneck from incremental decoding. GQA reduces decode cache state, while training processes whole sequences and may remain dominated by large attention and feed-forward computations.

Fix: Measure decode throughput (tokens per second during autoregressive generation), not training throughput. GQA's wins show up when you're repeatedly loading the KV cache for each new token, not when you're computing the full attention matrix once.

"MQA always destroys quality"

Symptom: You reject MQA outright even for small, latency-sensitive models where the quality difference isn't visible on your task.

Cause: MQA imposes the strongest sharing constraint, but impact depends on model size, task, and training recipe. Treating it as always catastrophic is as wrong as treating it as free.

Fix: Compare task evals and serving metrics for your workload. Use GQA when you need a safer quality/serving compromise.

"I saved 8x on KV cache, so my serving cost dropped 8x"

Symptom: You calculate a huge KV cache reduction, but real-world latency or cost only improves by a fraction of that.

Cause: KV cache is one piece of the puzzle. Model weights, FFN compute, attention arithmetic, scheduler overhead, and interconnect traffic still matter. If your batch was previously limited by compute rather than memory, shrinking the cache won't move the needle as much.

Fix: Profile end-to-end latency with a realistic batch size. Use a tool like NVIDIA Nsight Systems or vLLM's built-in metrics to see whether you're memory-bound or compute-bound before betting on GQA alone.

"repeat_interleave exploded my memory during GQA training"

Symptom: Out-of-memory errors when you naively expand KV heads with repeat_interleave inside a training loop.

Cause: repeat_interleave materializes a larger tensor in memory. For 32 query heads and 8 KV groups, that temporarily creates a 4x larger K/V tensor before the matmul.

Fix: Use FlashAttention or FlashInfer, which handle GQA natively without materializing the expanded tensor. If you must write a custom kernel, use an implicit broadcast or fused kernel rather than explicit expansion.

"I can regroup heads any way I want during GQA conversion"

Symptom: The converted checkpoint loads, but quality drops sharply or tensor-parallel shards disagree after serving rollout.

Cause: Head grouping is not arbitrary. Query and KV heads often follow a fixed ordering in the weight matrices, and tensor-parallel sharding can interleave or chunk that order in specific ways. Regrouping without respecting the original layout silently changes which heads share a K/V projection.

Fix: Preserve the model's published head ordering during regrouping, then run a tiny parity test and shard-level smoke test before broader evals. Don't assume head 0-7 is always the intended first group.

"The KV-head count doesn't affect sharding"

Symptom: Tensor-parallel serving uses more memory than expected, or some shards replicate K/V state.

Cause: num_key_value_heads interacts with the tensor-parallel degree. Awkward divisibility can force replication or uneven ownership in some runtimes.

Fix: Check model config before choosing TP degree. Prefer serving layouts where KV heads divide cleanly across shards.


A quick check

Try this without looking back at the tables:

A model has 40 layers, 32 query heads, 8 KV groups, head dimension 128, and FP16 KV-cache elements. You serve a batch of 8 requests, each at 2,048 tokens. How large is the KV cache under GQA? Show your work.

Solution sketch
  1. Per token per layer: 2 (K and V) x 8 (KV groups) x 128 (head dim) x 2 (FP16 bytes) = 4,096 bytes.
  2. Per request: 2,048 tokens x 40 layers x 4,096 bytes = 335,544,320 bytes ≈ 320 MiB.
  3. Batch of 8: 8 x 320 MiB ≈ 2.5 GiB.

Compare to MHA: 32 KV heads instead of 8 gives a 4x larger cache, so the same batch would need roughly 10 GiB just for the KV cache.


What to remember and where to go next

  1. KV Cache Bottleneck: For long contexts or high concurrency, KV cache can consume more memory than model weights.
  2. MQA vs GQA: MQA uses 1 KV head (largest cache saving and strongest sharing constraint). GQA uses ggg KV heads to trade cache size against representational capacity; measure quality.
  3. Savings follow the head ratio: KV cache savings are num_query_heads / num_kv_heads. A 64-query-head model with 8 KV heads gets 8x savings; a 32-query-head model with 8 KV heads gets 4x.
  4. Uptraining: You can convert MHA checkpoints to GQA by mean-pooling KV heads and continuing training for a small compute budget, then validating quality and serving behavior.
  5. Serving Support: Modern GQA-aware kernels such as FlashAttention[10] and FlashInfer[11] avoid explicit K/V expansion in the hot path.
Next Step
Continue to KV Cache & PagedAttention

There, you'll understand KV cache storage strategies for multi-tenant LLM inference, including PagedAttention, memory fragmentation, and vLLM architecture.

PreviousInference: TTFT, TPS & KV Cache
Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail
References

Attention Is All You Need.

Vaswani, A., et al. · 2017

Qwen2.5 Technical Report

Qwen Team · 2024

Efficiently Scaling Transformer Inference.

Pope, R., et al. · 2023 · arXiv preprint

Fast Transformer Decoding: One Write-Head is All You Need.

Shazeer, N. · 2019 · arXiv preprint

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.

Ainslie, J., et al. · 2023 · EMNLP 2023

Llama 2: Open Foundation and Fine-Tuned Chat Models.

Touvron, H., et al. · 2023 · arXiv preprint

Mistral 7B.

Jiang, A. Q., et al. · 2023

Gemma 2: Improving Open Language Models at a Practical Size

Gemma Team, Google DeepMind · 2024

DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model

DeepSeek-AI · 2024

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.

Dao, T. · 2023 · ICLR 2024

FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving.

Ye, Z., et al. · 2025