Build and train a tiny GPT end to end on Shakespeare: tokenize with GPT-style subwords, remap active token IDs, run causal self-attention, track validation loss, save a checkpoint, and sample text.
The data-pipeline chapter ended with token shards ready for training. This lab shows what happens next: a tiny decoder-only GPT consumes token blocks, predicts next tokens, writes a checkpoint, and produces sample text you can judge.[1]
Reference implementation inspiration matters here. Karpathy's nanoGPT is still one of the clearest end-to-end GPT training codebases, but its fastest Shakespeare on-ramp is character-level. For this lab we keep same spirit and tiny scale, while switching to GPT-2 BPE subwords because that is closer to how real GPT-family models tokenize text.[2][3]
Corpus is bundled with lesson so readers can download same raw text and rerun experiment locally: Download shakespeare.txt. That file is Princeton's course mirror of Shakespeare text. We train on full bundled corpus, not single play excerpt. Run still stays CPU-friendly because model is tiny and update budget is tiny.[4]
Why use Shakespeare in curriculum that normally stays anchored in orders, refunds, shipping, and customer-support flows? Because here learning target is pure next-token mechanics. Once you understand tiny GPT loop on public corpus, same pipeline transfers back to ecommerce support logs, delivery updates, and other domain text where you would usually fine-tune or continue pretrain.
Small GPT runs still have one non-negotiable contract:
t attend only to positions <= tt + 1If any piece is wrong, loop still runs and still prints numbers. That is why this lesson is valuable. It forces you to make tokenizer, batching, model, and generation agree with each other instead of studying them as isolated concepts.
| Piece | What it does | Failure mode |
|---|---|---|
| Corpus | supplies raw language distribution | using too much text for CPU lab or too little text for recognizable output |
| Tokenizer | turns text into GPT-style subword IDs | forgetting that token IDs depend on exact tokenizer |
| Active-vocab remap | shrinks sparse GPT-2 token IDs into small local range | keeping 50k-way output head in toy CPU lab |
| Block packer | slices token stream into fixed contexts and shifted labels | off-by-one windows or broken train/val split |
| Decoder loop | runs masked self-attention, loss, checkpoint, and generation | missing causal mask or trusting loss without sampling |
That middle row matters. GPT-2 BPE can emit IDs anywhere in a 50k vocabulary. Our tiny lab still only activates subset of those IDs, so we remap active token IDs into contiguous local IDs like 0..17484. Same subword tokenization. Much smaller embedding table and output head.
This lab is more useful if first result is still half-baked. So notebook flow is staged:
First cell builds whole pipeline in plain PyTorch, then only trains for 80 steps. That is enough to learn something, but not enough to look good.[5]
1from pathlib import Path
2import math
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6import tiktoken
7
8# Fix random seed so article output is reproducible.
9torch.manual_seed(7)
10
11# Load raw text exactly as learner will download it.
12all_text = Path("assets/shakespeare.txt").read_text(encoding="utf-8")
13corpus = all_text
14
15# Use GPT-2 BPE so tokenization matches real GPT-style subword models.
16encoding_name = "gpt2"
17encoder = tiktoken.get_encoding(encoding_name)
18corpus_token_ids = encoder.encode(corpus)
19
20# GPT-2 token ids are sparse across a 50k vocabulary. This run only needs
21# tokens that actually appear inside bundled Shakespeare corpus, so remap them to 0..N-1.
22# That keeps embedding table and output head much smaller without changing
23# which subword pieces the tokenizer produced.
24active_token_ids = sorted(set(corpus_token_ids))
25token_to_local = {token_id: idx for idx, token_id in enumerate(active_token_ids)}
26local_to_token = {idx: token_id for token_id, idx in token_to_local.items()}
27ids = [token_to_local[token_id] for token_id in corpus_token_ids]
28
29# Each training example needs block_size input tokens plus 1 next-token label.
30block_size = 48
31chunk_size = block_size + 1
32train_flat = []
33val_flat = []
34
35# Walk through token stream in fixed windows. Every 10th chunk goes to
36# validation so train and val still come from same overall distribution.
37for chunk_index, start in enumerate(range(0, len(ids) - chunk_size, chunk_size)):
38 chunk = ids[start:start + chunk_size]
39 target = val_flat if chunk_index % 10 == 0 else train_flat
40 target.extend(chunk)
41
42# Convert Python lists into tensors once so batch sampling is cheap.
43train_ids = torch.tensor(train_flat, dtype=torch.long)
44val_ids = torch.tensor(val_flat, dtype=torch.long)
45batch_size = 12
46
47print(
48 f"tokens={len(ids)} active_vocab={len(active_token_ids)} "
49 f"train={len(train_ids)} val={len(val_ids)}"
50)
51
52def sample_batch(source: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
53 # Pick random starting positions from long token stream.
54 starts = torch.randint(0, len(source) - block_size - 1, (batch_size,))
55
56 # x is current context window.
57 x = torch.stack([source[s:s + block_size] for s in starts])
58
59 # y is same window shifted one token to left. This is whole learning target.
60 y = torch.stack([source[s + 1:s + block_size + 1] for s in starts])
61 return x, y
62
63class CausalSelfAttention(nn.Module):
64 def __init__(self, d_model: int = 96, n_heads: int = 4):
65 super().__init__()
66 self.n_heads = n_heads
67 self.head_dim = d_model // n_heads
68
69 # One linear layer projects each position into query, key, and value vectors.
70 self.qkv = nn.Linear(d_model, 3 * d_model)
71 self.proj = nn.Linear(d_model, d_model)
72
73 # Lower-triangular mask blocks attention to future positions.
74 self.register_buffer(
75 "mask",
76 torch.tril(torch.ones(block_size, block_size, dtype=torch.bool)),
77 )
78
79 def forward(self, x: torch.Tensor) -> torch.Tensor:
80 batch_size, seqlen, width = x.shape
81 q, k, v = self.qkv(x).chunk(3, dim=-1)
82
83 def split_heads(tensor: torch.Tensor) -> torch.Tensor:
84 # Turn [batch, time, width] into [batch, heads, time, head_dim].
85 return tensor.view(batch_size, seqlen, self.n_heads, self.head_dim).transpose(1, 2)
86
87 q = split_heads(q)
88 k = split_heads(k)
89 v = split_heads(v)
90
91 # Attention score = query-key similarity, scaled to keep softmax stable.
92 attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
93
94 # Any position above diagonal is future token, so hide it from model.
95 attn = attn.masked_fill(~self.mask[:seqlen, :seqlen], float("-inf"))
96 attn = attn.softmax(dim=-1)
97
98 # Weighted sum of value vectors produces contextualized representation.
99 out = attn @ v
100 out = out.transpose(1, 2).contiguous().view(batch_size, seqlen, width)
101 return self.proj(out)
102
103class Block(nn.Module):
104 def __init__(self, d_model: int = 96, n_heads: int = 4):
105 super().__init__()
106
107 # Pre-LN transformer block: normalize, attend, add residual, then MLP.
108 self.ln1 = nn.LayerNorm(d_model)
109 self.attn = CausalSelfAttention(d_model, n_heads)
110 self.ln2 = nn.LayerNorm(d_model)
111 self.ff = nn.Sequential(
112 nn.Linear(d_model, 4 * d_model),
113 nn.GELU(),
114 nn.Linear(4 * d_model, d_model),
115 )
116
117 def forward(self, x: torch.Tensor) -> torch.Tensor:
118 x = x + self.attn(self.ln1(x))
119 x = x + self.ff(self.ln2(x))
120 return x
121
122class TinyGPT(nn.Module):
123 def __init__(self, vocab_size: int, d_model: int = 96, n_heads: int = 4, n_layers: int = 2):
124 super().__init__()
125
126 # Token embeddings say "which subword is this?".
127 self.token_emb = nn.Embedding(vocab_size, d_model)
128
129 # Position embeddings say "where is this token inside current window?".
130 self.pos_emb = nn.Embedding(block_size, d_model)
131 self.blocks = nn.ModuleList([Block(d_model, n_heads) for _ in range(n_layers)])
132 self.ln_f = nn.LayerNorm(d_model)
133
134 # Final linear layer turns hidden state back into next-token logits.
135 self.head = nn.Linear(d_model, vocab_size, bias=False)
136
137 def forward(self, x: torch.Tensor) -> torch.Tensor:
138 _, seqlen = x.shape
139 positions = torch.arange(seqlen)
140
141 # GPT adds token meaning and position meaning before any attention happens.
142 h = self.token_emb(x) + self.pos_emb(positions)[None, :, :]
143 for block in self.blocks:
144 h = block(h)
145 h = self.ln_f(h)
146 return self.head(h)
147
148# Build model and optimizer.
149model = TinyGPT(vocab_size=len(active_token_ids))
150optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
151
152def evaluate(source: torch.Tensor) -> tuple[float, float]:
153 # Average across a few validation batches so accuracy is less noisy.
154 model.eval()
155 losses = []
156 accuracies = []
157 with torch.no_grad():
158 for _ in range(8):
159 x, y = sample_batch(source)
160 logits = model(x)
161 losses.append(F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1)).item())
162 accuracies.append((logits.argmax(dim=-1) == y).float().mean().item())
163 model.train()
164 return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
165
166def sample_completion(prompt: str, steps: int = 80) -> str:
167 # Reset sampling seed so article output stays reproducible.
168 torch.manual_seed(17)
169
170 prompt_token_ids = encoder.encode(prompt)
171 prompt_local_ids = [token_to_local[token_id] for token_id in prompt_token_ids]
172 context = torch.tensor([prompt_local_ids], dtype=torch.long)
173
174 model.eval()
175 with torch.no_grad():
176 for _ in range(steps):
177 # If sample gets longer than block size, GPT only sees most recent window.
178 x = context[:, -block_size:]
179 logits = model(x)
180
181 # Only final position matters for next-token sampling.
182 next_logits = logits[:, -1, :]
183
184 # Restrict to top candidates so toy model doesn't wander too wildly.
185 top_values, top_indices = torch.topk(next_logits, k=8, dim=-1)
186 probs = torch.softmax(top_values / 0.9, dim=-1)
187 sampled_index = torch.multinomial(probs, num_samples=1)
188 next_local_id = top_indices.gather(-1, sampled_index)
189
190 # Append sampled token and continue autoregressive loop.
191 context = torch.cat([context, next_local_id], dim=1)
192
193 # Convert local ids back to original GPT-2 token ids, then decode to text.
194 return encoder.decode([local_to_token[int(idx)] for idx in context[0]])
195
196for step in range(81):
197 # 1. Draw random training batch.
198 x, y = sample_batch(train_ids)
199
200 # 2. Predict next-token logits for every position.
201 logits = model(x)
202
203 # 3. Compare logits against shifted targets.
204 loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
205
206 # 4. Backpropagate and update weights.
207 optimizer.zero_grad()
208 loss.backward()
209 optimizer.step()
210
211 # Print occasional train/val snapshots so learner can watch first useful learning happen.
212 if step % 40 == 0:
213 val_loss, val_acc = evaluate(val_ids)
214 print(
215 f"step={step:03d} train={loss.item():.3f} "
216 f"val={val_loss:.3f} val_acc={val_acc:.3f}"
217 )
218
219# Save first-checkpoint metrics so later cell can compare improvement directly.
220early_val_loss = val_loss
221early_val_acc = val_acc
222
223# Save enough state to keep sampling compatible with trained weights.
224checkpoint = {
225 "encoding_name": encoding_name,
226 "active_token_ids": active_token_ids,
227 "model_state_dict": model.state_dict(),
228 "optimizer_state_dict": optimizer.state_dict(),
229}1tokens=1255253 active_vocab=17485 train=1129695 val=125538
2step=000 train=9.887 val=9.799 val_acc=0.000
3step=040 train=6.522 val=6.415 val_acc=0.145
4step=080 train=6.259 val=6.140 val_acc=0.154Second cell samples that early checkpoint. Prompt is chosen to be not present verbatim in training corpus, so sample is less about copied heading and more about genuine recombination.
1# Prompt is intentionally not copied from training corpus verbatim.
2prompt = "Good sir,\nSpeak plain.\n"
3assert prompt not in corpus
4
5# This first sample should still look rough and undertrained.
6sample = sample_completion(prompt)
7print(f"prompt_seen_verbatim={prompt in corpus}")
8print("sample:")
9print("\n".join(line.rstrip() for line in sample.splitlines()))1prompt_seen_verbatim=False
2sample:
3Good sir,
4Speak plain.
5
6I , a to , and , a . and
7
8And .
9
10And of ,
11
12And
13
14That .
15But .
16
17That .
18The the , and ,
19And .
20
21To ,
22And I , and the the , and , and ,
23
24And ,
25OThat checkpoint is still half-baked. It knows Shakespeare-ish punctuation and function words, but it doesn't yet hold a stable thought.
Third cell keeps exact same model and optimizer state, trains longer, and prints whether held-out metrics improved.
1# Continue from exact same checkpoint instead of restarting from scratch.
2for step in range(81, 321):
3 x, y = sample_batch(train_ids)
4 logits = model(x)
5 loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
6 optimizer.zero_grad()
7 loss.backward()
8 optimizer.step()
9
10 if step % 80 == 0:
11 late_val_loss, late_val_acc = evaluate(val_ids)
12 print(
13 f"step={step:03d} train={loss.item():.3f} "
14 f"val={late_val_loss:.3f} val_acc={late_val_acc:.3f}"
15 )
16
17print(f"val_loss_improved_by={early_val_loss - late_val_loss:.3f}")
18print(f"val_acc_improved_by={late_val_acc - early_val_acc:.3f}")
19
20checkpoint = {
21 "encoding_name": encoding_name,
22 "active_token_ids": active_token_ids,
23 "model_state_dict": model.state_dict(),
24 "optimizer_state_dict": optimizer.state_dict(),
25}1step=160 train=5.871 val=5.883 val_acc=0.166
2step=240 train=6.057 val=5.718 val_acc=0.175
3step=320 train=5.614 val=5.534 val_acc=0.179
4val_loss_improved_by=0.606
5val_acc_improved_by=0.025Fourth cell samples again from same prompt so you can compare rough checkpoint against longer-trained checkpoint.
1# Same prompt, same sampling settings. Only model weights changed.
2sample = sample_completion(prompt)
3print(f"prompt_seen_verbatim={prompt in corpus}")
4print("sample:")
5print("\n".join(line.rstrip() for line in sample.splitlines()))1prompt_seen_verbatim=False
2sample:
3Good sir,
4Speak plain.
5What , he , my love , and the other man ?
6
7He in my heart ,
8And to her , and not in my life
9
10And to you are his own time ,
11That ,
12What !
13
14A father .
15
16I am my lord ?
17
18I have a good heart ;
19I am the king .
20
21He ?That later checkpoint is still not good writing, but it is clearly more accurate than warmup checkpoint:
6.140 to 5.5340.154 to 0.179This is also where many readers misread toy-model output. You aren't asking, "is this fluent enough to publish?" You're asking, "did tokenization, masking, and training teach model enough local structure that output now looks Shakespeare-shaped instead of random?" Later run answers yes much more clearly:
Character tokens are useful for first contact because vocabulary is tiny and code is easy. But modern GPT-family models usually train on subword tokenizers, not raw characters.[3]
That changes three practical things:
This is why lab uses GPT-2 BPE plus local remap:
If you skip remap and keep full 50k-way vocabulary head, toy model becomes needlessly heavy without teaching anything extra.
Nothing about subwords changes core causal objective. Loss still compares logits at position t against ground-truth token at t + 1.
If you forget shift, model stops learning continuation and starts learning reconstruction. Loss may still fall, but checkpoint is optimizing wrong problem.
For packed token block:
1x: [15496, 995, 11, 262]
2y: [ 995, 11, 262, 995]Those are still "predict next token" pairs, even though each integer now stands for BPE token instead of character.
Real GPT pretraining treats tokenized text like one long stream. Fixed windows are cut from stream, not from tidy sentence rows.
This lab keeps same idea:
block_size windowsTrain/val split is chunk-based instead of one big tail split. That keeps validation text closer in distribution to training text for tiny corpus, while still holding out separate windows.
Notice progression:
That is point of staged notebook. You can see model move from "barely shaped" to "somewhat useful" instead of pretending first checkpoint was already good.
In bigger runs you still don't choose between metrics and generations. You need both.
Tokenizer, batch packer, causal mask, loss, checkpoint, and sampler have to agree on same ID space and same context contract. Debugging them separately can miss bugs that only appear when checkpoint is sampled.
Tokenizer defines meaning of every token ID. If you change tokenizer, embedding rows and output logits stop matching intended text units.
GPT-2 BPE token 50256 is valid, but tiny lab doesn't need to allocate all rows between 0 and 50256 if corpus only touched 17,485 distinct tokens.
Train loss can keep improving while generations become repetitive or validation drifts up. That is overfitting, not success.
block_size is a learning constraint, not only a speed knob. If window is too short for patterns you want model to learn, sample can stay locally plausible while losing longer dialogue structure.
| Skill | What strong answer includes |
|---|---|
| Coherent GPT loop | connects corpus, tokenizer, packed blocks, causal mask, shifted loss, checkpoint, and generation as one contract |
| Runnable small run | explains GPT-2 BPE tokenization, active-token remap, shifted targets, and held-out validation checks |
| Failure diagnosis | catches unshifted labels, train-only evaluation, tokenizer/checkpoint mismatch, tiny-context limits, and bad sample behavior |
| Prompt | Answer sketch |
|---|---|
| Why is this chapter different from earlier language-modeling and Transformer chapters? | Earlier chapters taught ingredients in isolation. This lab forces them to cooperate in a real GPT-style loop where tokenizer IDs, packed blocks, causal attention, validation, checkpointing, and generation all have to match. |
| What is most common silent bug in causal language-model training code? | Forgetting one-token shift between logits and labels. If logits at position t train against token t instead of token t + 1, loss can move while objective is wrong. |
| Why sample text after checkpoint instead of trusting loss alone? | Lower loss doesn't guarantee healthy generations. Sampling exposes repetition, collapse, context-length bugs, and tokenizer mistakes that scalar loss can hide. |
CS336: Language Modeling from Scratch.
Stanford University · 2026
nanoGPT.
Karpathy, A. · 2025
shakespeare.txt.
Princeton University COS 302 / SML 305 · 2020
Language Models are Unsupervised Multitask Learners.
Radford, A., et al. · 2019
PyTorch: An Imperative Style, High-Performance Deep Learning Library.
Paszke, A., et al. · 2019 · NeurIPS 2019