Measure how FP16 and BF16 affect training range, update precision, memory, and release evidence before enabling faster low-precision compute.
The previous lesson recorded an application fix as a controlled experiment. The same discipline applies when the changed artifact is a trained model: a faster run isn't an improvement if it corrupts updates, overflows, or lowers the declared held-out metric.
Suppose the order-support pipeline now needs a small carrier-event classifier that recognizes whether a delivery estimate is supported by a scan. You want to fine-tune it faster, so the proposed run changes its precision policy from FP32 to FP16 or BF16. This lesson builds the run evidence you would need before choosing that policy.
Mixed precision training runs selected expensive operations in a compact floating-point format while retaining higher precision where training is fragile. In modern PyTorch, model parameters normally remain FP32 while autocast chooses lower-precision computation for eligible operations.[1]
Start as you would in an experiment tracker: state the artifact under test and its guardrails before measuring candidates.
1from dataclasses import dataclass
2import torch
3
4@dataclass(frozen=True)
5class PrecisionContract:
6 experiment: str
7 artifact: str
8 baseline: str
9 candidates: tuple[str, ...]
10 required_supported_evidence_f1: float
11 permitted_nonfinite_steps: int
12
13contract = PrecisionContract(
14 experiment="carrier-event-encoder-precision",
15 artifact="carrier-evidence-classifier-v2",
16 baseline="fp32",
17 candidates=("fp16_unscaled", "fp16_scaled", "bf16"),
18 required_supported_evidence_f1=0.92,
19 permitted_nonfinite_steps=0,
20)
21
22print(f"experiment={contract.experiment}")
23print(f"artifact={contract.artifact}")
24print(f"baseline={contract.baseline}")
25print(f"candidates={','.join(contract.candidates)}")
26print(f"metric_gate=supported_evidence_f1>={contract.required_supported_evidence_f1:.2f}")
27print(f"nonfinite_steps_gate={contract.permitted_nonfinite_steps}")1experiment=carrier-event-encoder-precision
2artifact=carrier-evidence-classifier-v2
3baseline=fp32
4candidates=fp16_unscaled,fp16_scaled,bf16
5metric_gate=supported_evidence_f1>=0.92
6nonfinite_steps_gate=0The job isn't to declare BF16 good or FP16 bad in the abstract. It's to understand what each format can lose, then measure the acceptable candidates under the same validation and hardware conditions.
A floating-point value is similar to scientific notation: a sign, a scale, and significant digits. Its exponent controls range, meaning how tiny or large a magnitude it can represent. Its fraction (often called the mantissa in training discussions) controls resolution, meaning how close two neighboring values can be.
| Format | Bits | Exponent bits | Fraction bits | Main training consequence |
|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | Wide range and fine update resolution, with higher storage cost |
| FP16 | 16 | 5 | 10 | Compact, but small gradients can underflow and large values can overflow |
| BF16 | 16 | 8 | 7 | Compact with FP32-like range, but coarser nearby resolution |
The important correction is easy to miss: BF16 improves range relative to FP16; it doesn't improve nearby resolution. BF16 has fewer fraction bits than FP16. That is why BF16 compute still normally updates FP32 parameters.
PyTorch exposes the exact format limits with torch.finfo. Run this on CPU; no accelerator is required to inspect the number system.
1formats = (
2 ("FP32", torch.float32),
3 ("FP16", torch.float16),
4 ("BF16", torch.bfloat16),
5)
6
7print("format epsilon_at_1 smallest_normal largest_finite")
8for label, dtype in formats:
9 info = torch.finfo(dtype)
10 print(f"{label:<6} {info.eps:>12.1e} {info.tiny:>15.1e} {info.max:>14.1e}")
11
12print(f"bf16_min_normal_matches_fp32={torch.finfo(torch.bfloat16).tiny == torch.finfo(torch.float32).tiny}")
13print(f"bf16_resolution_coarser_than_fp16={torch.finfo(torch.bfloat16).eps > torch.finfo(torch.float16).eps}")1format epsilon_at_1 smallest_normal largest_finite
2FP32 1.2e-07 1.2e-38 3.4e+38
3FP16 9.8e-04 6.1e-05 6.6e+04
4BF16 7.8e-03 1.2e-38 3.4e+38
5bf16_min_normal_matches_fp32=True
6bf16_resolution_coarser_than_fp16=Trueepsilon_at_1 is the spacing between 1.0 and the next representable value near it. smallest_normal and largest_finite describe range. FP16 gives finer resolution than BF16 near 1.0, but far less range.
Assume the classifier has a weight at 1.0. One optimizer step wants to subtract 0.0001. This is much larger than the smallest BF16 magnitude, but smaller than the spacing between BF16 values near 1.0.
1weight = torch.tensor([1.0], dtype=torch.float32)
2update = torch.tensor([1.0e-4], dtype=torch.float32)
3
4for label, dtype in formats:
5 before = weight.to(dtype)
6 after = before - update.to(dtype)
7 changed = bool(after.item() != before.item())
8 print(f"{label}: stored_after_step={after.item():.8f}, update_survived={changed}")
9
10print("lesson=BF16 protects range; FP32 protects small accumulated updates")1FP32: stored_after_step=0.99989998, update_survived=True
2FP16: stored_after_step=1.00000000, update_survived=False
3BF16: stored_after_step=1.00000000, update_survived=False
4lesson=BF16 protects range; FP32 protects small accumulated updatesBoth 16-bit parameter values lose this update. The original mixed-precision recipe used an FP32 master copy of the weights so small updates accumulate instead of disappearing.[2] Modern PyTorch AMP normally gets the same protection by leaving parameters in FP32 and autocasting eligible forward operations rather than converting the parameter storage itself.[1]
Resolution is one issue. Range is another. Compare an extremely small gradient and a very large activation-like value when stored in FP16 and BF16.
1values = torch.tensor([1.2e-8, 1.0e5], dtype=torch.float32)
2
3for label, dtype in (("FP16", torch.float16), ("BF16", torch.bfloat16)):
4 cast = values.to(dtype)
5 print(
6 f"{label}: small={cast[0].item():.2e}, "
7 f"large={cast[1].item():.2e}, "
8 f"all_finite={bool(torch.isfinite(cast).all())}"
9 )
10
11print("fp16_loses_small_and_large=True")
12print("bf16_keeps_range_in_this_example=True")1FP16: small=0.00e+00, large=inf, all_finite=False
2BF16: small=1.20e-08, large=9.98e+04, all_finite=True
3fp16_loses_small_and_large=True
4bf16_keeps_range_in_this_example=TrueFP16's smallest positive normal number is roughly , and its subnormal floor is about . A true gradient of becomes zero in FP16. At the other end, 100000 is beyond FP16's largest finite value of 65504, so it becomes Inf.
BF16 keeps the 8-bit exponent width of FP32, giving it a similar range and making these two magnitudes representable, although rounded. The BF16 training study documents that wider range as its main stability advantage over FP16.[3]
For FP16, loss scaling moves gradient magnitudes into a representable interval during backpropagation. Multiply loss by a scale ; the chain rule multiplies each gradient by too. After backward, divide gradients by in FP32 before applying the optimizer step. The intended update has not changed.
For a true gradient of :
| Operation | Value | FP16 outcome |
|---|---|---|
| Cast unscaled gradient | Rounds to zero | |
| Multiply by during backward | Representable | |
| Convert to FP32 and divide by | approximately | Ready for FP32 update |
1true_grad = torch.tensor([1.2e-8], dtype=torch.float32)
2scale = 1024.0
3
4plain_fp16 = true_grad.to(torch.float16)
5scaled_fp16 = (true_grad * scale).to(torch.float16)
6recovered_fp32 = scaled_fp16.to(torch.float32) / scale
7
8print(f"plain_underflowed={plain_fp16.item() == 0.0}")
9print(f"scaled_visible={scaled_fp16.item() > 0.0}")
10print(f"recovered_grad={recovered_fp32.item():.2e}")
11print(f"recovery_relative_error={abs(recovered_fp32.item() - true_grad.item()) / true_grad.item():.3%}")1plain_underflowed=True
2scaled_visible=True
3recovered_grad=1.20e-08
4recovery_relative_error=0.077%
A fixed scale that saves the smallest gradient may overflow a larger gradient in the same step. Dynamic scaling therefore has two outcomes: apply a finite, descaled update, or skip an overflowed step and reduce the scale.
1def scaled_step_status(gradients: torch.Tensor, scale: float) -> tuple[str, float]:
2 scaled = (gradients * scale).to(torch.float16)
3 if not bool(torch.isfinite(scaled).all()):
4 return "SKIP_OVERFLOW", scale / 2
5 return "APPLY_DESCALED_UPDATE", scale
6
7quiet_step = torch.tensor([1.2e-8, 2.0e-2], dtype=torch.float32)
8spiky_step = torch.tensor([1.2e-8, 1.0e2], dtype=torch.float32)
9
10quiet_status, quiet_next_scale = scaled_step_status(quiet_step, 1024.0)
11spiky_status, spiky_next_scale = scaled_step_status(spiky_step, 1024.0)
12
13print(f"quiet_step={quiet_status}, next_scale={quiet_next_scale:.0f}")
14print(f"spiky_step={spiky_status}, next_scale={spiky_next_scale:.0f}")
15print("invariant=never_apply_nonfinite_gradients")1quiet_step=APPLY_DESCALED_UPDATE, next_scale=1024
2spiky_step=SKIP_OVERFLOW, next_scale=512
3invariant=never_apply_nonfinite_gradientsPyTorch's torch.amp.GradScaler performs this scale, unscale, finite-check, skip, and update control flow for FP16 training. PyTorch also documents that if you inspect or clip gradients, you must call scaler.unscale_(optimizer) before clipping so thresholds apply to true gradient magnitudes.[1]
Loss scaling isn't a general extension of FP16 range. It rescues small backward gradients that would underflow, but it can't make a forward activation above 65504 representable. PyTorch also warns that GradScaler may reduce its scale below 1 for overflow-prone models, so don't assume the scale always grows or stays above 1.[1]
Loss scaling protects FP16 gradients from range failure. It doesn't make 16-bit parameter storage appropriate for tiny updates. Preserve the FP32 update path separately.
1step = torch.tensor([1.0e-4], dtype=torch.float32)
2fp16_parameter = torch.tensor([1.0], dtype=torch.float16)
3fp32_parameter = torch.tensor([1.0], dtype=torch.float32)
4
5fp16_after = fp16_parameter - step.to(torch.float16)
6fp32_after = fp32_parameter - step
7
8print(f"fp16_parameter_changed={fp16_after.item() != fp16_parameter.item()}")
9print(f"fp32_parameter_changed={fp32_after.item() != fp32_parameter.item()}")
10print(f"fp32_after={fp32_after.item():.8f}")
11print("policy=low_precision_compute_with_fp32_update_state")1fp16_parameter_changed=False
2fp32_parameter_changed=True
3fp32_after=0.99989998
4policy=low_precision_compute_with_fp32_update_stateThe original paper describes copying FP32 master weights into a low-precision compute copy.[2] With ordinary AMP, PyTorch parameters remain FP32, autocast selects lower precision for eligible compute, and the optimizer updates the FP32 parameters directly.[1]
This is the CUDA shape you would use for a real fine-tuning run. It isn't marked executable here because it needs an accelerator and a model workload:
1dtype = torch.bfloat16 # compare against torch.float16 in a controlled run
2use_scaler = dtype == torch.float16
3scaler = torch.amp.GradScaler("cuda", enabled=use_scaler)
4
5for batch, target in dataloader:
6 optimizer.zero_grad(set_to_none=True)
7 with torch.autocast(device_type="cuda", dtype=dtype):
8 logits = model(batch.cuda())
9 loss = criterion(logits, target.cuda())
10
11 if scaler.is_enabled():
12 scaler.scale(loss).backward()
13 scaler.unscale_(optimizer)
14 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
15 scaler.step(optimizer)
16 scaler.update()
17 else:
18 loss.backward()
19 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
20 optimizer.step()For BF16, skipping GradScaler is a common policy because BF16's exponent range avoids the FP16 failure that loss scaling targets. It isn't a guarantee that every BF16 training job is stable. Bad data, unstable losses, overly large learning rates, or sensitive kernels can still produce non-finite values.
When autocast runs an eligible operation in BF16 or FP16, its saved low-precision activation values use two bytes rather than four. Other operations stay in or return FP32 for numerical safety. Total training memory doesn't necessarily halve because FP32 parameters and optimizer moments may remain unchanged.
The next accounting exercise uses a deliberately small inventory: 100 million parameters, their gradients, two Adam moment buffers, and 800 million stored activation values. To make the arithmetic visible, assume those inventoried activations were saved in low precision for the AMP candidate. A real profile may retain FP32 values for some operations. This is a budget calculation, not a measured GPU profile.
1def gib(values: int, bytes_per_value: int) -> float:
2 return values * bytes_per_value / (1024 ** 3)
3
4parameter_values = 100_000_000
5activation_values = 800_000_000
6
7fp32_budget = {
8 "parameters": gib(parameter_values, 4),
9 "gradients": gib(parameter_values, 4),
10 "adam_moments": gib(parameter_values * 2, 4),
11 "activations": gib(activation_values, 4),
12}
13amp_budget = {
14 **{name: value for name, value in fp32_budget.items() if name != "activations"},
15 "activations": gib(activation_values, 2),
16}
17
18print(f"fp32_total_gib={sum(fp32_budget.values()):.2f}")
19print(f"amp_total_gib={sum(amp_budget.values()):.2f}")
20print(f"activation_saving_gib={fp32_budget['activations'] - amp_budget['activations']:.2f}")
21print(f"total_reduction={(1 - sum(amp_budget.values()) / sum(fp32_budget.values())):.1%}")
22print("lesson=half_size_activations_do_not_imply_half_total_memory")1fp32_total_gib=4.47
2amp_total_gib=2.98
3activation_saving_gib=1.49
4total_reduction=33.3%
5lesson=half_size_activations_do_not_imply_half_total_memory
For large models, sharding methods such as ZeRO and Fully Sharded Data Parallel (FSDP) address parameters, gradients, and optimizer-state memory that activation casting alone doesn't remove.[4][5]
When workers exchange gradients, the network payload has its own precision policy. FSDP's MixedPrecision configuration exposes param_dtype for computation and reduce_dtype for gradient reduction; they need not match.[6]
1gradient_values = 100_000_000
2fp32_reduce_gib = gib(gradient_values, 4)
3bf16_reduce_gib = gib(gradient_values, 2)
4
5print(f"gradient_payload_fp32_gib={fp32_reduce_gib:.2f}")
6print(f"gradient_payload_bf16_gib={bf16_reduce_gib:.2f}")
7print(f"payload_reduction={(1 - bf16_reduce_gib / fp32_reduce_gib):.0%}")
8print("warning=compute_dtype_does_not_prove_reduce_dtype")1gradient_payload_fp32_gib=0.37
2gradient_payload_bf16_gib=0.19
3payload_reduction=50%
4warning=compute_dtype_does_not_prove_reduce_dtypeA job can use BF16 for matrix computations and still communicate FP32 gradient payloads. Therefore a trustworthy run record separates compute_dtype, update_storage_dtype, and reduce_dtype rather than logging a single mixed_precision=true flag.
The final cell brings the lesson back to experiment tracking. The numbers below are illustrative recorded outcomes, not benchmark claims. They show the review rule you should apply after running the same classifier, data fingerprint, seed policy, held-out supported_evidence_f1 evaluation, and hardware profile for every precision configuration.
1@dataclass(frozen=True)
2class PrecisionRun:
3 run_id: str
4 policy: str
5 supported_evidence_f1: float
6 nonfinite_steps: int
7 peak_memory_gib: float
8 evidence: str
9
10runs = (
11 PrecisionRun("run_fp32", "fp32", 0.93, 0, 4.47, "illustrative_fixture"),
12 PrecisionRun("run_fp16_plain", "fp16_unscaled", 0.88, 3, 2.98, "illustrative_fixture"),
13 PrecisionRun("run_fp16_scaled", "fp16_scaled", 0.93, 0, 2.98, "illustrative_fixture"),
14 PrecisionRun("run_bf16", "bf16", 0.93, 0, 2.98, "illustrative_fixture"),
15)
16
17def passes(run: PrecisionRun) -> bool:
18 return (
19 run.supported_evidence_f1 >= contract.required_supported_evidence_f1
20 and run.nonfinite_steps <= contract.permitted_nonfinite_steps
21 )
22
23eligible = [run.policy for run in runs if run.policy != contract.baseline and passes(run)]
24rejected = [run.policy for run in runs if not passes(run)]
25
26print(f"eligible_candidates={','.join(eligible)}")
27print(f"rejected_candidates={','.join(rejected)}")
28print("decision=BLOCKED_PENDING_TARGET_GPU_PROFILE")
29print("next_metrics=throughput,peak_memory,supported_evidence_f1,nonfinite_steps")1eligible_candidates=fp16_scaled,bf16
2rejected_candidates=fp16_unscaled
3decision=BLOCKED_PENDING_TARGET_GPU_PROFILE
4next_metrics=throughput,peak_memory,supported_evidence_f1,nonfinite_stepsThe right result isn't "BF16 wins because it's modern." Both scaled FP16 and BF16 pass this small fixture, and both require a real measurement on target hardware. BF16 is often simpler to operate because it commonly avoids loss scaling, but only the controlled run can justify promotion.
FP8 reduces compute storage again, but its reduced range and resolution require managed scaling recipes. FP8 isn't one layout: the FP8 formats paper specifies complementary E4M3 and E5M2 encodings for deep-learning workloads. NVIDIA Transformer Engine documents a hybrid recipe that uses E4M3 during the forward pass and E5M2 during the backward pass, plus delayed, current, and block-scaling recipes for supported accelerators.[7][8]
That is enough orientation here. Don't add FP8 to a training proposal until BF16 or scaled FP16 is measured, quality gates exist, and the team can operate the scaling policy. Precision work should reduce measured cost without creating unexplained convergence risk.
GradScaler control flowtorch.finfo results and explains why FP16 and BF16 fail in different ways.Symptom: A BF16-only parameter update stops improving loss even though gradients are finite. Cause: Wide range was confused with fine resolution near current weights. Fix: Keep FP32 update state under ordinary AMP and log the storage policy.
Symptom: Training appears stable but supported_evidence_f1 lags the FP32 baseline.
Cause: Small unscaled FP16 gradients underflow to zero.
Fix: Use GradScaler for FP16, track non-finite or skipped steps, and compare the declared held-out metric against the same baseline.
Symptom: Clipping behaves erratically or training diverges under FP16 AMP.
Cause: The run clips gradients before scaler.unscale_(optimizer).
Fix: Unscale first, then clip, then let the scaler perform or skip the optimizer step.
Symptom: "Half-memory" planning fails when the job is scheduled. Cause: Only activation dtype changed while FP32 parameters and Adam moments remain large. Fix: Log a component-level memory profile or accounting budget, not a dtype slogan.
Symptom: BF16 compute is enabled, but cross-worker traffic is still a bottleneck.
Cause: Reduction payloads remain FP32.
Fix: Inspect and record reduce_dtype separately, then measure held-out metric and communication changes before promotion.
Automatic Mixed Precision package - torch.amp
PyTorch Contributors · 2026
Mixed Precision Training.
Micikevicius, P., et al. · 2018
A Study of BFLOAT16 for Deep Learning Training.
Kalamkar, D., et al. · 2019
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
Rajbhandari, S., et al. · 2020 · SC 2020
PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel.
Zhao, Y., et al. · 2023 · VLDB 2023
FullyShardedDataParallel
PyTorch Contributors · 2025
FP8 Formats for Deep Learning.
Micikevicius, P., et al. · 2022
Using FP8 and FP4 with Transformer Engine
NVIDIA · 2026