Build a PyTorch classifier from raw logits through autograd, validation, and reloadable checkpoints.
The previous chapter asked whether a model-system change helps customers. This chapter answers an earlier question: how does a model acquire new behavior at all?
Suppose an e-commerce support queue receives refund tickets. Each ticket has two simple signals: urgency score and negative-sentiment score. A label says whether a specialist later judged that ticket to need escalation. At first, a small classifier guesses badly. A training loop repeatedly shows it examples, measures wrongness, computes which parameters contributed to that wrongness, and updates those parameters.
PyTorch makes that process explicit. That matters when you later fine-tune a text classifier or read a transformer training script: wrappers may add logging and distributed execution, but learning still comes from batches, logits, loss, gradients, optimizer steps, validation, and saved model state.[1][2]
backward() fills dW but leaves W unchanged. With learning rate 0.2, optimizer.step() creates weights [[-0.1, -0.1], [0.1, 0.1]], which route the positive ticket to escalation and the negative ticket to the standard queue.By the end, you will be able to:
nn.Module with TensorDataset and DataLoader.A model can't consume the phrase "angry customer with delayed refund" directly in this small lab. Pretend an earlier feature pipeline already produced two normalized numbers:
| Feature | Meaning | Example range |
|---|---|---|
urgency | Language indicating an immediate deadline or financial risk | -2.0 to 2.0 |
sentiment | Negative tone and complaint intensity | -2.0 to 2.0 |
label | 0 = standard queue, 1 = specialist escalation | 0 or 1 |
Each training row becomes a vector of two floating-point inputs and one integer class target. A batch stacks those rows into an input tensor. Six tickets have feature shape (6, 2) and label shape (6,).
1import torch
2
3features = torch.tensor(
4 [
5 [-2.0, -1.2],
6 [-1.3, -0.8],
7 [-0.8, -1.4],
8 [0.9, 1.1],
9 [1.4, 0.8],
10 [1.8, 1.6],
11 ],
12 dtype=torch.float32,
13)
14labels = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long)
15
16print("features:", tuple(features.shape), features.dtype)
17print("labels:", tuple(labels.shape), labels.dtype)
18print("first ticket:", features[0].tolist(), "route:", labels[0].item())1features: (6, 2) torch.float32
2labels: (6,) torch.int64
3first ticket: [-2.0, -1.2000000476837158] route: 0float32 inputs can flow through linear layers. torch.long labels are class indices for this classification objective. Keeping the shape and dtype contract visible prevents a large class of training bugs.
A classifier ending in nn.Linear(2, 2) produces two scores for each ticket:
| Output column | Route represented |
|---|---|
0 | standard queue |
1 | specialist escalation |
These scores are logits. They may be negative and they don't have to sum to one. For ordinary multi-class classification, PyTorch's CrossEntropyLoss is designed to receive those raw logits plus target class indices. It handles the log-softmax calculation in a numerically stable form.[2]
1import torch
2from torch import nn
3
4logits = torch.tensor([[2.0, -1.0], [-0.5, 1.5]], dtype=torch.float32)
5labels = torch.tensor([0, 1], dtype=torch.long)
6loss = nn.CrossEntropyLoss()(logits, labels)
7
8routes = ["standard", "escalate"]
9predictions = logits.argmax(dim=1).tolist()
10print("predicted routes:", [routes[index] for index in predictions])
11print("cross entropy:", round(loss.item(), 4))1predicted routes: ['standard', 'escalate']
2cross entropy: 0.0878Failure case: Don't call
softmax()and then feed those probabilities intoCrossEntropyLossfor this setup. Pass raw logits. A hand-writtenlog(softmax(logits))calculation is also easier to make numerically unstable when a score becomes very large.
For one mini-batch, training follows a fixed order:
| Code | Purpose | Parameter values change? |
|---|---|---|
optimizer.zero_grad() | Clear gradients left by earlier work. | No |
logits = model(xb) | Run the current model forward. | No |
loss = loss_fn(logits, yb) | Score current mistakes. | No |
loss.backward() | Use autograd to fill each parameter's .grad. | No |
optimizer.step() | Apply an update derived from stored gradients. | Yes |
Backpropagation is credit assignment: if raising one weight would have increased loss, its gradient points in the direction the optimizer should avoid. Autograd performs this chain-rule bookkeeping for every parameter used in the forward computation.[3][1]
The following lab freezes a simple linear model into known starting weights, runs exactly one update, and proves that backward() fills gradients before step() moves weights.
1import torch
2from torch import nn
3
4features = torch.tensor([[1.0, 1.0], [-1.0, -1.0]], dtype=torch.float32)
5labels = torch.tensor([1, 0], dtype=torch.long)
6model = nn.Linear(2, 2)
7with torch.no_grad():
8 model.weight.zero_()
9 model.bias.zero_()
10
11optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
12loss_fn = nn.CrossEntropyLoss()
13
14before = model.weight.detach().clone()
15optimizer.zero_grad()
16loss = loss_fn(model(features), labels)
17loss.backward()
18gradient = model.weight.grad.detach().clone()
19after_backward = model.weight.detach().clone()
20optimizer.step()
21after_step = model.weight.detach().clone()
22
23print("loss:", round(loss.item(), 4))
24print("gradient filled:", bool(gradient.abs().sum() > 0))
25print("changed by backward:", not torch.equal(before, after_backward))
26print("changed by step:", not torch.equal(before, after_step))1loss: 0.6931
2gradient filled: True
3changed by backward: False
4changed by step: TrueIf you can point to the line where model state changes, training scripts stop looking like rituals.
PyTorch adds newly computed gradients into the existing .grad fields. This enables deliberate gradient accumulation, where several micro-batches contribute to one larger effective update. In a normal one-update-per-batch loop, stale gradients are accidental.
1import torch
2
3weight = torch.tensor(1.0, requires_grad=True)
4
5(weight**2).backward()
6first_grad = weight.grad.item()
7
8(weight**2).backward()
9accumulated_grad = weight.grad.item()
10
11weight.grad = None
12(weight**2).backward()
13cleared_grad = weight.grad.item()
14
15print("first backward:", first_grad)
16print("without clearing:", accumulated_grad)
17print("after clearing:", cleared_grad)1first backward: 2.0
2without clearing: 4.0
3after clearing: 2.0Here the same derivative is silently doubled when it isn't cleared. optimizer.zero_grad() is the standard loop form; setting parameter gradients to None expresses the same reset idea in this minimal demonstration.
A dataset owns examples. A data loader groups them into mini-batches and optionally changes row order between epochs. A mini-batch is large enough to estimate a useful gradient but small enough to fit memory.
1import torch
2from torch.utils.data import DataLoader, TensorDataset
3
4features = torch.tensor(
5 [[-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [0.9, 1.1], [1.4, 0.8]],
6 dtype=torch.float32,
7)
8labels = torch.tensor([0, 0, 0, 1, 1], dtype=torch.long)
9loader = DataLoader(TensorDataset(features, labels), batch_size=2, shuffle=False)
10
11for batch_index, (xb, yb) in enumerate(loader, start=1):
12 print(f"batch {batch_index}: shape={tuple(xb.shape)} labels={yb.tolist()}")1batch 1: shape=(2, 2) labels=[0, 0]
2batch 2: shape=(2, 2) labels=[0, 1]
3batch 3: shape=(1, 2) labels=[1]Notice the final mini-batch contains one row. When you average epoch loss, weight each batch loss by its row count so a short last batch doesn't count as much as a full one.
Our small model has a hidden layer so it resembles the structure of larger neural classifiers, while staying quick enough to inspect. Before measuring held-out behavior, start with an essential diagnostic: can it memorize one clean batch?
An overfit-one-batch test isn't a product metric. It is a wiring test. If a model can't drive training loss down on eight separable rows, investigate labels, shapes, loss, optimizer order, and learning rate before you scale up.
1import torch
2from torch import nn
3
4torch.manual_seed(7)
5features = torch.tensor(
6 [
7 [-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [-1.6, -0.4],
8 [0.9, 1.1], [1.4, 0.8], [1.8, 1.6], [0.7, 1.7],
9 ],
10 dtype=torch.float32,
11)
12labels = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long)
13model = nn.Sequential(nn.Linear(2, 8), nn.ReLU(), nn.Linear(8, 2))
14loss_fn = nn.CrossEntropyLoss()
15optimizer = torch.optim.SGD(model.parameters(), lr=0.15)
16
17with torch.no_grad():
18 initial_loss = loss_fn(model(features), labels).item()
19
20for _ in range(120):
21 optimizer.zero_grad()
22 loss = loss_fn(model(features), labels)
23 loss.backward()
24 optimizer.step()
25
26with torch.no_grad():
27 final_logits = model(features)
28 final_loss = loss_fn(final_logits, labels).item()
29 accuracy = (final_logits.argmax(dim=1) == labels).float().mean().item()
30
31print("initial loss:", round(initial_loss, 4))
32print("final loss:", round(final_loss, 4))
33print("memorized batch:", accuracy == 1.0)1initial loss: 0.6521
2final loss: 0.0056
3memorized batch: TrueLoss should fall sharply and memorized batch should be true. Only now is it sensible to ask whether training generalizes.
Training rows are allowed to change weights. Validation rows are held out from updates and answer a different question: does the current model work on unseen labeled tickets?
Use two switches during validation:
| Switch | What it controls |
|---|---|
model.eval() | Evaluation behavior for modules such as Dropout and BatchNorm. |
torch.no_grad() | Stops building gradient history for the validation computation. |
The next complete loop trains on eight tickets and selects a checkpoint using four held-out tickets. The data is deliberately easy so you can focus on mechanics rather than feature engineering.[2]
1import copy
2import torch
3from torch import nn
4from torch.utils.data import DataLoader, TensorDataset
5
6torch.manual_seed(7)
7train_x = torch.tensor(
8 [
9 [-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [-1.6, -0.4],
10 [0.9, 1.1], [1.4, 0.8], [1.8, 1.6], [0.7, 1.7],
11 ],
12 dtype=torch.float32,
13)
14train_y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long)
15val_x = torch.tensor([[-1.1, -0.6], [-0.5, -1.7], [1.0, 0.6], [1.7, 0.4]])
16val_y = torch.tensor([0, 0, 1, 1], dtype=torch.long)
17
18loader = DataLoader(TensorDataset(train_x, train_y), batch_size=4, shuffle=True)
19model = nn.Sequential(nn.Linear(2, 8), nn.ReLU(), nn.Linear(8, 2))
20loss_fn = nn.CrossEntropyLoss()
21optimizer = torch.optim.SGD(model.parameters(), lr=0.15)
22best_loss = float("inf")
23best_state = None
24
25for epoch in range(1, 61):
26 model.train()
27 for xb, yb in loader:
28 optimizer.zero_grad()
29 loss = loss_fn(model(xb), yb)
30 loss.backward()
31 optimizer.step()
32
33 model.eval()
34 with torch.no_grad():
35 val_logits = model(val_x)
36 val_loss = loss_fn(val_logits, val_y).item()
37 if val_loss < best_loss:
38 best_loss = val_loss
39 best_state = copy.deepcopy(model.state_dict())
40
41model.load_state_dict(best_state)
42model.eval()
43with torch.no_grad():
44 selected_accuracy = (model(val_x).argmax(dim=1) == val_y).float().mean().item()
45print("best validation loss:", round(best_loss, 4))
46print("selected validation accuracy:", round(selected_accuracy, 3))
47print("saved tensors:", len(best_state))1best validation loss: 0.0118
2selected validation accuracy: 1.0
3saved tensors: 4copy.deepcopy matters here. A plain best_state = model.state_dict() would keep references to tensors that later epochs continue to change. Saving immediately with torch.save() is another reliable checkpoint-selection pattern.
The linear classifier would produce the same result in train() and eval() because it contains no mode-sensitive layer. Real models often contain Dropout. During training, Dropout randomly hides activations to regularize learning. During evaluation, it must be disabled so one ticket receives stable scores.
1import torch
2from torch import nn
3
4torch.manual_seed(0)
5model = nn.Sequential(nn.Linear(2, 6), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(6, 2))
6ticket = torch.tensor([[1.2, 0.9]])
7
8model.train()
9train_a = model(ticket)
10train_b = model(ticket)
11
12model.eval()
13with torch.no_grad():
14 eval_a = model(ticket)
15 eval_b = model(ticket)
16
17print("training outputs equal:", torch.equal(train_a, train_b))
18print("evaluation outputs equal:", torch.equal(eval_a, eval_b))1training outputs equal: False
2evaluation outputs equal: TrueThis is why calling only torch.no_grad() isn't enough for evaluation: gradient tracking and layer behavior are separate concerns.
0.535825 at epoch 1 to 0.011815 at epoch 60, which becomes the selected checkpoint with held-out accuracy 1.0. Validation uses eval() and no_grad(); it never calls backward() or step().Training isn't complete when the process exits. You need a portable artifact and enough configuration to reproduce how it was created. For inference, saving a model's state_dict is the standard flexible pattern; for resuming training, also save optimizer state and current epoch.
1from pathlib import Path
2from tempfile import TemporaryDirectory
3
4import torch
5from torch import nn
6
7torch.manual_seed(11)
8features = torch.tensor([[-1.5, -1.0], [-0.9, -1.2], [1.1, 0.8], [1.6, 1.4]])
9labels = torch.tensor([0, 0, 1, 1], dtype=torch.long)
10model = nn.Linear(2, 2)
11optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
12loss_fn = nn.CrossEntropyLoss()
13
14for _ in range(80):
15 optimizer.zero_grad()
16 loss = loss_fn(model(features), labels)
17 loss.backward()
18 optimizer.step()
19
20model.eval()
21with torch.no_grad():
22 original = model(features).argmax(dim=1)
23
24with TemporaryDirectory() as directory:
25 path = Path(directory) / "ticket-router.pt"
26 torch.save({"model_state_dict": model.state_dict(), "label_names": ["standard", "escalate"]}, path)
27 checkpoint = torch.load(path, map_location="cpu", weights_only=True)
28 restored = nn.Linear(2, 2)
29 restored.load_state_dict(checkpoint["model_state_dict"])
30 restored.eval()
31 with torch.no_grad():
32 reloaded = restored(features).argmax(dim=1)
33
34print("routes:", original.tolist())
35print("reload agrees:", torch.equal(original, reloaded))
36print("labels:", checkpoint["label_names"])1routes: [0, 0, 1, 1]
2reload agrees: True
3labels: ['standard', 'escalate']In production, save the feature schema, label mapping, split or dataset version, model architecture configuration, and metric used to choose the checkpoint. A file of weights without that receipt is difficult to audit.
When a loop runs without crashing, it can still be wrong. Debug in a narrow order:
| Symptom | Likely cause | Check |
|---|---|---|
CrossEntropyLoss rejects labels | Class targets have floating dtype or wrong shape. | Use integer indices with shape (batch,). |
| Loss falls in one-batch test but validation is poor | Split contamination, weak features, or overfitting. | Inspect data and move to next dataset chapter. |
| Results vary during validation | Model remains in training mode with Dropout or BatchNorm. | Pair model.eval() with validation. |
| Gradients grow each batch unexpectedly | Gradients aren't cleared before the next backward pass. | Put optimizer.zero_grad() inside batch loop. |
| GPU memory grows while logging batch losses | Metric code retains graph-connected loss tensors. | Accumulate loss.item() or detached scalars, not loss. |
Loss becomes nan or inf | Invalid input, unstable loss, or exploding update. | Check tensors and gradients for finite values. |
Gradient clipping doesn't fix bad data or a broken loss, but it can cap an extreme update once you have identified that gradient norms are the issue.
When loss becomes nan or inf, locate the first invalid number instead of immediately lowering the learning rate. Test inputs first, then logits, then loss, then gradients after backward(), then parameters after step(). That sequence tells you whether failure entered through data, unstable arithmetic, gradient scaling, or an oversized update.
1import torch
2from torch import nn
3
4parameter = nn.Parameter(torch.tensor([3.0, 4.0]))
5parameter.grad = torch.tensor([6.0, 8.0])
6
7before = torch.linalg.vector_norm(parameter.grad).item()
8reported_norm = nn.utils.clip_grad_norm_([parameter], max_norm=5.0, error_if_nonfinite=True).item()
9after = torch.linalg.vector_norm(parameter.grad).item()
10
11print("norm before:", round(before, 1))
12print("reported norm:", round(reported_norm, 1))
13print("norm after:", round(after, 1))
14print("finite:", torch.isfinite(parameter.grad).all().item())1norm before: 10.0
2reported norm: 10.0
3norm after: 5.0
4finite: Trueclip_grad_norm_() returns the total norm before clipping. In a live loop, pass error_if_nonfinite=True as above so a non-finite total norm raises before it can enter optimizer state. Also check torch.isfinite(xb), torch.isfinite(logits), torch.isfinite(loss), and parameter gradients at the first failing step. Find where bad numbers appear before changing hyperparameters blindly.
When serving the selected ticket router, you aren't going to call backward(). torch.inference_mode() disables gradient tracking and additional autograd bookkeeping for computations that won't re-enter a gradient-tracked training graph. Use torch.no_grad() during simple validation when you want the familiar training-script pattern; use inference_mode() for committed inference paths after testing that contract.
1import torch
2from torch import nn
3
4torch.manual_seed(17)
5model = nn.Linear(2, 2)
6model.eval()
7ticket = torch.tensor([[1.5, 1.2]])
8
9with torch.inference_mode():
10 logits = model(ticket)
11 prediction = logits.argmax(dim=1).item()
12
13print("gradient tracking:", logits.requires_grad)
14print("route class:", prediction)1gradient tracking: False
2route class: 1The loop above is the core contract. On compatible accelerators, automatic mixed precision (AMP) may reduce memory use and increase throughput by letting suitable operations use lower precision while preserving higher precision where needed. With float16 training, gradient scaling through GradScaler reduces the risk that small gradients underflow before the update. PyTorch's AMP examples place autocast around forward and loss computation, scale before backward(), unscale before clipping, then step and update the scaler.[4][5]
1import torch
2from torch import nn
3
4device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5use_amp = device.type == "cuda"
6model = nn.Linear(2, 2).to(device)
7optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
8loss_fn = nn.CrossEntropyLoss()
9xb = torch.tensor([[1.2, 0.8], [-1.0, -0.7]], device=device)
10yb = torch.tensor([1, 0], dtype=torch.long, device=device)
11scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
12
13optimizer.zero_grad()
14with torch.amp.autocast("cuda", dtype=torch.float16, enabled=use_amp):
15 logits = model(xb)
16 loss = loss_fn(logits, yb)
17
18scaler.scale(loss).backward()
19scaler.unscale_(optimizer)
20torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
21scaler.step(optimizer)
22scaler.update()
23
24print("step completed with finite loss:", bool(torch.isfinite(loss)))1step completed with finite loss: TrueGradScaler.step() checks for non-finite gradients after unscaling and skips the optimizer update when it finds them. GradScaler.update() then adjusts the scale for later iterations. Treat AMP as an extension after a full-precision loop passes its correctness checks. Faster incorrect training is still incorrect training.
This chapter made a model learn from already prepared tensors. A real ticket-routing model still needs evidence for where its labeled rows came from, whether duplicates leaked across splits, and which preprocessing produced those two features.
Before declaring a training run acceptable, retain:
| Receipt item | Why it matters |
|---|---|
| Model and optimizer configuration | Explains how updates were produced. |
| Training and validation loss curves | Shows optimization and held-out behavior separately. |
| Best checkpoint and selection metric | Identifies which weights are being served. |
| Feature schema and label mapping | Makes inference compatible with training. |
| Dataset version and split identity | Lets reviewers investigate leakage or relabeling. |
Use the runnable examples above as a controlled failure lab. Make one change at a time, predict the result, then run the example and explain the observed behavior.
03-one-update.py, comment out optimizer.step(). Predict whether changed by step remains True.04-zero-grad.py, call backward() a third time before clearing gradients. Predict the gradient value.07-train-and-validate.py, replace copy.deepcopy(model.state_dict()) with model.state_dict(). Explain why best_state no longer preserves the best epoch after later updates.08-train-eval-modes.py, leave the model in training mode for the second pair of predictions. Predict whether the outputs must match.12-amp-ordering-pattern.py, explain why scaler.unscale_(optimizer) belongs before gradient clipping.optimizer.step(), gradients are populated but parameter values don't change.6.0: each backward() adds another derivative of 2.0 until gradients are cleared.state_dict() result keeps references to parameter storage. Later training updates overwrite the intended snapshot.backward(). Clipping before unscale_() would apply the threshold to scaled values instead of the gradients used for the real update.Explain the complete route from one ticket batch to a saved model without using the word "magic." Then answer these checks.
backward() and step().softmax() before CrossEntropyLoss instead of giving the loss raw logits.backward() batch after batch without clearing gradients or intentionally scheduling accumulation.backward().Answer every question, then check your score. Score above 75% to mark this lesson complete.
10 questions remaining.
PyTorch: An Imperative Style, High-Performance Deep Learning Library.
Paszke, A., et al. · 2019 · NeurIPS 2019
Optimizing Model Parameters.
PyTorch Contributors · 2026 · Official tutorial
Deep Learning.
Goodfellow, I., Bengio, Y., Courville, A. · 2016
Mixed Precision Training.
Micikevicius, P., et al. · 2018
Automatic Mixed Precision package - torch.amp
PyTorch Contributors · 2026