Compare MHA, MQA, and GQA architectures, calculate their KV cache footprint, and reason about memory-limited serving tradeoffs.
The previous chapter made the concrete: every active token consumes GPU memory during decode. This chapter asks how the model architecture can shrink that cache before serving tricks like paging, scheduling, or quantization are applied.
Multi-query and grouped-query attention reduce KV cache memory by sharing key and value heads across attention heads. This chapter explains why that memory saving matters for long prompts, high concurrency, and production serving cost.
Imagine a fulfillment center where every picker keeps their own personal copy of the route sheet. With five pickers, the paperwork pile is small. With sixty-four pickers and thousands of orders, the storage room overflows. The fix is simple: have small teams share one route sheet instead of giving every picker a personal copy.
That's the idea behind Grouped-Query Attention (GQA) and Multi-Query Attention (MQA). In standard decoder-side multi-head attention[1], each query head has its own Key and Value projections, so incremental decoding keeps separate cached K/V state per head. As conversations get longer and you serve more users simultaneously, this cache can grow to hundreds of gigabytes. GQA and MQA reduce that dynamic memory cost by sharing cached data across heads. If KV memory is the admission bottleneck, that saving can increase concurrency; it is not an automatic throughput or quality guarantee.
As we saw in the previous article on inference mechanics, every time a model generates a token, it saves the Key and Value from that step so it doesn't have to recompute them later. This KV cache is essential for performance, but it grows linearly with sequence length, and in standard Multi-Head Attention (MHA), every single head keeps its own separate copy. For big models, this adds up fast.
Before we scale up to billions of parameters, let's walk through a miniature case you can count on your fingers.
Suppose a single layer has 2 attention heads and each head has a dimension of 4. For one , we must store a Key vector and a Value vector for every head. That's 2 heads times 2 vectors (K and V) times 4 numbers each:
| Component | Count | Numbers stored |
|---|---|---|
| Key vectors | 2 heads | 2 x 4 = 8 |
| Value vectors | 2 heads | 2 x 4 = 8 |
| Total per token | 16 |
If you serve 8 requests in parallel and each request reaches 2,048 tokens, the cache explodes to 16 x 2,048 x 8 = 262,144 numbers for that one layer alone. Scale that to 80 layers and the pile becomes enormous. The problem isn't the math itself; it's the memory needed to hold all those numbers while the GPU streams them through High Bandwidth Memory (HBM) on every decoding step.
1def kv_elements(layers: int, tokens: int, batch: int, kv_heads: int, head_dim: int) -> int:
2 return 2 * layers * tokens * batch * kv_heads * head_dim
3
4tiny_one_layer = kv_elements(layers=1, tokens=2048, batch=8, kv_heads=2, head_dim=4)
5scaled_layers = kv_elements(layers=80, tokens=2048, batch=8, kv_heads=2, head_dim=4)
6print("tiny one-layer elements:", tiny_one_layer)
7print("with 80 layers:", scaled_layers)1tiny one-layer elements: 262144
2with 80 layers: 20971520
In standard MHA, each of the attention heads maintains its own Key () and Value () projections. The attention computation for a single layer is:
Where , , and for a sequence of length . Each head learns a different attention pattern by attending to different aspects of the input. The KV cache stores and for all previously computed tokens so the model doesn't have to recompute them on each decoding step.
For standard Multi-Head Attention (MHA) with heads, each with dimension , here is how to read the cache size formula before you apply it:
Multiply those four numbers and you get the cache for a single token in a single layer. For the full cache across all layers, sequence length, and batch:
Where is layers, is sequence length, and is batch size.
Common trap: is only the cache for one token in one layer. For a full serving estimate, you still need to multiply by layers, sequence length, and batch size.
Qwen2.5-72B[2] is a useful reference point because it uses 80 layers, 64 query heads, and a head dimension of 128. The published model uses GQA-8, not full MHA, but if a model with those same dimensions used FP16 MHA, the KV cache would be:
| Parameter | Value |
|---|---|
| 8192 | |
| (heads) | 64 |
| 128 | |
| Layers | 80 |
| Sequence length | 4096 |
| Batch size | 32 |
(Equivalent formulation using : bytes.)
counts K and V tensors, is layers, is tokens per request, is batch size, is heads, is head dimension (), and the final is FP16 bytes per element.
That's more than twice the model weights themselves (~144 GB in FP16). For long-context or high-concurrency decoding, moving that cache through HBM can dominate the incremental attention path.[3]
The next figure uses the same 72B-style dimensions, but shifts the shape of the workload from 32 parallel 4K requests to one 128K request. Both cases contain 131,072 cached token positions, so full MHA lands in the same 344 GB range.
In an e-commerce setting, this is the difference between a full-MHA cache that doesn't fit on a single 80 GB GPU and a GQA-8 cache around 43 GB before weights, runtime buffers, and fragmentation. Whether either configuration is viable still depends on weights, quantization, hardware layout, scheduler policy, and measured traffic.
1def kv_cache_gb(kv_heads: int, tokens: int = 4096, batch: int = 32) -> float:
2 bytes_used = 2 * 80 * tokens * batch * kv_heads * 128 * 2
3 return bytes_used / 1e9
4
5for name, kv_heads in (("MHA", 64), ("GQA-8", 8), ("MQA", 1)):
6 print(f"{name}: {kv_cache_gb(kv_heads):.2f} GB")1MHA: 343.60 GB
2GQA-8: 42.95 GB
3MQA: 5.37 GBBack to the warehouse analogy: MQA takes the extreme approach. Instead of every picking lane keeping its own copy of the route notes, all lanes share one copy. This saves enormous storage, but some lanes might lose details that would have helped their specific work.
MQA[4] uses one shared K head and one shared V head across all query heads:
In the head-sharing figure above, MQA is the rightmost design: every query head routes through the same cached K/V pair.
Return to our tiny example: 2 heads, dimension 4, one token. Under MHA we stored 16 numbers. Under MQA we store only 1 Key and 1 Value head, so the count drops to numbers. That's a 2x reduction for 2 heads. At 64 heads the reduction becomes 64x.
The memory footprint of MQA is drastically smaller because the number of attention heads () in the standard MHA formula is replaced by :
Where means a single shared KV head (instead of heads), so cache per token per layer scales with rather than with .
Instead of storing separate K and V for each of the 64 heads, MQA stores just one shared K and one shared V, cutting cache by 64x. All query heads look up the same key-value pair, like every picking lane using the same route note.
This scaling is proportional to the number of query heads: with query heads and a single KV head, you get an x reduction. The 8-head illustration above shows this with 8x reduction; the same principle applies at scale with 64 heads for a 64x reduction.
| Method | KV heads | Cache per token per layer (FP16) | Savings |
|---|---|---|---|
| MHA | 64 | 1x | |
| MQA | 1 | 64x |
For the same 72B-style dimensions with batch=32 and seq=4096, MHA uses 344 GB of KV cache. MQA cuts that by 64x to ~5.4 GB.
MQA can reduce quality because all heads share the same key-value representation. Query heads can still ask different questions, but they no longer get separate K/V subspaces.
Common mistake: Treating MQA quality loss as either zero or catastrophic. Real impact depends on model size and task. Shazeer's original paper motivates MQA as a serving optimization, and later GQA work shows why many larger models want more than one KV head.[4][5]
MQA is like projecting a single telescope's view onto a screen that all astronomers in an observatory can see simultaneously. They can each ask different questions about what they see (separate queries), but they're all looking at the same image (shared K/V). Most of the time, the single view is good enough, but occasionally, one astronomer misses a detail they would have caught with their own specialized lens.
Here is a basic PyTorch implementation of Multi-Query Attention. It projects the input into multiple query heads, but uses only a single shared key and value projection across all query heads. Notice how the key and value tensors are broadcast across the query heads during the attention score calculation to produce the final output.
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class MultiQueryAttention(nn.Module):
7 """
8 Multi-Query Attention (MQA).
9 Q gets h separate heads; K and V share a single head.
10 """
11 def __init__(self, d_model: int, n_heads: int):
12 super().__init__()
13 self.n_heads = n_heads
14 self.d_k = d_model // n_heads
15
16 # h separate query projections
17 self.W_q = nn.Linear(d_model, d_model)
18 # ONE shared key and value projection (output dim is d_k, not d_model)
19 self.W_k = nn.Linear(d_model, self.d_k)
20 self.W_v = nn.Linear(d_model, self.d_k)
21 self.W_o = nn.Linear(d_model, d_model)
22
23 def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
24 B, N, D = x.shape
25
26 # Q: [batch, seq, d_model] -> [batch, heads, seq, d_k]
27 Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
28
29 # K and V: [batch, seq, d_k] -> [batch, 1, seq, d_k]
30 # The singleton head dimension broadcasts across all query heads automatically
31 K = self.W_k(x).unsqueeze(1) # [B, 1, N, d_k]
32 V = self.W_v(x).unsqueeze(1) # [B, 1, N, d_k]
33
34 # Attention scores: Q @ K^T
35 # PyTorch broadcasts the singleton head dimension to match h
36 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
37 # shape: [B, h, N, N]
38
39 if mask is not None:
40 scores = scores.masked_fill(mask == 0, float('-inf'))
41
42 attn = F.softmax(scores, dim=-1)
43
44 # Weighted sum: attn @ V
45 # V's singleton head dimension broadcasts to h
46 out = torch.matmul(attn, V) # [B, h, N, d_k]
47 out = out.transpose(1, 2).contiguous().view(B, N, D)
48 return self.W_o(out)
49
50# Shape smoke test
51mqa = MultiQueryAttention(d_model=64, n_heads=4)
52x = torch.randn(2, 10, 64) # batch=2, seq=10
53out = mqa(x)
54print("MQA output shape:", out.shape)1MQA output shape: torch.Size([2, 10, 64])MQA saves the most memory, but forcing all query heads to share one K/V pair can reduce quality on some workloads. The compromise is to split the lanes into groups, with each group sharing one route note. That's GQA: less aggressive than MQA (one note for everyone), but more memory-efficient than MHA (one note per query head).
GQA[5] is the compromise between MHA and MQA. Instead of 1 KV head (MQA) or KV heads (MHA), use KV groups where :
With groups and total query heads, each group of queries shares one K and one V. For example, Mistral 7B's 32 query heads with 8 KV groups means 4 queries share each KV pair. That's a 4:1 ratio: 4x less KV cache than MHA, but with significantly more representational diversity than MQA's single KV pair.
Think of GQA like a logistics hub: instead of every driver carrying their own copy of the delivery manifest (MHA) or one manifest for the entire fleet (MQA), you organize drivers into pods that share a manifest. Each pod can still take different routes (different queries), but it shares cached reference data.
1def kv_group_for_query(query_head: int, query_heads: int, kv_heads: int) -> int:
2 assert query_heads % kv_heads == 0
3 return query_head // (query_heads // kv_heads)
4
5assignments = [kv_group_for_query(head, query_heads=8, kv_heads=2) for head in range(8)]
6print("query to KV group:", assignments)
7print("cache reduction:", 8 // 2, "x")1query to KV group: [0, 0, 0, 0, 1, 1, 1, 1]
2cache reduction: 4 xTo understand how this looks in production, let's look at the configurations used by several popular models:
| Model | Total heads () | KV groups () | Ratio |
|---|---|---|---|
| Qwen2.5-72B[2] | 64 | 8 | 8:1 |
| Llama 2 70B[6] | 64 | 8 | 8:1 |
| Mistral 7B[7] | 32 | 8 | 4:1 |
| Gemma 2 9B[8] | 16 | 8 | 2:1 |
Beyond memory savings, GQA can make tensor-parallel serving cleaner. If num_kv_heads divides the tensor-parallel degree, each shard can own an even slice of KV heads. If it doesn't divide cleanly, runtimes often replicate some KV heads instead, which wastes memory.
1def sharding_plan(kv_heads: int, tensor_parallel: int) -> str:
2 if kv_heads % tensor_parallel == 0:
3 return f"even: {kv_heads // tensor_parallel} KV heads per shard"
4 return "runtime may need replication or uneven ownership"
5
6print("TP=4:", sharding_plan(kv_heads=8, tensor_parallel=4))
7print("TP=6:", sharding_plan(kv_heads=8, tensor_parallel=6))1TP=4: even: 2 KV heads per shard
2TP=6: runtime may need replication or uneven ownershipThese are per-request KV cache sizes at the given sequence length:
| Method | KV groups | KV cache per request | Quality posture |
|---|---|---|---|
| MHA | 64 | ~10.7 GB | Reference architecture |
| GQA-8 | 8 | ~1.34 GB | Validate converted or trained model on task evals |
| MQA | 1 | ~0.17 GB | Strongest sharing constraint; evaluate carefully |
In Ainslie et al., intermediate group counts recovered much of MHA quality while retaining MQA-like inference benefits after uptraining.[5] The correct group count for another model remains an architecture and evaluation choice, not a guarantee inherited from that experiment.
Different model families have made distinct architectural choices based on their serving requirements and quality targets:
| Architecture | Typical usage | Design Rationale |
|---|---|---|
| MHA | Older decoder designs, or deployments where KV memory is less constrained | Maximum per-head flexibility, highest KV-cache cost |
| MQA | Serving-first deployments that need the smallest possible KV cache | Aggressive memory reduction with the strongest sharing constraint |
| GQA | Published models such as Llama 2 34B/70B, Qwen2.5-72B, Mistral 7B, and Gemma 2 9B | Intermediate KV-head count; evaluate quality and serving together |
A key practical insight for teams adapting older models: you don't need to train a GQA architecture from scratch. You can convert an existing MHA model to GQA through a process called "uptraining."
Key insight: Converting MHA to GQA is like a company merger. Several departments each had their own filing system (KV heads). Instead of rebuilding from scratch, you merge departments into groups, average their files, and then give the new organization time to adapt through continued training. Quality must be evaluated after the conversion.
The conversion has two main steps:
The original GQA paper[5] showed that this recipe recovers most of the lost quality with about 5% of the original pretraining compute. For a fresh pretraining run, teams usually bake the desired KV-head pattern into the architecture from day one. Uptraining matters when you inherit an older MHA checkpoint and want serving gains without restarting pretraining.
1def mean_pool_heads(heads: list[list[float]], group_size: int) -> list[list[float]]:
2 assert len(heads) % group_size == 0
3 pooled: list[list[float]] = []
4 for start in range(0, len(heads), group_size):
5 group = heads[start : start + group_size]
6 pooled.append([
7 sum(values) / len(group)
8 for values in zip(*group)
9 ])
10 return pooled
11
12mha_k_heads = [[1.0, 3.0], [3.0, 5.0], [10.0, 12.0], [14.0, 16.0]]
13print("GQA initial K heads:", mean_pool_heads(mha_k_heads, group_size=2))1GQA initial K heads: [[2.0, 4.0], [12.0, 14.0]]GQA reduces the number of distinct KV heads stored in the cache. Multi-Head Latent Attention (MLA) takes a fundamentally different route: it reduces the dimensionality of what must be stored per position by learning a low-rank latent representation.
Instead of caching full per-head Key and Value content vectors for every token, the model does the following:
Input hidden states are projected into a compact latent vector of dimension .
Only this compact latent state (plus a small amount of decoupled positional information) is written into the KV cache.
The architecture defines learned up-projections from that latent state. In an optimized inference implementation, those projection matrices can be absorbed into the query and output paths so decode does not materialize full per-head cached K/V again.[9]
DeepSeek-V2 also separates a positional RoPE component from compressed content so the cache can retain the required position-dependent term without preventing content compression.[9]
DeepSeek-V2 reports a 93.3% reduction in KV cache memory footprint compared to a comparable 67B baseline model, alongside its reported benchmark results.[9] The savings are not a simple head-count ratio; they depend on the chosen latent width , positional component, and execution path.
| Aspect | GQA | MLA |
|---|---|---|
| Mechanism | Fewer KV heads () | Low-rank latent compression + up-projection |
| Kernel compatibility | Fits GQA-aware attention kernels | Needs an MLA-aware latent-cache implementation |
| Compression measure | KV-head ratio gives exact cache-factor comparison | Cached latent width and positional component determine savings |
| Published examples | Llama 2 70B, Qwen2.5-72B, Mistral 7B, Gemma 2 9B | DeepSeek-V2 |
Why the runtime distinction matters. GQA slots into grouped-attention implementations and familiar tensor-parallel strategies. MLA changes the cached representation and needs a compatible execution path. A smaller cache on paper is only useful when the selected runtime implements that path efficiently.
This architectural fork appears in production interviews: "When would you choose GQA over a more aggressive compression scheme like MLA?" The defensible answer is that GQA has a simpler K/V-head contract, while MLA can compress harder if the selected model and runtime both support its latent path.
Using DeepSeek-V2's documented 512-dimensional compressed K/V latent and 64-dimensional decoupled key component, the raw cached-value comparison looks like this.[9]
1def cached_values_per_token(kv_heads: int, head_dim: int) -> int:
2 return 2 * kv_heads * head_dim
3
4gqa_values = cached_values_per_token(kv_heads=8, head_dim=128)
5mla_example_values = 512 + 64 # compact content latent plus positional component
6print("GQA cached values:", gqa_values)
7print("MLA-style cached values:", mla_example_values)
8print("compare actual model dimensions before deployment:", True)1GQA cached values: 2048
2MLA-style cached values: 576
3compare actual model dimensions before deployment: TrueReducing KV cache doesn't guarantee an equal throughput gain, but it does raise the memory-limited concurrency ceiling. Using the same 80-layer, 64-query-head, 128-dim-head example:
| Config | KV Cache per Request | Relative Memory-Limited Concurrency Ceiling |
|---|---|---|
| MHA | ~10.7 GB | 1x |
| GQA-8 | ~1.34 GB | ~8x |
| MQA | ~0.17 GB | ~64x |
If KV cache is what limits batch size, those ratios are first-order good estimates. Real systems land below that theoretical ceiling because weights, scheduler overhead, kernel efficiency, and interconnect traffic still matter.[3]
KV savings become larger in absolute bytes with long-context models. For the same 72B-style dimensions, per-request cache grows linearly with sequence length:
| Sequence Length | MHA KV Cache (approx.) | GQA-8 KV Cache (approx.) | Savings |
|---|---|---|---|
| 4K | ~10.7 GB | ~1.34 GB | 8x |
| 32K | ~85.9 GB | ~10.7 GB | 8x |
| 128K | ~343.6 GB | ~42.9 GB | 8x |
| 1M | ~2.75 TB | ~343.6 GB | 8x |
At long context, head-sharing or another cache compression strategy becomes an explicit capacity decision. At 128K, full MHA uses roughly 344 GB of KV state per request in this example, before weights, allocator fragmentation, or extra concurrency. GQA-8 drops that cache to about 43 GB, which is still expensive and still needs an end-to-end memory budget.
1def kv_cache_gb(kv_heads: int, tokens: int) -> float:
2 return 2 * 80 * tokens * kv_heads * 128 * 2 / 1e9
3
4mha_128k = kv_cache_gb(kv_heads=64, tokens=131_072)
5gqa_128k = kv_cache_gb(kv_heads=8, tokens=131_072)
6print(f"MHA 128K cache: {mha_128k:.1f} GB")
7print(f"GQA-8 128K cache: {gqa_128k:.1f} GB")
8print("GQA cache plus 140 GB weights fits in 4x80 GB raw:", gqa_128k + 140 <= 320)1MHA 128K cache: 343.6 GB
2GQA-8 128K cache: 42.9 GB
3GQA cache plus 140 GB weights fits in 4x80 GB raw: TrueCommon mistake: Treating long-context support as an architectural context-window claim rather than a serving budget. A long support history, codebase prompt, or retrieved document bundle can consume the same KV allocation. Calculate active-token memory before promising concurrency.
When building or working with inference engines, you must handle the KV head expansion efficiently. Here is a conceptual implementation of how an engine matches the smaller number of KV heads to the larger number of query heads before computing attention.
1import math
2import torch
3import torch.nn.functional as F
4
5def gqa_attention_reference(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
6 n_heads: int, n_kv_heads: int) -> torch.Tensor:
7 """
8 Readable GQA implementation concept.
9
10 Args:
11 Q: [batch, n_heads, seq, d_k]
12 K: [batch, n_kv_heads, seq, d_k]
13 V: [batch, n_kv_heads, seq, d_k]
14 n_heads: Number of query heads
15 n_kv_heads: Number of KV heads
16 """
17 # Expand KV heads to match Q heads.
18 # This materializes repeated K/V for clarity. Production serving kernels
19 # compute grouped attention without building these larger tensors.
20 heads_per_group = n_heads // n_kv_heads
21
22 # [batch, n_kv_heads, seq, d_k] -> [batch, n_kv_heads * group_size, seq, d_k]
23 K_expanded = K.repeat_interleave(heads_per_group, dim=1)
24 V_expanded = V.repeat_interleave(heads_per_group, dim=1)
25
26 # Standard attention from here
27 d_k = Q.size(-1)
28 scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / math.sqrt(d_k)
29 attn = F.softmax(scores, dim=-1)
30 return torch.matmul(attn, V_expanded)
31
32# Shape smoke test
33B, N, h, h_kv, d = 1, 8, 32, 8, 64
34Q = torch.randn(B, h, N, d)
35K = torch.randn(B, h_kv, N, d)
36V = torch.randn(B, h_kv, N, d)
37out = gqa_attention_reference(Q, K, V, h, h_kv)
38print("GQA output shape:", out.shape)1GQA output shape: torch.Size([1, 32, 8, 64])Modern kernels (FlashAttention[10], FlashInfer[11]) handle GQA without materializing the expanded K/V tensors. This avoids the memory overhead of repeat_interleave and computes grouped attention directly in the fused kernel.
Production tip: When evaluating models for serving, total parameter count isn't enough.
num_kv_heads, context length, and KV-cache dtype often dominate concurrency. GQA-8 raises the memory-limited concurrency ceiling by roughly 8x versus same-dimension MHA, but realized cost per query still depends on batching, quantization, and kernels.
In serving code, the hardest GQA bug is often not the attention formula. It's updating the cache correctly one token at a time.
For one decode step, the runtime should:
The cache shape should use KV heads, not query heads:
1K_cache: (batch, n_kv_heads, max_seq, d_k)
2V_cache: (batch, n_kv_heads, max_seq, d_k)For an MHA model, n_kv_heads == n_heads. For MQA, n_kv_heads == 1. For GQA, 1 < n_kv_heads < n_heads.
1def append_kv(cache: list[str], cursor: int, token_value: str) -> int:
2 if cursor != len(cache):
3 raise ValueError("cursor would overwrite or skip cache state")
4 cache.append(token_value)
5 return cursor + 1
6
7cache: list[str] = []
8cursor = append_kv(cache, cursor=0, token_value="token-0-kv")
9cursor = append_kv(cache, cursor=cursor, token_value="token-1-kv")
10print("cache length:", len(cache))
11try:
12 append_kv(cache, cursor=1, token_value="stale-write")
13except ValueError as exc:
14 print("blocked:", exc)1cache length: 2
2blocked: cursor would overwrite or skip cache stateCommon cache-update bugs look like this:
| Symptom | Likely bug | Check |
|---|---|---|
| answer quality degrades after a few tokens | overwrote position t - 1 instead of appending at t | print cache length after every decode step |
| works at batch size 1, fails under batching | used one global cache length for all requests | store per-request positions |
| GQA memory is still huge | allocated cache with query heads instead of KV heads | assert K_cache.size(1) == n_kv_heads |
| long-context output becomes incoherent | applied RoPE with local chunk position instead of absolute position | log the position id used for each appended token |
That's why cache tests should compare a cached decode path against a full-prefix recompute path on the same tiny prompt. The logits should match closely for the next token. If they don't, the bug is usually position IDs, mask shape, or cache append order.
By this point, you should be able to reason from the architecture diagram to serving cost, not just name the attention variant.
| Skill | What a strong answer includes |
|---|---|
| KV-cache sizing | Full cache math multiplies K/V tensors, layers, sequence length, batch size, KV heads, head dimension, and bytes per element. |
| MHA vs MQA vs GQA | MHA stores one K/V pair per query head, MQA stores one shared K/V pair, and GQA stores an intermediate number of K/V groups. |
| Memory reduction factor | Savings follow num_query_heads / num_key_value_heads for the KV cache, such as 4x for Mistral 7B and 8x for Qwen2.5-72B. |
| Quality tradeoff | Fewer KV heads reduce memory traffic, but they also reduce K/V representational capacity. |
| Uptraining | Existing MHA checkpoints can be converted by mean-pooling K/V heads and continuing language-model training for a small compute budget. |
| Serving fit | GQA appears in the cited model architectures and can raise memory-limited concurrency when the runtime supports its head layout efficiently. |
| GQA vs MLA choice | GQA shrinks the number of full K/V heads. MLA stores a compact latent and needs an MLA-aware inference path that can avoid full cached K/V materialization. The right choice depends on runtime support and cache pressure. |
| Memory win vs throughput win | KV-cache savings can raise the memory-limited batch ceiling without translating into the same wall-clock speedup if compute, prefill, or scheduler overhead still dominates. |
For 80 layers, 64 query heads, head dimension 128, FP16 cache, and full MHA, the per-token-per-layer cache is 2 * 64 * 128 * 2 = 32 KB. Across all 80 layers, that is about 2.5 MiB per token. At 128K context, one request needs about 344 GB of KV cache under full MHA. GQA-8 cuts that to about 43 GB per request, which is much better but still too large for comfortable single-node serving. Strong answers mention head sharing plus paging, scheduling, retrieval, or KV-cache quantization.
Use uptraining. Group old K/V heads, mean-pool each group into one new K head and one new V head, then continue language-model training for a small fraction of original pretraining budget so the model adapts to reduced K/V capacity. After that, validate both task quality and serving behavior: perplexity or downstream evals, KV-cache size, tensor-parallel sharding fit, and decode throughput under realistic batch and context settings. Ainslie et al. used this path to recover much of the quality lost by reducing KV heads.[5]
num_key_value_heads = 8 and tensor parallel degree is 6. Why is that awkward?GQA is easiest when KV heads divide cleanly across tensor-parallel shards. With 8 KV heads and TP degree 6, some runtimes must replicate KV data or use uneven ownership. That wastes memory and cuts into the concurrency gain you expected from GQA.
Choose GQA when your selected runtime supports grouped heads and its sharding layout fits the model. Choose MLA when your model is MLA-native, your stack ships its latent-cache path, and cache pressure is dominant enough to justify that execution contract.
Check that cache shape uses n_kv_heads, not query heads, and that absolute positions, cache cursors, and append order all stay correct. The quickest proof is a tiny-prompt parity test: cached decode logits should match full-prefix recompute closely at each step.
Symptom: You calculate that a quantized model fits in GPU memory, then the server still runs out of memory under long context or high concurrency.
Cause: Weights are static model parameters. The KV cache is dynamic per-request state. GQA mainly shrinks dynamic K/V activations, not the feed-forward layers or the full weight tensor.
Fix: Budget weights and KV cache separately. Then add activation buffers, allocator slack, scheduler state, and fragmentation.
Symptom: You benchmark training throughput and see almost no change after switching from MHA to GQA.
Cause: Training often has a different bottleneck from incremental decoding. GQA reduces decode cache state, while training processes whole sequences and may remain dominated by large attention and feed-forward computations.
Fix: Measure decode throughput (tokens per second during autoregressive generation), not training throughput. GQA's wins show up when you're repeatedly loading the KV cache for each new token, not when you're computing the full attention matrix once.
Symptom: You reject MQA outright even for small, latency-sensitive models where the quality difference isn't visible on your task.
Cause: MQA imposes the strongest sharing constraint, but impact depends on model size, task, and training recipe. Treating it as always catastrophic is as wrong as treating it as free.
Fix: Compare task evals and serving metrics for your workload. Use GQA when you need a safer quality/serving compromise.
Symptom: You calculate a huge KV cache reduction, but real-world latency or cost only improves by a fraction of that.
Cause: KV cache is one piece of the puzzle. Model weights, FFN compute, attention arithmetic, scheduler overhead, and interconnect traffic still matter. If your batch was previously limited by compute rather than memory, shrinking the cache won't move the needle as much.
Fix: Profile end-to-end latency with a realistic batch size. Use a tool like NVIDIA Nsight Systems or vLLM's built-in metrics to see whether you're memory-bound or compute-bound before betting on GQA alone.
Symptom: Out-of-memory errors when you naively expand KV heads with repeat_interleave inside a training loop.
Cause: repeat_interleave materializes a larger tensor in memory. For 32 query heads and 8 KV groups, that temporarily creates a 4x larger K/V tensor before the matmul.
Fix: Use FlashAttention or FlashInfer, which handle GQA natively without materializing the expanded tensor. If you must write a custom kernel, use an implicit broadcast or fused kernel rather than explicit expansion.
Symptom: The converted checkpoint loads, but quality drops sharply or tensor-parallel shards disagree after serving rollout.
Cause: Head grouping is not arbitrary. Query and KV heads often follow a fixed ordering in the weight matrices, and tensor-parallel sharding can interleave or chunk that order in specific ways. Regrouping without respecting the original layout silently changes which heads share a K/V projection.
Fix: Preserve the model's published head ordering during regrouping, then run a tiny parity test and shard-level smoke test before broader evals. Don't assume head 0-7 is always the intended first group.
Symptom: Tensor-parallel serving uses more memory than expected, or some shards replicate K/V state.
Cause: num_key_value_heads interacts with the tensor-parallel degree. Awkward divisibility can force replication or uneven ownership in some runtimes.
Fix: Check model config before choosing TP degree. Prefer serving layouts where KV heads divide cleanly across shards.
Try this without looking back at the tables:
A model has 40 layers, 32 query heads, 8 KV groups, head dimension 128, and FP16 KV-cache elements. You serve a batch of 8 requests, each at 2,048 tokens. How large is the KV cache under GQA? Show your work.
Compare to MHA: 32 KV heads instead of 8 gives a 4x larger cache, so the same batch would need roughly 10 GiB just for the KV cache.
num_query_heads / num_kv_heads. A 64-query-head model with 8 KV heads gets 8x savings; a 32-query-head model with 8 KV heads gets 4x.Attention Is All You Need.
Vaswani, A., et al. · 2017
Qwen2.5 Technical Report
Qwen Team · 2024
Efficiently Scaling Transformer Inference.
Pope, R., et al. · 2023 · arXiv preprint
Fast Transformer Decoding: One Write-Head is All You Need.
Shazeer, N. · 2019 · arXiv preprint
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.
Ainslie, J., et al. · 2023 · EMNLP 2023
Llama 2: Open Foundation and Fine-Tuned Chat Models.
Touvron, H., et al. · 2023 · arXiv preprint
Mistral 7B.
Jiang, A. Q., et al. · 2023
Gemma 2: Improving Open Language Models at a Practical Size
Gemma Team, Google DeepMind · 2024
DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
DeepSeek-AI · 2024
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.
Dao, T. · 2023 · ICLR 2024
FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving.
Ye, Z., et al. · 2025