Build and train a tiny GPT end to end on Shakespeare: tokenize with GPT-style subwords, remap active token IDs, run causal self-attention, track validation loss, save a checkpoint, and sample text.
Token shards are only useful once a model can learn from them. This lab shows that next step: a tiny decoder-only GPT (Generative Pre-trained Transformer) consumes token blocks, predicts next tokens, writes a checkpoint, and produces sample text you can judge.[1]
Reference implementation inspiration matters here. Karpathy's nanoGPT provides a compact end-to-end Transformer training codebase, but its fastest Shakespeare on-ramp is character-level. For this lab we keep the tiny scale while switching to GPT-2 byte-level byte pair encoding (BPE), matching GPT-2's published input representation.[2][3]
The corpus is bundled here so readers can download the same raw text and rerun the experiment locally: Download shakespeare.txt. That file is Princeton's course mirror of Shakespeare text. We train on the full bundled corpus, not a single-play excerpt. The run stays CPU-friendly because the model and update budget are tiny.[4]
Why use Shakespeare instead of a technical-log corpus? Because here the learning target is pure next-token mechanics: tokenization, causal masking, loss, checkpointing, and sampling. Once you understand the tiny GPT loop on a public corpus, the same pipeline transfers to code, runbooks, chat transcripts, and other domain text where you would usually fine-tune or continue pretrain.
Small GPT runs still have one non-negotiable contract:
t attend only to positions <= tt + 1If any piece is wrong, the loop still runs and still prints numbers. Tokenizer, batching, model, and generation need to agree with each other instead of being studied as isolated concepts.
| Piece | What it does | Failure mode |
|---|---|---|
| Corpus | supplies raw language distribution | using too much text for CPU lab or too little text for recognizable output |
| Tokenizer | turns text into GPT-style subword IDs | forgetting that token IDs depend on exact tokenizer |
| Active-vocab remap | shrinks sparse GPT-2 token IDs into small local range | keeping a 50k-row tied vocabulary matrix in toy CPU lab |
| Block packer | slices token stream into fixed contexts and shifted labels | off-by-one windows or broken train/val split |
| Decoder loop | runs masked self-attention, loss, checkpoint, and generation | missing causal mask or trusting loss without sampling |
That middle row matters. GPT-2 BPE can emit IDs anywhere in a 50k vocabulary. Our tiny lab still only activates a subset of those IDs, so we remap active token IDs into contiguous local IDs like 0..17484. Same subword tokenization. Smaller tied embedding/output matrix.
1gpt2_vocab = 50_257
2active_vocab = 17_485
3d_model = 96
4
5# GPT-2 uses token-embedding weights again for output logits, so one
6# vocabulary-sized matrix determines this part of parameter cost.
7full_weights = gpt2_vocab * d_model
8compact_weights = active_vocab * d_model
9reduction = 1 - compact_weights / full_weights
10
11print(f"full vocab tied weights: {full_weights:,}")
12print(f"compact tied weights: {compact_weights:,}")
13print(f"toy-lab reduction: {reduction:.1%}")1full vocab tied weights: 4,824,672
2compact tied weights: 1,678,560
3toy-lab reduction: 65.2%The remap is an efficiency device for this fixed corpus, not a replacement tokenizer. The model below also follows GPT-2's tied token-embedding/output-logit matrix.[5] This toy mapping is built from the full corpus before the split so validation IDs remain representable; a full GPT-2-vocabulary run wouldn't need that compromise. Compact remapping also means generated prompts must be expressible using token IDs present in the corpus.
A costly silent error in a language-model lab is training each token to reproduce itself instead of predicting its successor. Check the shift on a tiny block before looking at a full training loop.
1block = ["Good", " sir", ",", "\n", "Speak", " plain", "."]
2x = block[:-1]
3y = block[1:]
4
5for current, target in zip(x, y):
6 print(f"{current!r:>8} -> {target!r}")
7
8assert y[0] == " sir" and y[-1] == "."1'Good' -> ' sir'
2 ' sir' -> ','
3 ',' -> '\n'
4 '\n' -> 'Speak'
5 'Speak' -> ' plain'
6' plain' -> '.'A shifted target isn't enough: attention at each position must also be unable to read future inputs. During each forward pass, the model emits logits: one unnormalized score for every local vocabulary ID at every position. Cross-entropy compares those scores with the shifted targets. Inside attention, the causal mask is applied before softmax turns allowed scores into weights, so future positions receive no probability mass.
1import torch
2
3mask = torch.tril(torch.ones(4, 4, dtype=torch.bool))
4for row in mask.int().tolist():
5 print(" ".join(map(str, row)))
6
7assert mask[2].tolist() == [True, True, True, False]
8print("position 2 can read positions:", [index for index, visible in enumerate(mask[2]) if visible])11 0 0 0
21 1 0 0
31 1 1 0
41 1 1 1
5position 2 can read positions: [0, 1, 2]This lab is more useful if first result is still half-baked. So notebook flow is staged:
The first cell builds the whole pipeline in plain PyTorch, then only trains for 80 steps. That's enough to learn something, but not enough to look good.[6]
1from pathlib import Path
2import math
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6import tiktoken
7
8# Fix random seed so article output is reproducible.
9torch.manual_seed(7)
10
11# Load raw text exactly as learner will download it.
12all_text = Path("assets/shakespeare.txt").read_text(encoding="utf-8")
13corpus = all_text
14
15# Use GPT-2 BPE so tokenization matches GPT-2's published input representation.
16encoding_name = "gpt2"
17encoder = tiktoken.get_encoding(encoding_name)
18corpus_token_ids = encoder.encode(corpus)
19
20# GPT-2 token ids are sparse across a 50k vocabulary. This run only needs
21# tokens that actually appear inside bundled Shakespeare corpus, so remap them to 0..N-1.
22# That keeps tied embedding/output weights much smaller without changing
23# which subword pieces the tokenizer produced.
24active_token_ids = sorted(set(corpus_token_ids))
25token_to_local = {token_id: idx for idx, token_id in enumerate(active_token_ids)}
26local_to_token = {idx: token_id for token_id, idx in token_to_local.items()}
27ids = [token_to_local[token_id] for token_id in corpus_token_ids]
28
29# Each training example needs block_size input tokens plus 1 next-token label.
30block_size = 48
31
32# Split once along original stream. Concatenating interleaved held-out chunks
33# would create fake transitions where non-adjacent Shakespeare passages meet.
34split_index = int(0.9 * len(ids))
35train_ids = torch.tensor(ids[:split_index], dtype=torch.long)
36val_ids = torch.tensor(ids[split_index:], dtype=torch.long)
37batch_size = 12
38
39# Keep training, validation, and text-sampling randomness independent. Logging
40# one sample should never change which training windows the model sees next.
41train_generator = torch.Generator().manual_seed(101)
42val_generator = torch.Generator().manual_seed(202)
43
44print(
45 f"tokens={len(ids)} active_vocab={len(active_token_ids)} "
46 f"train={len(train_ids)} val={len(val_ids)}"
47)
48
49def sample_batch(
50 source: torch.Tensor,
51 *,
52 generator: torch.Generator,
53) -> tuple[torch.Tensor, torch.Tensor]:
54 # Pick random starting positions from long token stream.
55 starts = torch.randint(
56 0,
57 len(source) - block_size,
58 (batch_size,),
59 generator=generator,
60 )
61
62 # x is current context window.
63 x = torch.stack([source[s:s + block_size] for s in starts])
64
65 # y is same window shifted one token to left. This is whole learning target.
66 y = torch.stack([source[s + 1:s + block_size + 1] for s in starts])
67 return x, y
68
69class CausalSelfAttention(nn.Module):
70 def __init__(self, d_model: int = 96, n_heads: int = 4):
71 super().__init__()
72 self.n_heads = n_heads
73 self.head_dim = d_model // n_heads
74
75 # One linear layer projects each position into query, key, and value vectors.
76 self.qkv = nn.Linear(d_model, 3 * d_model)
77 self.proj = nn.Linear(d_model, d_model)
78
79 # Lower-triangular mask blocks attention to future positions.
80 self.register_buffer(
81 "mask",
82 torch.tril(torch.ones(block_size, block_size, dtype=torch.bool)),
83 )
84
85 def forward(self, x: torch.Tensor) -> torch.Tensor:
86 batch_size, seqlen, width = x.shape
87 q, k, v = self.qkv(x).chunk(3, dim=-1)
88
89 def split_heads(tensor: torch.Tensor) -> torch.Tensor:
90 # Turn [batch, time, width] into [batch, heads, time, head_dim].
91 return tensor.view(batch_size, seqlen, self.n_heads, self.head_dim).transpose(1, 2)
92
93 q = split_heads(q)
94 k = split_heads(k)
95 v = split_heads(v)
96
97 # Attention score = query-key similarity, scaled to keep softmax stable.
98 attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
99
100 # Any position above diagonal is future token, so hide it from model.
101 attn = attn.masked_fill(~self.mask[:seqlen, :seqlen], float("-inf"))
102 attn = attn.softmax(dim=-1)
103
104 # Weighted sum of value vectors produces contextualized representation.
105 out = attn @ v
106 out = out.transpose(1, 2).contiguous().view(batch_size, seqlen, width)
107 return self.proj(out)
108
109class Block(nn.Module):
110 def __init__(self, d_model: int = 96, n_heads: int = 4):
111 super().__init__()
112
113 # Pre-LN transformer block: normalize, attend, add residual, then MLP.
114 self.ln1 = nn.LayerNorm(d_model)
115 self.attn = CausalSelfAttention(d_model, n_heads)
116 self.ln2 = nn.LayerNorm(d_model)
117 self.ff = nn.Sequential(
118 nn.Linear(d_model, 4 * d_model),
119 nn.GELU(),
120 nn.Linear(4 * d_model, d_model),
121 )
122
123 def forward(self, x: torch.Tensor) -> torch.Tensor:
124 x = x + self.attn(self.ln1(x))
125 x = x + self.ff(self.ln2(x))
126 return x
127
128class TinyGPT(nn.Module):
129 def __init__(self, vocab_size: int, d_model: int = 96, n_heads: int = 4, n_layers: int = 2):
130 super().__init__()
131
132 # Token embeddings say "which subword is this?".
133 self.token_emb = nn.Embedding(vocab_size, d_model)
134
135 # Position embeddings say "where is this token inside current window?".
136 self.pos_emb = nn.Embedding(block_size, d_model)
137 self.blocks = nn.ModuleList([Block(d_model, n_heads) for _ in range(n_layers)])
138 self.ln_f = nn.LayerNorm(d_model)
139
140 # GPT-2 reuses token embedding weights for its output logits.
141 self.head = nn.Linear(d_model, vocab_size, bias=False)
142 self.apply(self._init_weights)
143 self.head.weight = self.token_emb.weight
144
145 @staticmethod
146 def _init_weights(module: nn.Module) -> None:
147 # GPT-style small initialization is important once output weights are tied.
148 if isinstance(module, (nn.Linear, nn.Embedding)):
149 nn.init.normal_(module.weight, mean=0.0, std=0.02)
150 if isinstance(module, nn.Linear) and module.bias is not None:
151 nn.init.zeros_(module.bias)
152
153 def forward(self, x: torch.Tensor) -> torch.Tensor:
154 _, seqlen = x.shape
155 positions = torch.arange(seqlen, device=x.device)
156
157 # GPT adds token meaning and position meaning before any attention happens.
158 h = self.token_emb(x) + self.pos_emb(positions)[None, :, :]
159 for block in self.blocks:
160 h = block(h)
161 h = self.ln_f(h)
162 return self.head(h)
163
164# Build model and optimizer.
165model_config = {"d_model": 96, "n_heads": 4, "n_layers": 2}
166model = TinyGPT(vocab_size=len(active_token_ids), **model_config)
167optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
168
169def evaluate(source: torch.Tensor) -> tuple[float, float]:
170 # Average across a few validation batches so accuracy is less noisy.
171 was_training = model.training
172 model.eval()
173 losses = []
174 accuracies = []
175 with torch.no_grad():
176 for _ in range(8):
177 x, y = sample_batch(source, generator=val_generator)
178 logits = model(x)
179 losses.append(F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1)).item())
180 accuracies.append((logits.argmax(dim=-1) == y).float().mean().item())
181 model.train(was_training)
182 return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
183
184def sample_completion(prompt: str, steps: int = 80) -> str:
185 # Use a local generator so fixed-prompt monitoring cannot perturb training.
186 sampling_generator = torch.Generator().manual_seed(17)
187
188 prompt_token_ids = encoder.encode(prompt)
189 missing_ids = sorted(set(prompt_token_ids) - set(active_token_ids))
190 if missing_ids:
191 raise ValueError(f"Prompt uses token ids outside compact corpus vocabulary: {missing_ids}")
192 prompt_local_ids = [token_to_local[token_id] for token_id in prompt_token_ids]
193 context = torch.tensor([prompt_local_ids], dtype=torch.long)
194
195 was_training = model.training
196 model.eval()
197 with torch.no_grad():
198 for _ in range(steps):
199 # If sample gets longer than block size, GPT only sees most recent window.
200 x = context[:, -block_size:]
201 logits = model(x)
202
203 # Only final position matters for next-token sampling.
204 next_logits = logits[:, -1, :]
205
206 # Restrict to top candidates so toy model doesn't wander too wildly.
207 top_values, top_indices = torch.topk(next_logits, k=8, dim=-1)
208 probs = torch.softmax(top_values / 0.9, dim=-1)
209 sampled_index = torch.multinomial(
210 probs,
211 num_samples=1,
212 generator=sampling_generator,
213 )
214 next_local_id = top_indices.gather(-1, sampled_index)
215
216 # Append sampled token and continue autoregressive loop.
217 context = torch.cat([context, next_local_id], dim=1)
218
219 # Convert local ids back to original GPT-2 token ids, then decode to text.
220 sample = encoder.decode([local_to_token[int(idx)] for idx in context[0]])
221 model.train(was_training)
222 return sample
223
224for step in range(81):
225 # 1. Draw random training batch.
226 x, y = sample_batch(train_ids, generator=train_generator)
227
228 # 2. Predict next-token logits for every position.
229 logits = model(x)
230
231 # 3. Compare logits against shifted targets.
232 loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
233
234 # 4. Backpropagate and update weights.
235 optimizer.zero_grad()
236 loss.backward()
237 optimizer.step()
238
239 # Print occasional train/val snapshots so learner can watch first useful learning happen.
240 if step % 40 == 0:
241 val_loss, val_acc = evaluate(val_ids)
242 print(
243 f"step={step:03d} train={loss.item():.3f} "
244 f"val={val_loss:.3f} val_acc={val_acc:.3f}"
245 )
246
247# Save first-checkpoint metrics so later cell can compare improvement directly.
248early_val_loss = val_loss
249early_val_acc = val_acc
250
251# Save enough state to keep sampling compatible with trained weights.
252checkpoint = {
253 "encoding_name": encoding_name,
254 "active_token_ids": active_token_ids,
255 "model_config": model_config,
256 "block_size": block_size,
257 "split_index": split_index,
258 "model_state_dict": model.state_dict(),
259 "optimizer_state_dict": optimizer.state_dict(),
260 "train_generator_state": train_generator.get_state(),
261 "val_generator_state": val_generator.get_state(),
262}1tokens=1255253 active_vocab=17485 train=1129727 val=125526
2step=000 train=9.760 val=9.536 val_acc=0.102
3step=040 train=6.444 val=6.439 val_acc=0.146
4step=080 train=6.103 val=6.076 val_acc=0.169The second cell samples that early checkpoint. Because the prompt isn't present verbatim in the corpus, generation doesn't start from a memorized heading. The continuation can still contain familiar or copied spans, so this is a sanity check rather than a memorization audit.
1# Prompt is intentionally not copied from training corpus verbatim.
2prompt = "Good sir,\nSpeak plain.\n"
3
4# This first sample should still look rough and undertrained.
5sample = sample_completion(prompt)
6print(f"prompt_seen_verbatim={prompt in corpus}")
7print("sample:")
8print("\n".join(line.rstrip() for line in sample.splitlines()))1prompt_seen_verbatim=False
2sample:
3Good sir,
4Speak plain.
5
6I and not
7
8 and ,
9
10To
11
12 and a I
13
14To ,
15
16And
17
18I the ,
19 the
20 , . the . , I , the , I . and ,
21 . . ,
22
23 ,
24 .
25 ,
26 a ; , ,That checkpoint is still half-baked. It knows Shakespeare-ish punctuation and function words, but it doesn't yet hold a stable thought.
The third cell keeps the exact same model and optimizer state, trains longer, and prints whether held-out metrics improved.
1# Continue from exact same checkpoint instead of restarting from scratch.
2for step in range(81, 321):
3 x, y = sample_batch(train_ids, generator=train_generator)
4 logits = model(x)
5 loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
6 optimizer.zero_grad()
7 loss.backward()
8 optimizer.step()
9
10 if step % 80 == 0:
11 late_val_loss, late_val_acc = evaluate(val_ids)
12 print(
13 f"step={step:03d} train={loss.item():.3f} "
14 f"val={late_val_loss:.3f} val_acc={late_val_acc:.3f}"
15 )
16
17print(f"val_loss_improved_by={early_val_loss - late_val_loss:.3f}")
18print(f"val_acc_improved_by={late_val_acc - early_val_acc:.3f}")
19
20checkpoint = {
21 "encoding_name": encoding_name,
22 "active_token_ids": active_token_ids,
23 "model_config": model_config,
24 "block_size": block_size,
25 "split_index": split_index,
26 "model_state_dict": model.state_dict(),
27 "optimizer_state_dict": optimizer.state_dict(),
28 "train_generator_state": train_generator.get_state(),
29 "val_generator_state": val_generator.get_state(),
30}1step=160 train=5.785 val=5.788 val_acc=0.174
2step=240 train=5.670 val=5.730 val_acc=0.163
3step=320 train=5.475 val=5.612 val_acc=0.176
4val_loss_improved_by=0.464
5val_acc_improved_by=0.007The fourth cell samples again from the same prompt so you can compare the rough checkpoint against the longer-trained checkpoint.
1# Same prompt, same sampling settings. Only model weights changed.
2sample = sample_completion(prompt)
3print(f"prompt_seen_verbatim={prompt in corpus}")
4print("sample:")
5print("\n".join(line.rstrip() for line in sample.splitlines()))1prompt_seen_verbatim=False
2sample:
3Good sir,
4Speak plain.
5To the very king
6
7And you you will not you not a lord .
8
9And the king .
10
11And , my lord of a good hand , you , to a lord : I have you .
12And have not the king ; and be not at , my good good lord .
13I pray .
14And have be be a lord . But have the lordThe state dictionary isn't enough for an exact training resume. A checkpoint also needs the vocabulary mapping, corpus split, model configuration, optimizer state, and random-generator states that give its tensors meaning and determine the next update. This cell writes a real checkpoint file, reloads it, and proves that the restored model produces identical logits and the same next training batch.
1checkpoint_path = Path("tiny_gpt_checkpoint.pt")
2torch.save(checkpoint, checkpoint_path)
3restored = torch.load(checkpoint_path, weights_only=True)
4
5assert restored["encoding_name"] == encoding_name
6assert restored["block_size"] == block_size
7assert restored["split_index"] == split_index
8resumed_model = TinyGPT(vocab_size=len(restored["active_token_ids"]), **restored["model_config"])
9resumed_model.load_state_dict(restored["model_state_dict"])
10resumed_optimizer = torch.optim.AdamW(resumed_model.parameters(), lr=2e-3)
11resumed_optimizer.load_state_dict(restored["optimizer_state_dict"])
12
13resumed_train_generator = torch.Generator()
14resumed_train_generator.set_state(restored["train_generator_state"])
15resumed_val_generator = torch.Generator()
16resumed_val_generator.set_state(restored["val_generator_state"])
17
18resumed_token_to_local = {
19 token_id: idx for idx, token_id in enumerate(restored["active_token_ids"])
20}
21assert resumed_token_to_local == token_to_local
22assert torch.equal(resumed_val_generator.get_state(), restored["val_generator_state"])
23
24probe_generator = torch.Generator().manual_seed(303)
25probe_x, _ = sample_batch(val_ids, generator=probe_generator)
26model.eval()
27resumed_model.eval()
28with torch.no_grad():
29 original_logits = model(probe_x)
30 resumed_logits = resumed_model(probe_x)
31
32expected_train_generator = torch.Generator()
33expected_train_generator.set_state(checkpoint["train_generator_state"])
34expected_x, expected_y = sample_batch(train_ids, generator=expected_train_generator)
35resumed_x, resumed_y = sample_batch(train_ids, generator=resumed_train_generator)
36
37print(f"saved checkpoint={checkpoint_path}")
38print(f"restored encoding={restored['encoding_name']} active_vocab={len(restored['active_token_ids'])}")
39print("same logits after round trip:", torch.equal(original_logits, resumed_logits))
40print(
41 "same next training batch after round trip:",
42 torch.equal(expected_x, resumed_x) and torch.equal(expected_y, resumed_y),
43)1saved checkpoint=tiny_gpt_checkpoint.pt
2restored encoding=gpt2 active_vocab=17485
3same logits after round trip: True
4same next training batch after round trip: TrueFor a GPU or distributed run, checkpoint the device RNG, scheduler, and sampler state too. The exact state list grows with the training system, but the rule stays simple: a resume is exact only if the next update sees the same model, optimizer, corpus split, data window, and randomness.
That later checkpoint is still not good writing, but it produces better held-out metrics than the warmup checkpoint:
6.076 to 5.6120.169 to 0.176This is also where many readers misread toy-model output. You aren't asking, "is this fluent enough to publish?" You're checking whether held-out metrics improve and whether fixed-prompt output begins to show local corpus structure without claiming fluent generation.
Character tokens are useful for first contact because vocabulary is tiny and code is easy. GPT-2 instead uses a byte-level BPE input representation, so this lab adopts that specific tokenizer design.[3]
That changes three practical things:
The lab uses GPT-2 BPE plus a local remap:
If you skip remap and keep the full 50k-way tied vocabulary matrix, the toy model becomes needlessly heavy without teaching anything extra.
Nothing about subwords changes core causal objective. Loss still compares logits at position t against ground-truth token at t + 1.
If you forget shift, model stops learning continuation and starts learning reconstruction. Loss may still fall, but checkpoint is optimizing wrong problem.
For a packed token block:
1x: [15496, 995, 11, 262]
2y: [ 995, 11, 262, 995]Those are still "predict next token" pairs, even though each integer now stands for BPE token instead of character.
This toy corpus is represented as one token stream, and fixed windows are cut from that stream rather than sentence rows. Larger pre-training pipelines may pack multiple documents with separate boundary handling, as covered in the preceding chapter.
This lab keeps the same idea:
block_size windowsThis lab splits the original stream once: the first 90% is training data and the final 10% is validation data. Avoid building a validation stream by concatenating interleaved chunks: random windows could then train or evaluate on invented transitions between passages that were never adjacent.
1tokens = list(range(20))
2chunks = [tokens[start:start + 4] for start in range(0, len(tokens), 4)]
3
4interleaved_train = chunks[0] + chunks[2]
5fake_transition = (interleaved_train[3], interleaved_train[4])
6
7split_at = int(0.8 * len(tokens))
8contiguous_train = tokens[:split_at]
9contiguous_validation = tokens[split_at:]
10
11print("interleaved concatenation transition:", fake_transition)
12print("was adjacent in source:", fake_transition[1] == fake_transition[0] + 1)
13print("contiguous split sizes:", len(contiguous_train), len(contiguous_validation))1interleaved concatenation transition: (3, 8)
2was adjacent in source: False
3contiguous split sizes: 16 4Notice progression:
That's the point of the staged notebook. You can see the model move from "barely shaped" to "somewhat useful" instead of pretending the first checkpoint was already good.
In bigger runs you still don't choose between metrics and generations. You need both.
This lab follows the relevant GPT-2 educational choices: learned absolute position embeddings, layer normalization at each sub-block input plus a final normalization, full multi-head attention, and a GELU MLP.[3] Later decoder-only families such as Llama 2 keep the same causal next-token contract while changing several components for their training and serving goals.[7]
| Lab component | Modern replacement | Why it changed |
|---|---|---|
Learned pos_emb table | Rotary position embeddings (RoPE) | rotates queries and keys by position instead of learning an absolute position table; Llama 2 uses RoPE[8][7] |
nn.LayerNorm | RMSNorm | removes mean-centering from normalization; Llama 2 uses RMSNorm before transformer sublayers[9][7] |
GELU MLP | SwiGLU (gated GLU variant) | uses a gated feed-forward activation selected in Llama 2's architecture[10][7] |
| Full multi-head attention | Grouped-query attention (GQA) | shares key/value heads within query-head groups to reduce KV-cache load; Llama 2 uses GQA for its 34B and 70B models[11][7] |
Llama 2 reports RoPE, RMSNorm, and SwiGLU throughout its family, while GQA is specific to its larger 34B and 70B configurations.[7]
One more change is computational, not architectural. Our forward materializes the full [batch, heads, time, time] score matrix, then masks and softmaxes it. PyTorch provides torch.nn.functional.scaled_dot_product_attention, which can dispatch to optimized backends when conditions allow. You can drop it into this lab without changing the attention result:
1import math
2import torch
3import torch.nn.functional as F
4
5torch.manual_seed(3)
6q = torch.randn(1, 2, 4, 8)
7k = torch.randn(1, 2, 4, 8)
8v = torch.randn(1, 2, 4, 8)
9
10scores = (q @ k.transpose(-2, -1)) / math.sqrt(q.size(-1))
11mask = torch.tril(torch.ones(4, 4, dtype=torch.bool))
12manual = scores.masked_fill(~mask, float("-inf")).softmax(dim=-1) @ v
13fused_api = F.scaled_dot_product_attention(q, k, v, is_causal=True)
14
15print("output shape:", tuple(fused_api.shape))
16print("numerically close:", torch.allclose(manual, fused_api, atol=1e-6))1output shape: (1, 2, 4, 8)
2numerically close: TrueThe manual version stays in the lab because seeing the score matrix get masked is the point. PyTorch's API can select optimized kernels when supported, while FlashAttention gives the IO-aware algorithmic basis for avoiding full attention-matrix materialization.[12]
t were trained against token t instead of token t + 1. Fix: Audit the shifted-label path with one tiny hand-checked batch before training longer.block_size was chosen only for speed, not for pattern length. Fix: Increase context window until the model can see enough preceding tokens to learn the structure you care about.Answer every question, then check your score. Score above 75% to mark this lesson complete.
8 questions remaining.
CS336: Language Modeling from Scratch.
Stanford University · 2026
nanoGPT.
Karpathy, A. · 2025
Language Models are Unsupervised Multitask Learners.
Radford, A., et al. · 2019
shakespeare.txt.
Princeton University COS 302 / SML 305 · 2020
GPT-2 Source Implementation.
OpenAI · 2019
PyTorch: An Imperative Style, High-Performance Deep Learning Library.
Paszke, A., et al. · 2019 · NeurIPS 2019
Llama 2: Open Foundation and Fine-Tuned Chat Models.
Touvron, H., et al. · 2023 · arXiv preprint
RoFormer: Enhanced Transformer with Rotary Position Embedding.
Su, J., et al. · 2021
Root Mean Square Layer Normalization.
Zhang, B. & Sennrich, R. · 2019 · NeurIPS 2019
GLU Variants Improve Transformer
Shazeer, N. · 2020
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.
Ainslie, J., et al. · 2023 · EMNLP 2023
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. · 2022 · NeurIPS 2022