Understand the main forms of knowledge distillation for LLMs, from logit matching and response-based supervision to on-policy KD. Learn when distillation helps, where student capacity becomes the bottleneck, and how to implement a correct teacher-student training loop.
RLVR trained a policy against checked outcomes where a verifier exists. Knowledge distillation asks a different deployment question: after you have a useful large language model (LLM) teacher, how do you transfer selected behavior into a smaller model that's cheaper to serve?
Knowledge distillation trains a smaller student model to imitate useful behavior from a larger or otherwise more capable teacher. Imagine a return-policy chatbot: a large model performs well on an evaluated question set, but running it for every user costs too much. A smaller student is viable only if it retains enough answer quality under a latency and serving-cost budget.
The transfer signal depends on teacher access. In white-box settings, the student can match softened token probabilities or internal features. In black-box settings, it can fine-tune on selected teacher-written answers, solution , or synthetic corpora. Each channel carries different information and different failure modes. The goal isn't to assume the student inherits the teacher; it's to transfer useful behavior and verify the resulting quality-cost tradeoff.
The flow below shows the white-box KD case. Response distillation uses selected teacher text instead of the teacher-probability branch.
In machine learning, the core idea (introduced by Hinton et al. in 2015) is to train a student model to mimic a teacher model's behavior, rather than only the ground truth labels. The student learns from the teacher's full probability distribution, which contains richer information than simple one-hot labels:[1]
Where matches teacher behavior, matches ground-truth labels, and controls the tradeoff. This combined loss makes sure the student learns from both the teacher's soft predictions and the true labels.
Suppose a customer asks, "How long do I have to return an electronic item?" The teacher's output distribution contains dark knowledge: relationships between classes that hard labels don't capture.
The hard label only says "30 days" is correct. In this simplified answer-choice example, the soft label also records that the teacher ranks "14 days" above "1 year." That extra signal is useful only if the teacher's ranking is itself useful; it is not proof that the student learned a general return-policy rule.
The probabilities above describe one simplified prediction decision. For a causal LLM, logit distillation applies this idea at each predicted token position. exposes more of the teacher's ranking, but it can also flatten the distribution until very little preference signal remains.
1import math
2
3logits = {"30": 4.0, "14": 2.0, "90": 1.0, "1": 0.7}
4
5def softened_probabilities(temperature: float) -> dict[str, float]:
6 scaled = {token: math.exp(logit / temperature) for token, logit in logits.items()}
7 total = sum(scaled.values())
8 return {token: value / total for token, value in scaled.items()}
9
10for temperature in (1.0, 4.0, 40.0):
11 probs = softened_probabilities(temperature)
12 rounded = {token: round(probability, 3) for token, probability in probs.items()}
13 top_gap = probs["30"] - probs["14"]
14 print(f"T={temperature:g} probabilities:", rounded, "top_gap:", round(top_gap, 3))1T=1 probabilities: {'30': 0.818, '14': 0.111, '90': 0.041, '1': 0.03} top_gap: 0.708
2T=4 probabilities: {'30': 0.397, '14': 0.241, '90': 0.188, '1': 0.174} top_gap: 0.156
3T=40 probabilities: {'30': 0.263, '14': 0.25, '90': 0.244, '1': 0.242} top_gap: 0.013Distillation is most useful when you already have a strong teacher, legal access to its signal, and a clear smaller deployment target. It does not win automatically over training from scratch. The Gemma 2 report gives a controlled example: its authors train the 2B and 9B models with token-probability distillation, and a 2B ablation trained for 500B tokens scores 67.7 when distilled from a 7B teacher versus 60.3 from scratch on their three-benchmark average.[2]
Different recipes expose different supervision channels: Orca trains on explanation traces, phi-1.5 uses curated synthetic textbook-like data, and Gemma 2 uses teacher token probabilities for small models.[3][4][2] They motivate careful data and signal selection; they do not establish one universally best recipe.
This method requires access to the teacher model's logits (the raw, unnormalized scores output by the final layer of the network before the softmax function). We minimize the KL Divergence (Kullback-Leibler divergence), a mathematical measure of how one probability distribution differs from another, to align the student's probability distribution with the teacher's.
First, we apply temperature scaling (dividing logits by a temperature before softmax to soften the probability distribution) to both models' logits to get softened probability distributions:
Where are the teacher's logits, are the student's logits, and is the temperature. Then we compute the KL divergence loss:
For causal LMs, two implementation details matter. First, next-token training requires a one-token shift: logits at position train against the label at position . Second, direct logit KD assumes teacher and student use the same token-to-id output mapping. Equal vocabulary sizes alone are insufficient: token id 42 must denote the same token for both models. If output spaces differ, plain token-level KL no longer lines up cleanly and you usually fall back to response distillation or design an explicit mapping. The snippet below handles the shift, masks ignored positions, and fails fast on a reordered vocabulary.
Common mistake: Running logit distillation without verifying tokenizer alignment. Two models can have the same vocabulary size and different token-id mappings. Compare the complete output mapping, not only
vocab_size, before training.
1import torch
2import torch.nn.functional as F
3
4def distillation_loss(
5 student_logits: torch.Tensor,
6 teacher_logits: torch.Tensor,
7 labels: torch.Tensor,
8 student_vocabulary: tuple[str, ...],
9 teacher_vocabulary: tuple[str, ...],
10 temperature: float = 3.0,
11 alpha: float = 0.5,
12 ignore_index: int = -100,
13) -> torch.Tensor:
14 """
15 Computes the weighted sum of Knowledge Distillation (KD) loss and Cross-Entropy loss.
16 """
17 if student_vocabulary != teacher_vocabulary:
18 raise ValueError(
19 "Logit KD requires identical token-to-id mappings. "
20 "Use response KD or design an explicit mapping when output spaces differ."
21 )
22 if student_logits.size(-1) != teacher_logits.size(-1) or student_logits.size(-1) != len(student_vocabulary):
23 raise ValueError("Logit tensors and vocabulary dimensions must agree.")
24
25 # Causal LMs predict token t+1 from positions up to t.
26 shift_student = student_logits[..., :-1, :].contiguous()
27 shift_teacher = teacher_logits[..., :-1, :].contiguous()
28 shift_labels = labels[..., 1:].contiguous()
29
30 vocab_size = shift_student.size(-1)
31 flat_student = shift_student.reshape(-1, vocab_size)
32 flat_teacher = shift_teacher.reshape(-1, vocab_size)
33 flat_labels = shift_labels.reshape(-1)
34
35 valid_mask = flat_labels != ignore_index
36 if not valid_mask.any():
37 return student_logits.sum() * 0
38
39 student_valid = flat_student[valid_mask]
40 teacher_valid = flat_teacher[valid_mask]
41 labels_valid = flat_labels[valid_mask]
42
43 # Soft loss: KL divergence between softened distributions
44 soft_teacher = F.softmax(teacher_valid.detach() / temperature, dim=-1)
45 soft_student = F.log_softmax(student_valid / temperature, dim=-1)
46
47 # KLDivLoss expects log-probabilities for the input (student)
48 # and standard probabilities for the target (teacher)
49 soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
50 soft_loss *= temperature ** 2 # Scale by T² for gradient magnitude
51
52 # Hard loss: standard cross-entropy with ground truth
53 hard_loss = F.cross_entropy(student_valid, labels_valid)
54
55 return alpha * soft_loss + (1 - alpha) * hard_loss
56
57torch.manual_seed(7)
58batch, seq_len, vocab = 2, 5, 8
59student_logits = torch.nn.Parameter(torch.randn(batch, seq_len, vocab))
60teacher_logits = torch.randn(batch, seq_len, vocab)
61labels = torch.tensor([
62 [0, 1, 2, 3, 4],
63 [0, 2, -100, 5, 6],
64])
65vocabulary = tuple(f"token_{index}" for index in range(vocab))
66
67loss = distillation_loss(
68 student_logits,
69 teacher_logits,
70 labels,
71 vocabulary,
72 vocabulary,
73 temperature=3.0,
74 alpha=0.6,
75)
76loss.backward()
77
78loss_is_scalar = loss.ndim == 0
79loss_is_finite = bool(torch.isfinite(loss))
80grad_exists = student_logits.grad is not None
81grad_is_finite = bool(torch.isfinite(student_logits.grad).all()) if grad_exists else False
82mismatch_failed = False
83
84try:
85 distillation_loss(student_logits, teacher_logits, labels, vocabulary, tuple(reversed(vocabulary)))
86except ValueError as exc:
87 mismatch_failed = "token-to-id mappings" in str(exc)
88
89print("loss:", round(float(loss), 4))
90print("grad norm:", round(float(student_logits.grad.norm()), 4))
91print("loss_is_scalar:", loss_is_scalar)
92print("loss_is_finite:", loss_is_finite)
93print("grad_is_finite:", grad_is_finite)
94print("mismatch failed:", mismatch_failed)1loss: 1.4959
2grad norm: 0.1891
3loss_is_scalar: True
4loss_is_finite: True
5grad_is_finite: True
6mismatch failed: TrueWhen you don't have access to the teacher's weights or logits, but can obtain generated text, response distillation is the available KD channel. In this approach, the teacher generates text responses for prompts, and the student fine-tunes on selected (prompt, response) pairs.
This is Supervised Fine-Tuning (SFT) on teacher-generated targets, not ground truth. Because an API commonly provides text rather than full token probabilities, the training channel is less detailed than direct logit access. A teacher can provide worked solutions, decomposed subproblems, critiques, or multiple candidate answers, but these outputs should be filtered with task-specific checks where possible. For a warehouse-routing assistant, a generated route is useful training data only after checks for capacity, delivery-window, and policy constraints.
text-davinci-003.Response distillation and synthetic-data training overlap when a stronger model generates the selected targets. Calling a dataset "distillation" should not bypass evaluation: generated traces can be incorrect, stylistically misleading, contaminated, or out of scope for the intended student.
1generated = [
2 {"prompt": "sealed electronics return", "teacher": "30 days", "verified": "30 days"},
3 {"prompt": "final-sale item return", "teacher": "30 days", "verified": "not eligible"},
4 {"prompt": "defective item warranty", "teacher": "warranty process", "verified": "warranty process"},
5]
6
7accepted = [
8 example for example in generated
9 if example["teacher"] == example["verified"]
10]
11rejected = [
12 example["prompt"] for example in generated
13 if example["teacher"] != example["verified"]
14]
15
16print("generated:", len(generated))
17print("accepted:", len(accepted))
18print("rejected prompts:", rejected)
19print("teacher text is trusted label:", len(rejected) == 0)1generated: 3
2accepted: 2
3rejected prompts: ['final-sale item return']
4teacher text is trusted label: FalseLogit distillation matches output distributions. With white-box access, a training objective can also match selected student hidden states to selected teacher hidden states through a learned projection.
Where and are hidden states at layer , and projects student features into the teacher feature space before comparison.
Since the student commonly has different hidden dimensions, maps the student's representation into the teacher comparison space. Feature matching introduces extra decisions: which layers correspond, how the projection is trained, and whether its additional compute improves held-out outcomes. Hidden-state access is a richer interface, not a guarantee of a better student.
| Method | Teacher signal | Main advantage | Main constraint |
|---|---|---|---|
| Response KD | Selected text outputs | Works without white-box access | Teacher errors become SFT targets unless filtered |
| Logit KD | Token probabilities | Preserves distribution information | Requires aligned output space or an explicit mapping |
| Feature KD | Selected hidden states | Exposes intermediate representations | Needs layer/projection design and more storage or compute |
| On-policy KD | Teacher scores on student samples | Visits prefixes the student actually produces | Requires online sampling and teacher evaluation |
When minimizing KL divergence for language generation, direction matters. Classical distillation commonly minimizes Forward KL (teacher || student), which penalizes a student for missing probability mass that the teacher assigns to continuations. When a small student cannot model the teacher distribution well, this pressure may be costly.
Reverse KL (student || teacher) places more pressure on probability mass the student assigns where the teacher assigns little. MiniLLM reports improvements over its studied standard-KD baselines using reverse KL with an on-policy optimization algorithm in instruction-following experiments.[8] GKD evaluates multiple divergences and explicitly reports that the best divergence depends on the task and diversity-performance tradeoff.[9]
| Direction | Formula | Behavior | Common fit |
|---|---|---|---|
| Forward KL | Mean-seeking, covers more of the teacher distribution | Classic KD when broad coverage matters | |
| Reverse KL | Penalizes student mass in teacher-low-probability regions | Candidate objective to evaluate for generation |
No divergence is the default winner for every task. Measure task quality, diversity, calibration, and failure rates under the actual decoding setup.
1import math
2
3teacher = {"safe": 0.58, "alternate": 0.40, "bad": 0.02}
4students = {
5 "covers_teacher": {"safe": 0.54, "alternate": 0.36, "bad": 0.10},
6 "adds_bad_mass": {"safe": 0.40, "alternate": 0.35, "bad": 0.25},
7}
8
9def kl(left: dict[str, float], right: dict[str, float]) -> float:
10 return sum(prob * math.log(prob / right[token]) for token, prob in left.items())
11
12for name, student in students.items():
13 forward = kl(teacher, student)
14 reverse = kl(student, teacher)
15 print(name, "forward:", round(forward, 3), "reverse:", round(reverse, 3))
16
17print("choose objective from evaluation, not slogan")1covers_teacher forward: 0.051 reverse: 0.084
2adds_bad_mass forward: 0.218 reverse: 0.436
3choose objective from evaluation, not sloganOff-policy (standard) distillation trains the student on teacher outputs for a static dataset. During inference, the student generates its own tokens, causing distribution shift. Errors compound as the student drifts from training data (exposure bias).
On-policy methods like Generalized Knowledge Distillation (GKD) sample sequences from the student, then compare student and teacher token distributions on the prefixes the student produced. GKD can mix fixed outputs and student-generated outputs through a student-data fraction ; it does not require a natural-language critique.[9] The tradeoff is computational: both student sampling and teacher scoring run during training. It is useful when fixed teacher data misses prefixes that the deployed student commonly enters, but the benefit must be measured per task.
1fixed_teacher_prefixes = {
2 "return sealed electronics",
3 "return unopened clothing",
4}
5student_generated_prefixes = {
6 "return sealed electronics",
7 "return opened final-sale electronics",
8 "return item without receipt",
9}
10
11unseen_in_fixed_data = student_generated_prefixes - fixed_teacher_prefixes
12teacher_scored_prefixes = fixed_teacher_prefixes | student_generated_prefixes
13
14print("fixed prefixes:", len(fixed_teacher_prefixes))
15print("student prefixes needing new teacher scores:", sorted(unseen_in_fixed_data))
16print("scored after on-policy collection:", len(teacher_scored_prefixes))1fixed prefixes: 2
2student prefixes needing new teacher scores: ['return item without receipt', 'return opened final-sale electronics']
3scored after on-policy collection: 4The temperature parameter controls how much we smooth the teacher distribution used in a logit loss. It changes which relative token preferences are visible to the loss; it does not guarantee that those preferences are useful or that the student can represent them.
Mathematically, the softmax function converts raw logits into probabilities using an exponential function. When one logit is significantly larger than the rest, the exponential makes it dominate the entire probability mass, crushing the others to near-zero. By dividing all logits by a temperature before applying the exponential, we reduce the relative difference between the largest logit and the smaller ones. This prevents the top choice from monopolizing the probability space and allows the relative scores of the "incorrect" choices to emerge.
Imagine the teacher sees the prompt "Return policy for electronics?" and produces these raw logits for the next token:
| Token | Raw logit |
|---|---|
30 | 4.0 |
14 | 2.0 |
90 | 1.0 |
1 | 0.7 |
At , the probabilities are sharp: 30 gets about 0.82, 14 gets 0.11, 90 gets 0.04, and 1 gets 0.03. The student sees only a weak signal that 90 is a more reasonable guess than 1.
At , after dividing each logit by 4 and applying softmax, the probabilities spread out: 30 gets about 0.40, 14 gets 0.24, 90 gets 0.19, and 1 gets 0.17. The distillation loss now exposes more of the teacher's ranking among alternatives.
Key insight: Temperature changes the training target, not the trusted answer. If the teacher ranks an invalid option highly, a softened loss makes that error easier to copy. Use labels or checks and held-out evaluation alongside KD.
30). The relationships between incorrect choices (14 vs 90) are hidden because their probabilities are near zero.In practice, tune temperature rather than assuming a universal constant. The right value depends on the teacher distribution, soft-loss weight, task, and evaluation metrics. Inspect probability spreads as in the executable example above, then compare held-out quality.
A typical distillation training loop involves a frozen teacher model and a trainable student model. We forward the same input through both models and compute the combined loss. The local example below uses tiny PyTorch models so you can test the mechanics without downloading a real teacher.
1import torch
2from torch import nn
3import torch.nn.functional as F
4
5class TinyLM(nn.Module):
6 def __init__(self, vocab_size: int, hidden_size: int):
7 super().__init__()
8 self.embedding = nn.Embedding(vocab_size, hidden_size)
9 self.output = nn.Linear(hidden_size, vocab_size)
10
11 def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
12 return self.output(self.embedding(input_ids))
13
14def kd_loss(student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
15 temperature = 3.0
16 alpha = 0.7
17
18 shift_student = student_logits[:, :-1, :]
19 shift_teacher = teacher_logits[:, :-1, :]
20 shift_labels = labels[:, 1:]
21
22 student_flat = shift_student.reshape(-1, shift_student.size(-1))
23 teacher_flat = shift_teacher.reshape(-1, shift_teacher.size(-1))
24 labels_flat = shift_labels.reshape(-1)
25
26 soft_teacher = F.softmax(teacher_flat / temperature, dim=-1)
27 soft_student = F.log_softmax(student_flat / temperature, dim=-1)
28 soft_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * temperature**2
29 hard_loss = F.cross_entropy(student_flat, labels_flat)
30 return alpha * soft_loss + (1 - alpha) * hard_loss
31
32torch.manual_seed(0)
33vocab_size = 12
34teacher = TinyLM(vocab_size=vocab_size, hidden_size=16)
35student = TinyLM(vocab_size=vocab_size, hidden_size=6)
36teacher.requires_grad_(False)
37teacher.eval()
38
39input_ids = torch.tensor([
40 [1, 2, 3, 4, 5],
41 [1, 3, 5, 7, 9],
42])
43labels = input_ids.clone()
44optimizer = torch.optim.AdamW(student.parameters(), lr=0.05)
45
46with torch.no_grad():
47 teacher_logits = teacher(input_ids)
48
49before = kd_loss(student(input_ids), teacher_logits, labels)
50optimizer.zero_grad()
51before.backward()
52has_grad = any(parameter.grad is not None for parameter in student.parameters())
53optimizer.step()
54
55after = kd_loss(student(input_ids), teacher_logits, labels)
56
57print("before:", round(float(before), 4))
58print("after:", round(float(after), 4))
59print("has_grad:", has_grad)
60print("after_is_finite:", bool(torch.isfinite(after)))
61print("improved:", bool(after < before))1before: 1.1513
2after: 0.9793
3has_grad: True
4after_is_finite: True
5improved: TrueIn production, the same pattern usually lives inside a framework trainer, with aligned-vocabulary models such as a larger Gemma teacher and smaller Gemma student when direct token-probability distillation is required. Teams may pre-compute some teacher signal offline to avoid running the teacher inside every student update. Dense next-token logits across a long corpus are costly to store, so a design may consider top-k logits, teacher responses, or online scoring, then measure the quality effect of compression. The payload-only estimate below ignores metadata and storage-format overhead, so treat it as a lower-bound sizing exercise.
1tokens = 50_000_000
2vocab_size = 32_000
3bytes_per_logit = 2 # bf16
4top_k = 64
5bytes_per_topk_item = 2 + 4 # bf16 value plus int32 token id
6
7dense_bytes = tokens * vocab_size * bytes_per_logit
8topk_bytes = tokens * top_k * bytes_per_topk_item
9gib = 1024 ** 3
10
11print("dense cache GiB:", round(dense_bytes / gib, 1))
12print("top-k cache GiB:", round(topk_bytes / gib, 1))
13print("storage reduction:", round(dense_bytes / topk_bytes, 1), "x")
14print("quality must still be evaluated:", True)1dense cache GiB: 2980.2
2top-k cache GiB: 17.9
3storage reduction: 166.7 x
4quality must still be evaluated: TrueIf an online-distillation pilot is feasible, compare it with an offline baseline before committing to large-scale data generation. The result can expose whether fresh teacher scoring is worth its compute cost for this task.
When the teacher is only available through generated text, response distillation becomes the default. When you control both models and can inspect logits or hidden states, white-box distillation exposes a richer training signal. Recent systems use both patterns depending on what the teacher exposes.
What changes from system to system isn't the core idea, but the supervision channel: raw logits, hidden states, instruction-response pairs, rationale traces, or synthetic corpora generated by a stronger model.
| Student / recipe | Teacher | Signal transferred | Why it matters |
|---|---|---|---|
| Alpaca 7B[5] | text-davinci-003 | 52K generated instruction-response examples | The repository reports preliminary instruction-following evaluation and clear non-commercial dataset terms. |
| Orca 13B[3] | GPT-4 + ChatGPT | Explanation traces and task instructions | Evaluates a richer generated-trace training recipe, rather than logit KD. |
| phi-1.5[4] | Existing LLMs + curated synthetic data | Textbook-like synthetic corpora | Adjacent synthetic-data recipe, not a direct teacher-distribution KD comparison. |
| Gemma 2 2B / 9B[2] | Larger Gemma teachers | Token-probability distillation during pretraining | Reports a controlled 2B distilled-versus-from-scratch ablation. |
| DeepSeek-R1-Distill 1.5B-70B[7] | DeepSeek-R1 | Roughly 800K selected SFT examples | Reports results from text-target fine-tuning; transfer remains benchmark-scoped. |
When using response-based distillation, the selected dataset is one of the main controllable inputs, alongside student capacity and training budget. Build a generation and selection pipeline that can reject incorrect, duplicate, contaminated, or irrelevant examples.
A structured approach to generating teacher data involves a Seed-Expand-Filter pipeline. It does not ensure quality by itself; each filter needs a measurable contract and a separate evaluation split.
This is close to the spirit of Alpaca's Self-Instruct-style pipeline and Orca's richer explanation-trace data generation, even though real systems add deduplication, safety filters, and task balancing.[5][3]
This logic can be implemented by creating a small wrapper around a teacher client. The example below focuses on two easy-to-miss requirements: deduplicate prompts before paying for generation, and verify teacher answers before they become student targets.
1from collections.abc import Callable
2
3class DistillationDataGenerator:
4 def __init__(
5 self,
6 teacher_generate: Callable[[str], str],
7 verify_response: Callable[[str, str], bool],
8 ):
9 self.teacher_generate = teacher_generate
10 self.verify_response = verify_response
11
12 def generate_dataset(self, prompts: list[str]) -> list[dict[str, str]]:
13 selected: list[dict[str, str]] = []
14 seen: set[str] = set()
15 for prompt in prompts:
16 normalized = " ".join(prompt.lower().split())
17 if normalized in seen:
18 continue
19 seen.add(normalized)
20 response = self.teacher_generate(normalized).strip()
21 if self.verify_response(normalized, response):
22 selected.append({"prompt": normalized, "response": response})
23 return selected
24
25def fake_teacher(prompt: str) -> str:
26 return "30 days"
27
28trusted_answers = {
29 "sealed electronics return": "30 days",
30 "final-sale electronics return": "not eligible",
31}
32
33def verify_response(prompt: str, response: str) -> bool:
34 return trusted_answers[prompt] == response
35
36generator = DistillationDataGenerator(fake_teacher, verify_response)
37examples = generator.generate_dataset([
38 "sealed electronics return",
39 " Sealed electronics return ",
40 "final-sale electronics return",
41])
42
43print("selected prompts:", [example["prompt"] for example in examples])
44print("selected count:", len(examples))
45print("bad response retained:", any("final-sale" in example["prompt"] for example in examples))1selected prompts: ['sealed electronics return']
2selected count: 1
3bad response retained: FalseTeacher generation can quietly contaminate a benchmark when prompts, reference solutions, or close rewrites enter the student training set. At minimum, block exact normalized overlap before training. For real releases, extend the gate with near-duplicate and reference-solution checks.
1def normalize(prompt: str) -> str:
2 return " ".join(prompt.lower().replace("?", "").split())
3
4candidate_training_prompts = [
5 "Compute shipping refund for a late parcel",
6 "Return sealed electronics within 30 days",
7 "Can a FINAL-SALE item be returned?",
8]
9held_out_prompts = [
10 "can a final-sale item be returned",
11 "Estimate delivery window for a remote zip code",
12]
13
14held_out_keys = {normalize(prompt) for prompt in held_out_prompts}
15accepted = [
16 prompt for prompt in candidate_training_prompts
17 if normalize(prompt) not in held_out_keys
18]
19blocked = [
20 prompt for prompt in candidate_training_prompts
21 if normalize(prompt) in held_out_keys
22]
23
24print("accepted training prompts:", len(accepted))
25print("blocked overlap:", blocked)
26print("held-out exact overlap after gate:", any(normalize(p) in held_out_keys for p in accepted))1accepted training prompts: 2
2blocked overlap: ['Can a FINAL-SALE item be returned?']
3held-out exact overlap after gate: FalseDistillation does not make capacity, context-window, or data-coverage constraints disappear. A student may beat its teacher on a narrow checked metric after filtering or task-specific training, while regressing on other behavior. Treat the teacher and student as separate artifacts to evaluate.
Before investing in a distillation pipeline, define which behaviors matter and how they will be tested.
| Behavior | Regression risk to test | Useful held-out gate |
|---|---|---|
| Domain answers | Generated targets can repeat teacher errors | Checked answer accuracy and abstention rate |
| Instruction following | Narrow traces can miss new constraints | Fresh constraint-following prompts |
| Multi-step solutions | Final answers can hide invalid steps | Step checks where available plus final-answer accuracy |
| Long-context use | Student architecture or context limit may differ | Retrieval and long-context slices at deployment length |
| Safety and policy behavior | Filtered corpus may omit refusals or edge cases | Safety-policy evaluation separate from task benchmark |
Beyond technical constraints, distillation introduces unique licensing and evaluation challenges. Because the student model closely mirrors the teacher's outputs, the origins of that training data are important.
Deciding whether to distill, and which method to use, comes down to measured quality and economics. Richer teacher access can enable different losses; it does not rank final models without evaluation.
| Approach | Required access | Main training cost | Release gate |
|---|---|---|---|
| Use teacher directly | Teacher inference | No student training | Baseline quality, latency, and cost |
| Response KD | Generated outputs and permitted use | Generation plus SFT | Output filtering and held-out task quality |
| Logit KD | Aligned teacher token probabilities | Teacher scoring or cache storage | Task quality plus cache/online cost |
| Feature KD | Hidden states and layer mapping | Extra projections and state transfer | Ablation against simpler KD baseline |
Also, don't stop at distillation loss. A student can match teacher probabilities on training batches and still regress on held-out generation quality, long-context behavior, or latency targets. Measure task metrics, pairwise win rate, and real serving cost together.
1teacher = {"checked_accuracy": 0.94, "policy_error_rate": 0.01, "latency_ms": 180}
2student = {"checked_accuracy": 0.92, "policy_error_rate": 0.04, "latency_ms": 42}
3requirements = {"checked_accuracy": 0.90, "max_policy_error_rate": 0.02, "max_latency_ms": 60}
4
5checks = {
6 "quality": student["checked_accuracy"] >= requirements["checked_accuracy"],
7 "policy": student["policy_error_rate"] <= requirements["max_policy_error_rate"],
8 "latency": student["latency_ms"] <= requirements["max_latency_ms"],
9}
10
11print("student faster:", student["latency_ms"] < teacher["latency_ms"])
12print("release checks:", checks)
13print("deploy student:", all(checks.values()))1student faster: True
2release checks: {'quality': True, 'policy': False, 'latency': True}
3deploy student: FalseHow does temperature affect distillation? Temperature controls how soft the teacher distribution becomes. A higher spreads probability mass across alternatives, exposing more of the teacher ranking to the loss. Too high becomes nearly uniform and loses ranking signal; any setting still needs held-out validation.
Can you distill from a proprietary API model? If its permitted interface and terms allow generated outputs for training, the available KD channel is response distillation. You generate candidate prompt-response pairs or checked solution traces and fine-tune the student on selected text. You don't get token probabilities, so filtering, use rights, and held-out evaluation are central.
When does distillation fail to transfer capability? It fails on a target behavior when student capacity, context length, architecture, or training coverage cannot reproduce that behavior at the required threshold. Define slices such as long-context tasks, policy edge cases, and checked multi-step solutions so the failure is visible.
How do you decide whether to deploy a distilled student? Set quality, policy, latency, and cost gates before training. Lower latency is not enough if a student fails a policy or accuracy gate, as the executable deployment example shows.
Can you directly distill between different tokenizers? Not with plain token-level KL unless token-to-id output mappings agree. Same vocabulary size is not enough. If output spaces differ, use response distillation or design and validate an explicit mapping.
Before moving on, try to answer these questions without looking back at the article.
Why does a soft label teach more than a hard label?
Hint: Think about what the student learns about the near-miss answers, not only the correct one.
Why might provide less ranking signal than a higher temperature? Hint: Consider what a sharp teacher distribution exposes about alternatives.
Why do causal language models need a one-token shift when computing distillation loss?
Hint: Remember that a causal LM predicts the next token from all previous tokens.
Why should forward versus reverse KL be chosen through evaluation? Hint: Think about mode coverage, low-teacher-probability mass, and task-dependent diversity requirements.
A hard label only tells the student which answer is selected. A soft label exposes the teacher's full ranking over alternatives. That extra signal can help, but a trusted check still has to catch bad teacher rankings.
If the teacher distribution is sharp at , the top token receives most of the probability mass and alternatives contribute little signal. Raising temperature can expose their relative ranking, until excessive smoothing removes useful separation.
In a causal LM, the logits produced at position predict the token at position . If you don't shift the labels by one, you are training the model to predict the current token from itself, which is trivial and wrong.
Forward KL penalizes missing teacher mass; reverse KL penalizes student mass where the teacher is low. Their quality and diversity tradeoffs vary by task and decoding setup, so evaluate both where the objective choice matters.
| Symptom | Cause | Fix |
|---|---|---|
| Validation loss barely changes as you raise temperature. | The softened teacher distribution may be too flat or the soft-loss weight may be ineffective. | Inspect teacher probabilities and tune temperature and loss weight on held-out tasks. |
| Student looks strong on training prompts but weak on held-out tasks. | Distillation corpus is too narrow, repetitive, or too close to evaluation data. | Broaden prompt coverage, filter duplicates, and keep a separate held-out evaluation slice. |
| Student predicts current token instead of next token during logit KD. | Causal LM loss forgot the one-token shift. | Shift logits at position against labels at position before KL or cross-entropy. |
| Student copies teacher hallucinations and policy mistakes. | Distillation blindly transferred bad teacher outputs. | Filter teacher generations, add task loss, and evaluate against trusted labels or reward checks. |
| KL loss runs but student quality stays random. | Teacher and student token-to-id mappings do not align, even if sizes match. | Compare mappings exactly, use response distillation, or design an explicit output-space mapping. |
| Tiny student misses required checked behaviors. | Student capacity, context, or data coverage is insufficient for this release target. | Narrow task scope, increase student size, or revise training and evaluation design. |
| Offline metrics look great but production quality collapses. | Distillation and evaluation data leaked into each other. | Split generation, tuning, and evaluation sets cleanly before training starts. |
By the end of this chapter, you should be able to:
Distilling the Knowledge in a Neural Network.
Hinton, G., Vinyals, O., & Dean, J. · 2015
Gemma 2: Improving Open Language Models at a Practical Size
Gemma Team, Google DeepMind · 2024
Orca: Progressive Learning from Complex Explanation Traces of GPT-4.
Mukherjee, S., et al. · 2023
Textbooks Are All You Need II: phi-1.5 technical report.
Li, Y., et al. · 2023
Stanford Alpaca: An Instruction-following LLaMA Model.
Taori, R., et al. · 2023 · GitHub
Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data.
Hsieh, C., et al. · 2023
DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning
DeepSeek-AI · 2025
MiniLLM: On-Policy Distillation of Large Language Models.
Gu, Y., et al. · 2024
On-Policy Distillation for Language Models.
Agarwal, R., et al. · 2024