LeetLLM
LearnPracticeFeaturesBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Practice
  • Features
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 155 articles completed

🛠️Computing Foundations0/6
NumPy and Tensor ShapesCUDA for ML TrainingMPS & Metal for ML on MacData Structures for AISQL and Data ModelingAlgorithms for ML Engineers
📊Math & Statistics0/8
Gradients and BackpropVectors, Matrices & TensorsLinear Algebra for MLAdam, Momentum, SchedulersProbability for Machine LearningStatistics and UncertaintyDistributions and SamplingHypothesis Tests, Intervals, and pass@k
📚Preparation & Prerequisites0/13
Neural Networks from ScratchCNNs from ScratchTraining & BackpropagationSoftmax, Cross-Entropy & OptimizationRNNs, LSTMs, GRUs, and Sequence ModelingAutoencoders and VAEsThe Transformer Architecture End-to-EndLanguage Modeling & Next TokensFrom GPT to Modern LLMsPrompt Engineering FundamentalsCalling LLM APIs in ProductionFirst AI App End-to-EndThe LLM Lifecycle
🧮ML Algorithms & Evaluation0/11
Linear Regression from ScratchLogistic Regression and MetricsDecision Trees, Forests, and BoostingReinforcement Learning BasicsValidation and LeakageClustering and PCACore Retrieval AlgorithmsDecoding AlgorithmsExperiment Design and A/B TestingPyTorch Training LoopsDataset Pipelines and Data Quality
📦Production ML Systems0/6
Feature Engineering for Production MLBatch and Streaming Feature PipelinesGradient Boosted Trees in ProductionRanking and Recommendation SystemsForecasting and Anomaly DetectionMonitoring Predictive Models
🧪Core LLM Foundations0/8
The Bitter Lesson & ComputeBPE, WordPiece, and SentencePieceStatic to Contextual EmbeddingsPerplexity & Model EvaluationFile Ingestion for AIChunking StrategiesLLM Benchmarks & LimitationsInstruction Tuning & Chat Templates
🧰Applied LLM Engineering0/23
Dimensionality Reduction for EmbeddingsCoT, ToT & Self-Consistency PromptingFunction Calling & Tool UseMCP & Tool Protocol StandardsPrompt Injection DefenseResponsible AI GovernanceData Labeling and Human FeedbackEvaluating AI AgentsProduction RAG PipelinesHybrid Search: Dense + SparseReranking and Cross-Encoders for RAGRAG Evaluation for Reliable AnswersLLM-as-a-Judge EvaluationBias & Fairness in LLMsHallucination Detection & MitigationLLM Observability & MonitoringExperiment Tracking with MLflow and W&BMixed Precision TrainingModel Versioning & DeploymentSemantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Gateways, Routing, and FallbacksDesign an Automated Support Agent
🎓Portfolio Capstones0/9
Capstone: Delivery ETA PredictionCapstone: Product RankingCapstone: Demand ForecastingCapstone: Image Damage ClassifierCapstone: Production ML PipelineCapstone: Document QACapstone: Eval DashboardCapstone: Fine-Tuned ClassifierCapstone: Production Agent
🧠Transformer Deep Dives0/8
Sentence Embeddings & Contrastive LossEmbedding Similarity & QuantizationScaled Dot-Product AttentionVision Transformers and Image EncodersPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNMechanistic InterpretabilityDecoding Strategies: Greedy to Nucleus
🧬Advanced Training & Adaptation0/16
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleBuild GPT from Scratch LabContinued Pretraining for Domain ShiftSynthetic Data PipelinesSupervised Fine-Tuning PipelineDistributed Training: FSDP & ZeROLoRA & Parameter-Efficient TuningReward Modeling from Preference DataRLHF & DPO AlignmentConstitutional AI & Red TeamingRLVR & Verifiable RewardsKnowledge Distillation for LLMsModel Merging and Weight InterpolationPrompt Optimization with DSPyRecursive Language Models (RLM)
🤖Advanced Agents & Retrieval0/14
Vector DB Internals: HNSW & IVFAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access ControlStructured Output GenerationReAct & Plan-and-ExecuteGuardrails & Safety FiltersCode Generation & SandboxingComputer-Use / GUI / Browser AgentsHuman-in-the-Loop Agent ArchitectureAI Coding Workflow with AgentsAgent Memory & PersistenceAgent Failure & RecoveryMulti-Agent Orchestration
⚡Inference & Production Scale0/20
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionPrefix Caching and Prompt CachingFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceModel Parallelism for LLM InferenceModel Quantization: GPTQ, AWQ & GGUFLocal LLM DeploymentSLM Specialization & Edge DeploymentSpeculative DecodingLong Context Window ManagementContext EngineeringMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time ComputeAdvanced MLOps & DevOps for AIGPU Serving & AutoscalingA/B Testing for LLMs
🏗️System Design Capstones0/9
Content Moderation SystemCode Completion SystemMulti-Tenant LLM PlatformLLM-Powered Search EngineVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image GenerationReal-Time Voice AI AgentReasoning & Test-Time Compute
🎤AI Lab Interviewing0/4
AI Lab Coding Interview: Python SystemsAI Lab System Design InterviewAI Lab Behavioral InterviewAI Lab Technical Presentation
Back to Topics
LearnML Algorithms & EvaluationPyTorch Training Loops
⚡MediumFine-Tuning & Training

PyTorch Training Loops

Build a PyTorch classifier from raw logits through autograd, validation, and reloadable checkpoints.

17 min read
Learning path
Step 37 of 155 in the full curriculum
Experiment Design and A/B TestingDataset Pipelines and Data Quality

The previous chapter asked whether a model-system change helps customers. This chapter answers an earlier question: how does a model acquire new behavior at all?

Suppose an e-commerce support queue receives refund tickets. Each ticket has two simple signals: urgency score and negative-sentiment score. A label says whether a specialist later judged that ticket to need escalation. At first, a small classifier guesses badly. A training loop repeatedly shows it examples, measures wrongness, computes which parameters contributed to that wrongness, and updates those parameters.

PyTorch makes that process explicit. That matters when you later fine-tune a text classifier or read a transformer training script: wrappers may add logging and distributed execution, but learning still comes from batches, logits, loss, gradients, optimizer steps, validation, and saved model state.[1][2]

Exact PyTorch optimizer update showing a two-ticket input matrix, zero initial weights, cross-entropy loss 0.6931, the gradient matrix written by backward, unchanged weights after backward, SGD weights after a 0.2 learning-rate step, and the resulting correct route logits. Exact PyTorch optimizer update showing a two-ticket input matrix, zero initial weights, cross-entropy loss 0.6931, the gradient matrix written by backward, unchanged weights after backward, SGD weights after a 0.2 learning-rate step, and the resulting correct route logits.
In the exact zero-initialized fixture, backward() fills dW but leaves W unchanged. With learning rate 0.2, optimizer.step() creates weights [[-0.1, -0.1], [0.1, 0.1]], which route the positive ticket to escalation and the negative ticket to the standard queue.

By the end, you will be able to:

  • Build and train a small nn.Module with TensorDataset and DataLoader.
  • Explain why classification loss receives raw logits and integer labels.
  • Distinguish gradient computation from parameter updates.
  • Measure held-out performance without accidentally training on validation rows.
  • Save and reload a checkpoint, then diagnose common loop failures.

Turn a ticket into tensors

A model can't consume the phrase "angry customer with delayed refund" directly in this small lab. Pretend an earlier feature pipeline already produced two normalized numbers:

FeatureMeaningExample range
urgencyLanguage indicating an immediate deadline or financial risk-2.0 to 2.0
sentimentNegative tone and complaint intensity-2.0 to 2.0
label0 = standard queue, 1 = specialist escalation0 or 1

Each training row becomes a vector of two floating-point inputs and one integer class target. A batch stacks those rows into an input tensor. Six tickets have feature shape (6, 2) and label shape (6,).

01-ticket-tensors.py
1import torch 2 3features = torch.tensor( 4 [ 5 [-2.0, -1.2], 6 [-1.3, -0.8], 7 [-0.8, -1.4], 8 [0.9, 1.1], 9 [1.4, 0.8], 10 [1.8, 1.6], 11 ], 12 dtype=torch.float32, 13) 14labels = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) 15 16print("features:", tuple(features.shape), features.dtype) 17print("labels:", tuple(labels.shape), labels.dtype) 18print("first ticket:", features[0].tolist(), "route:", labels[0].item())
Tensor contract
1features: (6, 2) torch.float32 2labels: (6,) torch.int64 3first ticket: [-2.0, -1.2000000476837158] route: 0

float32 inputs can flow through linear layers. torch.long labels are class indices for this classification objective. Keeping the shape and dtype contract visible prevents a large class of training bugs.

Logits are scores, not probabilities

A classifier ending in nn.Linear(2, 2) produces two scores for each ticket:

Output columnRoute represented
0standard queue
1specialist escalation

These scores are logits. They may be negative and they don't have to sum to one. For ordinary multi-class classification, PyTorch's CrossEntropyLoss is designed to receive those raw logits plus target class indices. It handles the log-softmax calculation in a numerically stable form.[2]

02-raw-logits.py
1import torch 2from torch import nn 3 4logits = torch.tensor([[2.0, -1.0], [-0.5, 1.5]], dtype=torch.float32) 5labels = torch.tensor([0, 1], dtype=torch.long) 6loss = nn.CrossEntropyLoss()(logits, labels) 7 8routes = ["standard", "escalate"] 9predictions = logits.argmax(dim=1).tolist() 10print("predicted routes:", [routes[index] for index in predictions]) 11print("cross entropy:", round(loss.item(), 4))
Raw logits and loss
1predicted routes: ['standard', 'escalate'] 2cross entropy: 0.0878

Failure case: Don't call softmax() and then feed those probabilities into CrossEntropyLoss for this setup. Pass raw logits. A hand-written log(softmax(logits)) calculation is also easier to make numerically unstable when a score becomes very large.

One update has five actions

For one mini-batch, training follows a fixed order:

CodePurposeParameter values change?
optimizer.zero_grad()Clear gradients left by earlier work.No
logits = model(xb)Run the current model forward.No
loss = loss_fn(logits, yb)Score current mistakes.No
loss.backward()Use autograd to fill each parameter's .grad.No
optimizer.step()Apply an update derived from stored gradients.Yes

Backpropagation is credit assignment: if raising one weight would have increased loss, its gradient points in the direction the optimizer should avoid. Autograd performs this chain-rule bookkeeping for every parameter used in the forward computation.[3][1]

Diagram showing Mini-batch xb, yb, Forward logits = model(xb), Loss cross_entropy(logits, yb), and Backward parameter .grad fields. Diagram showing Mini-batch xb, yb, Forward logits = model(xb), Loss cross_entropy(logits, yb), and Backward parameter .grad fields.
Mini-batch xb, yb, Forward logits = model(xb), Loss cross_entropy(logits, yb), and Backward parameter .grad fields.

The following lab freezes a simple linear model into known starting weights, runs exactly one update, and proves that backward() fills gradients before step() moves weights.

03-one-update.py
1import torch 2from torch import nn 3 4features = torch.tensor([[1.0, 1.0], [-1.0, -1.0]], dtype=torch.float32) 5labels = torch.tensor([1, 0], dtype=torch.long) 6model = nn.Linear(2, 2) 7with torch.no_grad(): 8 model.weight.zero_() 9 model.bias.zero_() 10 11optimizer = torch.optim.SGD(model.parameters(), lr=0.2) 12loss_fn = nn.CrossEntropyLoss() 13 14before = model.weight.detach().clone() 15optimizer.zero_grad() 16loss = loss_fn(model(features), labels) 17loss.backward() 18gradient = model.weight.grad.detach().clone() 19after_backward = model.weight.detach().clone() 20optimizer.step() 21after_step = model.weight.detach().clone() 22 23print("loss:", round(loss.item(), 4)) 24print("gradient filled:", bool(gradient.abs().sum() > 0)) 25print("changed by backward:", not torch.equal(before, after_backward)) 26print("changed by step:", not torch.equal(before, after_step))
One optimizer update
1loss: 0.6931 2gradient filled: True 3changed by backward: False 4changed by step: True

If you can point to the line where model state changes, training scripts stop looking like rituals.

Why clearing gradients isn't optional

PyTorch adds newly computed gradients into the existing .grad fields. This enables deliberate gradient accumulation, where several micro-batches contribute to one larger effective update. In a normal one-update-per-batch loop, stale gradients are accidental.

04-zero-grad.py
1import torch 2 3weight = torch.tensor(1.0, requires_grad=True) 4 5(weight**2).backward() 6first_grad = weight.grad.item() 7 8(weight**2).backward() 9accumulated_grad = weight.grad.item() 10 11weight.grad = None 12(weight**2).backward() 13cleared_grad = weight.grad.item() 14 15print("first backward:", first_grad) 16print("without clearing:", accumulated_grad) 17print("after clearing:", cleared_grad)
Accumulating gradients
1first backward: 2.0 2without clearing: 4.0 3after clearing: 2.0

Here the same derivative is silently doubled when it isn't cleared. optimizer.zero_grad() is the standard loop form; setting parameter gradients to None expresses the same reset idea in this minimal demonstration.

Mini-batches make a loop

A dataset owns examples. A data loader groups them into mini-batches and optionally changes row order between epochs. A mini-batch is large enough to estimate a useful gradient but small enough to fit memory.

05-dataloader.py
1import torch 2from torch.utils.data import DataLoader, TensorDataset 3 4features = torch.tensor( 5 [[-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [0.9, 1.1], [1.4, 0.8]], 6 dtype=torch.float32, 7) 8labels = torch.tensor([0, 0, 0, 1, 1], dtype=torch.long) 9loader = DataLoader(TensorDataset(features, labels), batch_size=2, shuffle=False) 10 11for batch_index, (xb, yb) in enumerate(loader, start=1): 12 print(f"batch {batch_index}: shape={tuple(xb.shape)} labels={yb.tolist()}")
Mini-batches
1batch 1: shape=(2, 2) labels=[0, 0] 2batch 2: shape=(2, 2) labels=[0, 1] 3batch 3: shape=(1, 2) labels=[1]

Notice the final mini-batch contains one row. When you average epoch loss, weight each batch loss by its row count so a short last batch doesn't count as much as a full one.

Build the smallest training loop

Our small model has a hidden layer so it resembles the structure of larger neural classifiers, while staying quick enough to inspect. Before measuring held-out behavior, start with an essential diagnostic: can it memorize one clean batch?

An overfit-one-batch test isn't a product metric. It is a wiring test. If a model can't drive training loss down on eight separable rows, investigate labels, shapes, loss, optimizer order, and learning rate before you scale up.

06-overfit-one-batch.py
1import torch 2from torch import nn 3 4torch.manual_seed(7) 5features = torch.tensor( 6 [ 7 [-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [-1.6, -0.4], 8 [0.9, 1.1], [1.4, 0.8], [1.8, 1.6], [0.7, 1.7], 9 ], 10 dtype=torch.float32, 11) 12labels = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long) 13model = nn.Sequential(nn.Linear(2, 8), nn.ReLU(), nn.Linear(8, 2)) 14loss_fn = nn.CrossEntropyLoss() 15optimizer = torch.optim.SGD(model.parameters(), lr=0.15) 16 17with torch.no_grad(): 18 initial_loss = loss_fn(model(features), labels).item() 19 20for _ in range(120): 21 optimizer.zero_grad() 22 loss = loss_fn(model(features), labels) 23 loss.backward() 24 optimizer.step() 25 26with torch.no_grad(): 27 final_logits = model(features) 28 final_loss = loss_fn(final_logits, labels).item() 29 accuracy = (final_logits.argmax(dim=1) == labels).float().mean().item() 30 31print("initial loss:", round(initial_loss, 4)) 32print("final loss:", round(final_loss, 4)) 33print("memorized batch:", accuracy == 1.0)
Overfit one batch
1initial loss: 0.6521 2final loss: 0.0056 3memorized batch: True

Loss should fall sharply and memorized batch should be true. Only now is it sensible to ask whether training generalizes.

Separate training from validation

Training rows are allowed to change weights. Validation rows are held out from updates and answer a different question: does the current model work on unseen labeled tickets?

Use two switches during validation:

SwitchWhat it controls
model.eval()Evaluation behavior for modules such as Dropout and BatchNorm.
torch.no_grad()Stops building gradient history for the validation computation.

The next complete loop trains on eight tickets and selects a checkpoint using four held-out tickets. The data is deliberately easy so you can focus on mechanics rather than feature engineering.[2]

07-train-and-validate.py
1import copy 2import torch 3from torch import nn 4from torch.utils.data import DataLoader, TensorDataset 5 6torch.manual_seed(7) 7train_x = torch.tensor( 8 [ 9 [-2.0, -1.2], [-1.3, -0.8], [-0.8, -1.4], [-1.6, -0.4], 10 [0.9, 1.1], [1.4, 0.8], [1.8, 1.6], [0.7, 1.7], 11 ], 12 dtype=torch.float32, 13) 14train_y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long) 15val_x = torch.tensor([[-1.1, -0.6], [-0.5, -1.7], [1.0, 0.6], [1.7, 0.4]]) 16val_y = torch.tensor([0, 0, 1, 1], dtype=torch.long) 17 18loader = DataLoader(TensorDataset(train_x, train_y), batch_size=4, shuffle=True) 19model = nn.Sequential(nn.Linear(2, 8), nn.ReLU(), nn.Linear(8, 2)) 20loss_fn = nn.CrossEntropyLoss() 21optimizer = torch.optim.SGD(model.parameters(), lr=0.15) 22best_loss = float("inf") 23best_state = None 24 25for epoch in range(1, 61): 26 model.train() 27 for xb, yb in loader: 28 optimizer.zero_grad() 29 loss = loss_fn(model(xb), yb) 30 loss.backward() 31 optimizer.step() 32 33 model.eval() 34 with torch.no_grad(): 35 val_logits = model(val_x) 36 val_loss = loss_fn(val_logits, val_y).item() 37 if val_loss < best_loss: 38 best_loss = val_loss 39 best_state = copy.deepcopy(model.state_dict()) 40 41model.load_state_dict(best_state) 42model.eval() 43with torch.no_grad(): 44 selected_accuracy = (model(val_x).argmax(dim=1) == val_y).float().mean().item() 45print("best validation loss:", round(best_loss, 4)) 46print("selected validation accuracy:", round(selected_accuracy, 3)) 47print("saved tensors:", len(best_state))
Training and validation
1best validation loss: 0.0118 2selected validation accuracy: 1.0 3saved tensors: 4

copy.deepcopy matters here. A plain best_state = model.state_dict() would keep references to tensors that later epochs continue to change. Saving immediately with torch.save() is another reliable checkpoint-selection pattern.

Modes change model behavior

The linear classifier would produce the same result in train() and eval() because it contains no mode-sensitive layer. Real models often contain Dropout. During training, Dropout randomly hides activations to regularize learning. During evaluation, it must be disabled so one ticket receives stable scores.

08-train-eval-modes.py
1import torch 2from torch import nn 3 4torch.manual_seed(0) 5model = nn.Sequential(nn.Linear(2, 6), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(6, 2)) 6ticket = torch.tensor([[1.2, 0.9]]) 7 8model.train() 9train_a = model(ticket) 10train_b = model(ticket) 11 12model.eval() 13with torch.no_grad(): 14 eval_a = model(ticket) 15 eval_b = model(ticket) 16 17print("training outputs equal:", torch.equal(train_a, train_b)) 18print("evaluation outputs equal:", torch.equal(eval_a, eval_b))
Train mode versus eval mode
1training outputs equal: False 2evaluation outputs equal: True

This is why calling only torch.no_grad() isn't enough for evaluation: gradient tracking and layer behavior are separate concerns.

Log-scaled training and validation loss curves from the seeded 60-epoch ticket classifier, with measured points from epochs 1 through 60, validation performed with eval and no_grad, and epoch 60 selected as the best checkpoint at validation loss 0.0118 and accuracy 1.0. Log-scaled training and validation loss curves from the seeded 60-epoch ticket classifier, with measured points from epochs 1 through 60, validation performed with eval and no_grad, and epoch 60 selected as the best checkpoint at validation loss 0.0118 and accuracy 1.0.
In the seeded fixture, validation loss falls from 0.535825 at epoch 1 to 0.011815 at epoch 60, which becomes the selected checkpoint with held-out accuracy 1.0. Validation uses eval() and no_grad(); it never calls backward() or step().

Save evidence you can reload

Training isn't complete when the process exits. You need a portable artifact and enough configuration to reproduce how it was created. For inference, saving a model's state_dict is the standard flexible pattern; for resuming training, also save optimizer state and current epoch.

09-save-and-reload.py
1from pathlib import Path 2from tempfile import TemporaryDirectory 3 4import torch 5from torch import nn 6 7torch.manual_seed(11) 8features = torch.tensor([[-1.5, -1.0], [-0.9, -1.2], [1.1, 0.8], [1.6, 1.4]]) 9labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) 10model = nn.Linear(2, 2) 11optimizer = torch.optim.SGD(model.parameters(), lr=0.2) 12loss_fn = nn.CrossEntropyLoss() 13 14for _ in range(80): 15 optimizer.zero_grad() 16 loss = loss_fn(model(features), labels) 17 loss.backward() 18 optimizer.step() 19 20model.eval() 21with torch.no_grad(): 22 original = model(features).argmax(dim=1) 23 24with TemporaryDirectory() as directory: 25 path = Path(directory) / "ticket-router.pt" 26 torch.save({"model_state_dict": model.state_dict(), "label_names": ["standard", "escalate"]}, path) 27 checkpoint = torch.load(path, map_location="cpu", weights_only=True) 28 restored = nn.Linear(2, 2) 29 restored.load_state_dict(checkpoint["model_state_dict"]) 30 restored.eval() 31 with torch.no_grad(): 32 reloaded = restored(features).argmax(dim=1) 33 34print("routes:", original.tolist()) 35print("reload agrees:", torch.equal(original, reloaded)) 36print("labels:", checkpoint["label_names"])
Reloaded checkpoint
1routes: [0, 0, 1, 1] 2reload agrees: True 3labels: ['standard', 'escalate']

In production, save the feature schema, label mapping, split or dataset version, model architecture configuration, and metric used to choose the checkpoint. A file of weights without that receipt is difficult to audit.

Debug failures before scaling

When a loop runs without crashing, it can still be wrong. Debug in a narrow order:

SymptomLikely causeCheck
CrossEntropyLoss rejects labelsClass targets have floating dtype or wrong shape.Use integer indices with shape (batch,).
Loss falls in one-batch test but validation is poorSplit contamination, weak features, or overfitting.Inspect data and move to next dataset chapter.
Results vary during validationModel remains in training mode with Dropout or BatchNorm.Pair model.eval() with validation.
Gradients grow each batch unexpectedlyGradients aren't cleared before the next backward pass.Put optimizer.zero_grad() inside batch loop.
GPU memory grows while logging batch lossesMetric code retains graph-connected loss tensors.Accumulate loss.item() or detached scalars, not loss.
Loss becomes nan or infInvalid input, unstable loss, or exploding update.Check tensors and gradients for finite values.

Gradient clipping doesn't fix bad data or a broken loss, but it can cap an extreme update once you have identified that gradient norms are the issue.

Numerical triage when loss becomes NaN

When loss becomes nan or inf, locate the first invalid number instead of immediately lowering the learning rate. Test inputs first, then logits, then loss, then gradients after backward(), then parameters after step(). That sequence tells you whether failure entered through data, unstable arithmetic, gradient scaling, or an oversized update.

10-gradient-clipping.py
1import torch 2from torch import nn 3 4parameter = nn.Parameter(torch.tensor([3.0, 4.0])) 5parameter.grad = torch.tensor([6.0, 8.0]) 6 7before = torch.linalg.vector_norm(parameter.grad).item() 8reported_norm = nn.utils.clip_grad_norm_([parameter], max_norm=5.0, error_if_nonfinite=True).item() 9after = torch.linalg.vector_norm(parameter.grad).item() 10 11print("norm before:", round(before, 1)) 12print("reported norm:", round(reported_norm, 1)) 13print("norm after:", round(after, 1)) 14print("finite:", torch.isfinite(parameter.grad).all().item())
Gradient clipping check
1norm before: 10.0 2reported norm: 10.0 3norm after: 5.0 4finite: True

clip_grad_norm_() returns the total norm before clipping. In a live loop, pass error_if_nonfinite=True as above so a non-finite total norm raises before it can enter optimizer state. Also check torch.isfinite(xb), torch.isfinite(logits), torch.isfinite(loss), and parameter gradients at the first failing step. Find where bad numbers appear before changing hyperparameters blindly.

Inference shouldn't build training history

When serving the selected ticket router, you aren't going to call backward(). torch.inference_mode() disables gradient tracking and additional autograd bookkeeping for computations that won't re-enter a gradient-tracked training graph. Use torch.no_grad() during simple validation when you want the familiar training-script pattern; use inference_mode() for committed inference paths after testing that contract.

11-inference-mode.py
1import torch 2from torch import nn 3 4torch.manual_seed(17) 5model = nn.Linear(2, 2) 6model.eval() 7ticket = torch.tensor([[1.5, 1.2]]) 8 9with torch.inference_mode(): 10 logits = model(ticket) 11 prediction = logits.argmax(dim=1).item() 12 13print("gradient tracking:", logits.requires_grad) 14print("route class:", prediction)
Inference receipt
1gradient tracking: False 2route class: 1

Optional GPU extension: automatic mixed precision

The loop above is the core contract. On compatible accelerators, automatic mixed precision (AMP) may reduce memory use and increase throughput by letting suitable operations use lower precision while preserving higher precision where needed. With float16 training, gradient scaling through GradScaler reduces the risk that small gradients underflow before the update. PyTorch's AMP examples place autocast around forward and loss computation, scale before backward(), unscale before clipping, then step and update the scaler.[4][5]

12-amp-ordering-pattern.py
1import torch 2from torch import nn 3 4device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5use_amp = device.type == "cuda" 6model = nn.Linear(2, 2).to(device) 7optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 8loss_fn = nn.CrossEntropyLoss() 9xb = torch.tensor([[1.2, 0.8], [-1.0, -0.7]], device=device) 10yb = torch.tensor([1, 0], dtype=torch.long, device=device) 11scaler = torch.amp.GradScaler("cuda", enabled=use_amp) 12 13optimizer.zero_grad() 14with torch.amp.autocast("cuda", dtype=torch.float16, enabled=use_amp): 15 logits = model(xb) 16 loss = loss_fn(logits, yb) 17 18scaler.scale(loss).backward() 19scaler.unscale_(optimizer) 20torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 21scaler.step(optimizer) 22scaler.update() 23 24print("step completed with finite loss:", bool(torch.isfinite(loss)))
AMP-compatible step
1step completed with finite loss: True

GradScaler.step() checks for non-finite gradients after unscaling and skips the optimizer update when it finds them. GradScaler.update() then adjusts the scale for later iterations. Treat AMP as an extension after a full-precision loop passes its correctness checks. Faster incorrect training is still incorrect training.

A training receipt for the next chapter

This chapter made a model learn from already prepared tensors. A real ticket-routing model still needs evidence for where its labeled rows came from, whether duplicates leaked across splits, and which preprocessing produced those two features.

Before declaring a training run acceptable, retain:

Receipt itemWhy it matters
Model and optimizer configurationExplains how updates were produced.
Training and validation loss curvesShows optimization and held-out behavior separately.
Best checkpoint and selection metricIdentifies which weights are being served.
Feature schema and label mappingMakes inference compatible with training.
Dataset version and split identityLets reviewers investigate leakage or relabeling.

Practice: Break and repair the loop

Use the runnable examples above as a controlled failure lab. Make one change at a time, predict the result, then run the example and explain the observed behavior.

  1. In 03-one-update.py, comment out optimizer.step(). Predict whether changed by step remains True.
  2. In 04-zero-grad.py, call backward() a third time before clearing gradients. Predict the gradient value.
  3. In 07-train-and-validate.py, replace copy.deepcopy(model.state_dict()) with model.state_dict(). Explain why best_state no longer preserves the best epoch after later updates.
  4. In 08-train-eval-modes.py, leave the model in training mode for the second pair of predictions. Predict whether the outputs must match.
  5. In 12-amp-ordering-pattern.py, explain why scaler.unscale_(optimizer) belongs before gradient clipping.

Expected observations

  1. Without optimizer.step(), gradients are populated but parameter values don't change.
  2. The gradient becomes 6.0: each backward() adds another derivative of 2.0 until gradients are cleared.
  3. A plain state_dict() result keeps references to parameter storage. Later training updates overwrite the intended snapshot.
  4. Dropout remains active in training mode, so repeated outputs can vary. A matching pair can still occur by chance.
  5. AMP scales gradients before backward(). Clipping before unscale_() would apply the threshold to scaled values instead of the gradients used for the real update.

Explain the loop without looking back

Explain the complete route from one ticket batch to a saved model without using the word "magic." Then answer these checks.

Evaluation rubric

  • Foundational: Identifies logits, loss, gradients, and optimizer updates, including the difference between backward() and step().
  • Intermediate: Builds a runnable mini-batch loop with held-out validation, correct modes, and a reloadable checkpoint.
  • Advanced: Diagnoses gradient, numerical, or split-quality failures and explains when AMP or clipping belongs in a validated training system.

Common Pitfalls

  • Applying softmax() before CrossEntropyLoss instead of giving the loss raw logits.
  • Calling backward() batch after batch without clearing gradients or intentionally scheduling accumulation.
  • Reporting validation while the model remains in training mode or while held-out tickets influence updates.
  • Serving a checkpoint without its label mapping, feature schema, split identity, and selection metric.

Reuse this training pattern

  • A small ticket-escalation classifier with an explicit PyTorch optimization loop.
  • A one-batch diagnostic that fails early when a new loop is wired incorrectly.
  • A validation-and-checkpoint path that selects weights without learning from held-out tickets.
  • A training receipt that the data-quality pipeline can make reproducible.

Key Terms

  • Logit: Raw class score emitted before a probability normalization.
  • Loss: Scalar objective measuring prediction error for optimization.
  • Autograd: PyTorch system that computes derivatives through recorded operations.
  • Gradient: Directional derivative stored for a parameter after backward().
  • Mini-batch: Group of examples used for one gradient estimate.
  • Epoch: One traversal through training examples.
  • Checkpoint: Saved model state, often paired with configuration and training metadata.
Complete the lesson

Mastery Check

Answer every question, then check your score. Score above 75% to mark this lesson complete.

1.A ticket router ends with nn.Linear(2, 2). For a mini-batch, logits = model(xb) has shape (4, 2) and yb = torch.tensor([0, 1, 0, 1], dtype=torch.long). For ordinary class-index classification in PyTorch, which loss call correctly passes raw unnormalized class scores and integer targets?
2.In a one-batch lab, before = model.weight.detach().clone() is saved, then the code runs optimizer.zero_grad(); loss = loss_fn(model(xb), yb); loss.backward(). At this point model.weight.grad is nonzero, but the weight still equals before. Which statement explains this state?
3.A PyTorch model contains Dropout, and validation must measure stable held-out loss without updating weights or recording gradient history. Which validation pattern is correct?
4.You run an overfit-one-batch diagnostic on eight clean separable ticket rows. After 120 optimizer updates, loss has not fallen and accuracy is still poor. What should you check before scaling up or judging generalization?
5.A saved ticket-router file contains only model_state_dict. A teammate can load weights but must also audit how to run the model and why that checkpoint was chosen. Which receipt should accompany the weights?
6.In a one-update-per-batch training loop, optimizer.zero_grad() is called once before the epoch instead of inside the mini-batch loop. What happens after the second loss.backward()?
7.A loader produces batches of 2, 2, and 1 rows. Their reported mean losses are 0.20, 0.40, and 1.00. What is the correctly row-weighted mean loss for the epoch?
8.Validation loss reaches its minimum at epoch 12, but training will continue. Which in-memory assignment preserves the exact model tensors from epoch 12?
9.A previously finite training run begins producing nan loss. Which investigation most directly locates where the first invalid value enters the update?
10.A checkpoint stores a trained model's state_dict. Which sequence correctly restores it for inference?

10 questions remaining.

Next Step
Continue to Dataset Pipelines and Data Quality

You can now train and validate a model from prepared tensors. Next you will build trustworthy train, validation, and test artifacts from messy raw records.

PreviousExperiment Design and A/B Testing
Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail
References

PyTorch: An Imperative Style, High-Performance Deep Learning Library.

Paszke, A., et al. · 2019 · NeurIPS 2019

Optimizing Model Parameters.

PyTorch Contributors · 2026 · Official tutorial

Deep Learning.

Goodfellow, I., Bengio, Y., Courville, A. · 2016

Mixed Precision Training.

Micikevicius, P., et al. · 2018

Automatic Mixed Precision package - torch.amp

PyTorch Contributors · 2026