Understand ZeRO stages, current FSDP1 vs FSDP2 guidance, and when native PyTorch or DeepSpeed is the right choice for large-model training.
The SFT pipeline chapter ended with a trustworthy single-device training recipe. This chapter handles the next limit: full-weight updates may be correct for the task, but one GPU can't hold the states and activations needed to execute them. Fully Sharded Data Parallel (FSDP) and ZeRO (Zero Redundancy Optimizer) reduce repeated model-state storage across data-parallel ranks, then make communication part of the training cost.
Start with the memory contract. One useful sizing convention, used in the ZeRO analysis, assumes low-precision parameters and gradients plus an FP32 master parameter copy and FP32 Adam moments.[1] Under that specific recipe, each parameter costs 16 bytes of model-state memory before activations, temporary buffers, or allocator overhead:
bfloat16, short for Brain Floating Point) / FP16 (16-bit half precision) weightFor a 1B model, that is 16 GB of model states alone. Change the optimizer, parameter dtype, master-weight policy, or offload policy and the byte count changes. Always write down the recipe before using a bytes-per-parameter number.
Scale this recipe to 7B parameters and it reaches 112 GB of model states. At 70B parameters, it exceeds 1 TB before activations. If you need full-parameter training at that scale, replicated data parallelism alone won't make it fit.
FSDP (Fully Sharded Data Parallel) and ZeRO (Zero Redundancy Optimizer, a DeepSpeed technique that shards model states across GPUs) don't make memory disappear. They remove redundant copies of model states, then pay the price in communication.
For a 70B parameter model under the FP16/BF16 parameter-and-gradient plus FP32-master-and-Adam accounting convention, the model-state math looks like this:
That is about 1.12 TB of model states before activations, temporary buffers, or allocator overhead. It is an assumption-bound calculation, not a promise about every optimizer implementation.
The arithmetic is small enough to turn into a quick check. This example uses decimal GB to match the table below.
1GB = 10**9
2
3def state_memory_gb(params_billions: float, bytes_per_param: int) -> float:
4 return params_billions * 1_000_000_000 * bytes_per_param / GB
5
6recipe_bytes = {
7 "low_precision_parameters": 2,
8 "low_precision_gradients": 2,
9 "fp32_master_and_adam": 12,
10}
11weights = state_memory_gb(70, recipe_bytes["low_precision_parameters"])
12gradients = state_memory_gb(70, recipe_bytes["low_precision_gradients"])
13optimizer_and_master = state_memory_gb(70, recipe_bytes["fp32_master_and_adam"])
14total_model_states = weights + gradients + optimizer_and_master
15zero3_per_gpu = total_model_states / 256
16
17print("bytes_per_parameter=", sum(recipe_bytes.values()))
18print(f"weights={weights:.0f} GB")
19print(f"optimizer_and_master={optimizer_and_master:.0f} GB")
20print(f"total_model_states={total_model_states:.0f} GB")
21print(f"ZeRO-3 model states per GPU on 256 GPUs: {zero3_per_gpu:.2f} GB")1bytes_per_parameter= 16
2weights=140 GB
3optimizer_and_master=840 GB
4total_model_states=1120 GB
5ZeRO-3 model states per GPU on 256 GPUs: 4.38 GB
Before we look at the fix, recall how ordinary Distributed Data Parallel (DDP) works. DDP places a full copy of the model on every GPU, splits the training batch across them, and synchronizes gradients with an all-reduce (a collective operation that sums values across all GPUs and returns the result to every GPU) after each backward pass. It works well for models that fit on one GPU, because adding more GPUs improves throughput without changing the per-GPU memory requirement.
The problem appears when the model itself is too large for one GPU. DDP replicates the full model states on every rank (process/GPU), so adding more GPUs improves throughput but doesn't reduce the per-GPU memory footprint. FSDP and ZeRO attack that specific problem by sharding model states across the data-parallel group.
Think of DDP like every warehouse node keeping a full copy of the same giant catalog. Coordination is easy, but the storage is wasteful. Every GPU keeps a complete copy of the parameters, gradients, and optimizer states.
ZeRO changes that trade-off. Instead of duplicating the full catalog on every rank, it partitions the catalog across ranks and reconstructs only the pieces needed for computation.[1] In logistics terms, ZeRO is like a distributed warehouse network: each regional hub stores only a fraction of the catalog and pulls the remaining items on demand when an order arrives.
ZeRO (Zero Redundancy Optimizer) is easiest to understand as an incremental evolution. Each stage removes one more redundant copy of the model states.
Each GPU keeps optimizer state for only of the parameters. Parameters and gradients are still replicated.
Stage 2 also shards gradients, so each GPU owns only the gradient shard that matches its optimizer shard.
Stage 3 shards parameters as well. No rank owns a full persistent copy of the model states anymore.
Key insight: ZeRO doesn't reduce the aggregate model-state bytes; it eliminates persistent redundant copies. A 70B model still needs roughly 1.12 TB of model-state memory in aggregate, but ZeRO-3 partitions its steady-state storage across GPUs. Temporary materialized units and communication buffers still add to each rank's peak. The trade-off is communication: you pay in network bandwidth for what you save in persistent per-GPU state.
The table below assumes BF16/FP16 weights and gradients with Adam states in FP32. The totals are for model states only. Activation memory is listed separately because it depends heavily on sequence length, hidden size, micro-batch size, checkpointing, and attention implementation.
| Component | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |
|---|---|---|---|---|
| Weights | 140 GB | 140 GB | 140 GB | 0.55 GB |
| Gradients | 140 GB | 140 GB | 0.55 GB | 0.55 GB |
| Optimizer + FP32 master weights | 840 GB | 3.3 GB | 3.3 GB | 3.3 GB |
| Activations | Extra and workload-dependent | Extra and workload-dependent | Extra and workload-dependent | Extra and workload-dependent |
| Model states / GPU | 1120 GB | 283 GB | 144 GB | 4.4 GB |
This is why ZeRO-3 or FULL_SHARD is often paired with activation checkpointing (recomputing activations during the backward pass to save memory, also called gradient checkpointing). Sharding fixes model-state memory, but activations can still dominate the actual training footprint.
The same table should be generated from a recipe rather than copied into a sizing document:
1state_gb = {"parameters": 140, "gradients": 140, "optimizer_and_master": 840}
2world_size = 256
3
4def per_rank_gb(stage):
5 sharded = {
6 0: set(),
7 1: {"optimizer_and_master"},
8 2: {"optimizer_and_master", "gradients"},
9 3: {"optimizer_and_master", "gradients", "parameters"},
10 }[stage]
11 return sum(value / world_size if name in sharded else value for name, value in state_gb.items())
12
13for stage in [0, 1, 2, 3]:
14 print(f"ZeRO-{stage}: {per_rank_gb(stage):.2f} GB/model-state rank")1ZeRO-0: 1120.00 GB/model-state rank
2ZeRO-1: 283.28 GB/model-state rank
3ZeRO-2: 143.83 GB/model-state rank
4ZeRO-3: 4.38 GB/model-state rankZeRO-3 buys memory by turning parameter access into communication. That makes the interconnect part of your training system design, not an afterthought. Just as a distributed warehouse system only works if the trucking network between hubs can handle replenishment traffic, ZeRO-3's memory savings depend on whether your GPU interconnect can keep up with all-gather volume.
| Primitive | What it does | Where it shows up |
|---|---|---|
| All-reduce | Sum across ranks, then give the full result back to every rank | Classic DDP gradient sync |
| Reduce-scatter | Sum across ranks, but return only each rank's shard | ZeRO-2/3 gradient sync |
| All-gather | Collect shards from all ranks to rebuild the full tensor | ZeRO-3 and FSDP parameter materialization |
Those names stop blurring once you track one tiny tensor through each collective. Read each panel left to right: same four ranks, different "after" state.
The operations become concrete on a two-rank gradient vector. All-reduce leaves the summed vector on both ranks; reduce-scatter leaves one summed slice per rank; all-gather reconstructs a parameter vector from shards.
1rank_gradients = [[1, 2], [10, 20]]
2summed = [sum(values) for values in zip(*rank_gradients)]
3
4all_reduce_result = [summed.copy(), summed.copy()]
5reduce_scatter_result = [[summed[0]], [summed[1]]]
6
7parameter_shards = [["p0"], ["p1"]]
8full_parameters = [token for shard in parameter_shards for token in shard]
9all_gather_result = [full_parameters.copy(), full_parameters.copy()]
10
11print("all_reduce=", all_reduce_result)
12print("reduce_scatter=", reduce_scatter_result)
13print("all_gather=", all_gather_result)1all_reduce= [[11, 22], [11, 22]]
2reduce_scatter= [[11], [22]]
3all_gather= [['p0', 'p1'], ['p0', 'p1']]People often summarize ZeRO-3 communication as "about 1.5x DDP." That's a useful heuristic, but it's only a heuristic.
For a model with parameter bytes:
That gives the familiar rule of thumb. The harder part in practice isn't only byte volume, but how that volume is scheduled. DDP tends to use a smaller number of large collectives. ZeRO-3 and FSDP break communication into many layer-level or unit-level operations. If those collectives are small and your network has high latency, throughput can fall off quickly.
ZeRO-3 is most attractive when memory is the hard constraint. If the model already fits and the network is weak, DDP or ZeRO-2 style sharding can be faster.
1parameter_bytes = 140 # GB of low-precision parameters in the 70B example
2ddp_modeled_gb = 2 * parameter_bytes
3full_shard_modeled_gb = 3 * parameter_bytes
4
5print("DDP_modeled_traffic_GB=", ddp_modeled_gb)
6print("FULL_SHARD_modeled_traffic_GB=", full_shard_modeled_gb)
7print("byte_ratio=", full_shard_modeled_gb / ddp_modeled_gb)
8print("warning=latency_and_overlap_still_determine_step_time")1DDP_modeled_traffic_GB= 280
2FULL_SHARD_modeled_traffic_GB= 420
3byte_ratio= 1.5
4warning=latency_and_overlap_still_determine_step_timeFSDP is PyTorch's native sharded data-parallel stack.[2] The classic API is FullyShardedDataParallel (often called FSDP1). Current PyTorch exposes a DTensor-based API via torch.distributed.fsdp.fully_shard (often called FSDP2), and its tutorial marks FSDP1 deprecated while giving a migration guide. Treat FSDP1 as an existing-code path and start new designs from fully_shard.[3][4]
The main reason people choose FSDP is operational simplicity. It stays close to PyTorch's native distributed stack, native profilers, and native checkpointing tools. PyTorch's general torch.compile guidance is to compile the highest-level train or evaluation step, or the top-level module. If a distributed wrapper such as FSDP causes issues, compile the inner module instead. If you're compiling FSDP1, current docs also require use_orig_params=True.[5][6]
FSDP solves model-state memory. It doesn't solve activation memory. Long sequences and large micro-batches still push activations up, which is why FSDP is often paired with activation checkpointing.
For FULL_SHARD, or for an FSDP2 configuration that reshards after forward, one transformer-sized communication unit follows this lifecycle:
This sequence diagram shows one wrapped unit in a two-GPU FSDP job.
The following code shows a practical FSDP1 setup. You will still meet this API in existing codebases, so it is worth reading even though new designs should begin with fully_shard. The important details are:
backward_prefetch with limit_all_gathers. The first helps overlap communication with compute. The second is mainly a CPU-side rate limiter that caps in-flight all-gathers and peak memory.[6]This is distributed setup code, not a single-process notebook cell. It assumes torch.distributed.init_process_group() has already run under torchrun and that each rank has selected its local CUDA device.
1import torch
2import torch.nn as nn
3import functools
4from torch.distributed.fsdp import (
5 BackwardPrefetch,
6 FullyShardedDataParallel as FSDP,
7 MixedPrecision,
8 ShardingStrategy,
9)
10from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
11
12class TransformerDecoderLayer(nn.Module):
13 def __init__(self, d_model=1024, n_heads=16):
14 super().__init__()
15 self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
16 self.ffn = nn.Sequential(
17 nn.Linear(d_model, 4 * d_model),
18 nn.GELU(),
19 nn.Linear(4 * d_model, d_model),
20 )
21 self.norm1 = nn.LayerNorm(d_model)
22 self.norm2 = nn.LayerNorm(d_model)
23
24 def forward(self, x, mask=None):
25 attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
26 x = self.norm1(x + attn_out)
27 ffn_out = self.ffn(x)
28 x = self.norm2(x + ffn_out)
29 return x
30
31my_auto_wrap_policy = functools.partial(
32 transformer_auto_wrap_policy,
33 transformer_layer_cls={TransformerDecoderLayer},
34)
35
36mixed_precision_policy = MixedPrecision(
37 param_dtype=torch.bfloat16,
38 reduce_dtype=torch.bfloat16,
39 buffer_dtype=torch.bfloat16,
40)
41
42model = nn.Sequential(*[
43 TransformerDecoderLayer(d_model=1024, n_heads=16)
44 for _ in range(8)
45])
46
47fsdp_model = FSDP(
48 model,
49 auto_wrap_policy=my_auto_wrap_policy,
50 sharding_strategy=ShardingStrategy.FULL_SHARD,
51 mixed_precision=mixed_precision_policy,
52 backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
53 limit_all_gathers=True,
54 device_id=torch.cuda.current_device(),
55)
56
57optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=2e-5) # example SFT start; evaluate itFSDP2 moves away from the flat-parameter wrapper model and toward per-parameter sharding with DTensor. The point isn't only elegance. It changes how composable and inspectable the system is.
For new code, the structural pattern is bottom-up application of fully_shard, then optimizer construction over the resulting DTensor parameters:[3]
1from torch.distributed.fsdp import fully_shard
2
3model = TransformerModel()
4for block in model.blocks:
5 fully_shard(block) # one communication group per transformer block
6fully_shard(model) # shard remaining root-level parameters
7
8optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)This snippet shows structure rather than a runnable local demo: fully_shard must execute inside a distributed process group on suitable accelerators. The learning rate remains an SFT baseline to evaluate, not a distributed-training requirement.
| Feature | FSDP1 (FullyShardedDataParallel) | FSDP2 (fully_shard) |
|---|---|---|
| How it applies sharding | Wrapper module with flat parameters | In-place hooks on original modules |
| Sharding representation | FlatParameter groups | Per-parameter DTensor |
| Parameter names | Flattened internals to reason about | Original FQNs stay intact |
| Composition with TP / frozen params | More awkward | Better fit with DeviceMesh and frozen params |
| Checkpointing style | Full and sharded state dict APIs | Sharded state dicts first, reshard to full when needed |
Cleaner composition: FSDP2 uses DeviceMesh, which makes it much easier to compose data-parallel sharding with tensor parallelism in the same topology.
Cleaner state dict story for large runs: Sharded state dicts are first-class, and distributed checkpointing avoids the old pattern of gathering everything onto rank 0 for saving.
Cleaner handling of frozen parameters: Per-parameter sharding is a better fit for PEFT and LoRA-style workflows where only a small subset of weights should update.
Different scheduling controls: fully_shard is no longer a wrapper configured through FSDP1's auto_wrap_policy, backward_prefetch, use_orig_params, or limit_all_gathers arguments. Apply it bottom-up, then use FSDP2 methods such as set_modules_to_forward_prefetch or set_modules_to_backward_prefetch when explicit prefetch control is needed.[3]
The practical takeaway is simple: FSDP1 still shows up in existing codebases, while current PyTorch's FSDP2 documentation provides the migration path and fully_shard API for new work.
The most important FSDP choice is usually sharding granularity. Use transformer blocks as communication units. If one root unit owns the entire model, every all-gather becomes too large and peak memory shoots back up.
This simplified calculation isolates why granularity matters. It models only the transient full-parameter materialization, not activations or prefetch overlap:
1block_parameter_gb = [3.5, 3.5, 3.5, 3.5]
2
3root_only_peak_full_materialization = sum(block_parameter_gb)
4blockwise_peak_full_materialization = max(block_parameter_gb)
5
6print("root_only_active_parameters_GB=", root_only_peak_full_materialization)
7print("blockwise_active_parameters_GB=", blockwise_peak_full_materialization)
8print("modeled_reduction=", root_only_peak_full_materialization / blockwise_peak_full_materialization)1root_only_active_parameters_GB= 14.0
2blockwise_active_parameters_GB= 3.5
3modeled_reduction= 4.0| Strategy | Equivalent | Description |
|---|---|---|
FULL_SHARD | ZeRO-3 | Shards parameters, gradients, and optimizer states. Most memory efficient. |
SHARD_GRAD_OP | ZeRO-2-like | Unshards parameters for forward, keeps them unsharded until backward finishes, then reshards. Uses more memory than FULL_SHARD, but does fewer reshard steps. |
NO_SHARD | DDP-like replication | Keeps model states replicated and avoids sharded parameter materialization, but doesn't reduce per-rank model-state memory. |
HYBRID_SHARD | ZeRO-3 within node + replication across nodes | Shards within a node and replicates across nodes. Useful when intra-node links are fast and inter-node bandwidth is tighter. |
Large training runs fail in the middle all the time. A node reboots. A preemptible instance disappears. A job gets rescheduled. If your checkpoint only contains model weights, you didn't save the training run. You saved an export artifact.[7]
For a real resume point, the checkpoint usually needs at least:
| Must-save state | Why it matters on resume |
|---|---|
| Model weights | Restores current parameters |
| Optimizer state | Preserves momentum and adaptive moments |
| Scheduler state | Keeps warmup and decay aligned with training step |
| Gradient-scaler / AMP state, if used | Preserves mixed-precision control state across resume |
| Consumed step or token count | Lets logs, schedulers, and stop rules stay honest |
| Sampler / dataloader cursor | Prevents replaying or skipping large data regions |
| RNG state | Keeps dropout, shuffling, and sampling reproducible when needed |
In a sharded system, the safest default is to save sharded checkpoints, not to gather everything onto rank 0 first. FSDP2 and distributed checkpoint APIs are pushing in that direction because the "collect all weights on one machine and write a giant file" pattern eventually becomes the bottleneck or the failure point.[3][7]
DeepSpeed goes one step further with Universal Checkpointing, which is meant to make checkpoint artifacts more portable across different parallelism layouts instead of tying restore logic to the exact original topology.[8] That matters once you start changing world size between runs or resume on a different cluster shape.
The operational rule is simple:
latest-final.If you can't answer "what exact state do we restore, and how do we know resume is faithful?", the training system isn't production-ready yet.
A distributed checkpoint manifest also has to record the sharding topology it was written under, even when a portable conversion path exists:
1required = {
2 "model_shards",
3 "optimizer_shards",
4 "scheduler_state",
5 "consumed_tokens",
6 "sampler_cursor",
7 "rng_state",
8 "gradient_scaler_state",
9 "template_tokenizer_version",
10 "data_manifest",
11 "evaluation_manifest",
12 "best_metric",
13 "world_size",
14 "sharding_strategy",
15}
16
17manifest = {
18 "model_shards": "ckpt/step_800/model/",
19 "optimizer_shards": "ckpt/step_800/optimizer/",
20 "scheduler_state": "ckpt/step_800/scheduler.pt",
21 "consumed_tokens": 104_857_600,
22 "sampler_cursor": {"epoch": 0, "batch": 800},
23 "rng_state": "ckpt/step_800/rng.pt",
24 "gradient_scaler_state": None, # BF16 run; store scaler state for FP16 when used
25 "template_tokenizer_version": "support-chat-v3",
26 "data_manifest": "manifests/sft-train-v7.json",
27 "evaluation_manifest": "manifests/support-eval-v4.json",
28 "best_metric": {"name": "support_resolution_accuracy", "value": 0.78},
29 "world_size": 32,
30 "sharding_strategy": "FULL_SHARD",
31}
32
33missing = sorted(required - manifest.keys())
34print("resume_manifest_complete=", not missing)
35print("saved_topology=", manifest["sharding_strategy"], manifest["world_size"])
36print("gradient_scaler_state=", manifest["gradient_scaler_state"])
37print("best_metric=", manifest["best_metric"])1resume_manifest_complete= True
2saved_topology= FULL_SHARD 32
3gradient_scaler_state= None
4best_metric= {'name': 'support_resolution_accuracy', 'value': 0.78}DeepSpeed is Microsoft's distributed training runtime that introduced ZeRO and provides heterogeneous-memory offload through ZeRO-Offload and ZeRO-Infinity.[1][9] It's got a steeper learning curve than native PyTorch, but it remains a strong choice when memory pressure is the real blocker or when the rest of your stack already depends on Megatron-DeepSpeed style training.
One important DeepSpeed-specific detail is that ZeRO-3 depends on consistent module execution order across ranks. Dynamic routing patterns such as MoE can deadlock parameter all-gathers if different ranks enter different submodules, which is why DeepSpeed exposes the idea of ZeRO-3 leaf modules.[10]
The following distributed setup snippet shows how to configure and initialize a model using DeepSpeed. Run it through the DeepSpeed launcher in a real training job. The configuration dictionary defines the ZeRO stage, offloading settings, optimizer, and training batch parameters. The learning rate is an example SFT starting point to evaluate, and the bucket sizes are tuning knobs rather than canonical values:
1import deepspeed
2import torch.nn as nn
3
4ds_config = {
5 "train_micro_batch_size_per_gpu": 1,
6 "gradient_accumulation_steps": 8,
7 "bf16": {
8 "enabled": True,
9 },
10 "optimizer": {
11 "type": "DeepSpeedCPUAdam",
12 "params": {
13 "lr": 2e-5,
14 "betas": [0.9, 0.999],
15 "eps": 1e-8,
16 "weight_decay": 0.01,
17 },
18 },
19 "zero_optimization": {
20 "stage": 3,
21 "offload_optimizer": {
22 "device": "cpu",
23 "pin_memory": True,
24 },
25 "offload_param": {
26 "device": "cpu",
27 "pin_memory": True,
28 },
29 "overlap_comm": True,
30 "contiguous_gradients": True,
31 "reduce_bucket_size": 500_000_000,
32 "stage3_prefetch_bucket_size": 50_000_000,
33 "stage3_param_persistence_threshold": 100_000,
34 }
35}
36
37class SimpleTransformer(nn.Module):
38 def __init__(self, d_model=1024, n_layers=4):
39 super().__init__()
40 self.layers = nn.ModuleList([
41 nn.TransformerEncoderLayer(d_model, nhead=16, batch_first=True)
42 for _ in range(n_layers)
43 ])
44 self.proj = nn.Linear(d_model, 1000)
45
46 def forward(self, x):
47 for layer in self.layers:
48 x = layer(x)
49 return self.proj(x)
50
51model = SimpleTransformer(d_model=1024, n_layers=4)
52
53model_engine, optimizer, _, _ = deepspeed.initialize(
54 model=model,
55 config=ds_config,
56 model_parameters=model.parameters(),
57)
58
59# for batch in dataloader:
60# outputs = model_engine(batch)
61# loss = compute_loss(outputs)
62# model_engine.backward(loss)
63# model_engine.step()DeepSpeed extends ZeRO-3 with ZeRO-Infinity, which allows offloading sharded parameters and optimizer states to CPU RAM or even NVMe SSDs.[9]
| Tier | Memory Pool | Use Case |
|---|---|---|
| GPU | GPU HBM | Active compute, activations, temporary buffers |
| CPU | System RAM | Optimizer states or parameter shards when GPU memory is tight |
| NVMe | Local SSD / NVMe | Deep spillover tier when RAM still isn't enough |
This makes some otherwise impossible training runs feasible, but it doesn't make offload free. Once parameters or optimizer state spill into CPU or NVMe, throughput becomes constrained by PCIe, host memory bandwidth, storage latency, and how well the runtime can overlap data movement with compute.
DeepSpeed's feature matrix is broader, but it isn't "all combinations are valid." Current docs say AutoTP training supports ZeRO stages 0, 1, and 2, while PipelineModule isn't compatible with ZeRO-2 or ZeRO-3.[10][11]
Turn compatibility constraints into a config gate, not a surprise after scheduling a large job:
1requested_runs = [
2 {"feature": "AutoTP", "zero_stage": 2},
3 {"feature": "AutoTP", "zero_stage": 3},
4 {"feature": "PipelineModule", "zero_stage": 2},
5]
6
7def supported(run):
8 if run["feature"] == "AutoTP":
9 return run["zero_stage"] in {0, 1, 2}
10 if run["feature"] == "PipelineModule":
11 return run["zero_stage"] in {0, 1}
12 return False
13
14for run in requested_runs:
15 print(run["feature"], f"ZeRO-{run['zero_stage']}", "allowed=", supported(run))1AutoTP ZeRO-2 allowed= True
2AutoTP ZeRO-3 allowed= False
3PipelineModule ZeRO-2 allowed= FalseBoth frameworks solve the same core problem, but they optimize for different operational constraints. FSDP optimizes for staying close to standard PyTorch. DeepSpeed optimizes for a wider set of large-scale runtime features.
| Feature | FSDP | DeepSpeed ZeRO |
|---|---|---|
| Runtime model | Native PyTorch distributed stack | Separate DeepSpeed runtime |
| ZeRO stages | ZeRO-2-like and ZeRO-3-like sharding modes | Native ZeRO-1/2/3 |
| CPU offloading | CPUOffload for params and gradients | Strong CPU and NVMe offload |
| NVMe offloading | No | Yes |
| Compile story | Compile the train step or top-level module first; fall back to the inner module if the wrapper causes issues. FSDP1 needs use_orig_params=True | Depends on integration path |
| Pipeline parallel | External | Built-in, but PipelineModule excludes ZeRO-2/3 |
| Tensor parallel | External | AutoTP or Megatron integration (AutoTP supports ZeRO-0/1/2) |
| Checkpointing / tooling | Standard PyTorch APIs | DeepSpeed-specific runtime APIs |
| Debugging | Standard PyTorch stack traces and profiler | More runtime-specific behavior |
For very large jobs, data-parallel sharding alone isn't enough. As you push the data-parallel group wider, global batch size, optimizer dynamics, and cross-node communication all start fighting you. The standard answer is to combine multiple forms of parallelism in the same training topology.[12]
That's what people mean by 3D parallelism: data parallelism for replica groups, tensor parallelism (splitting weight matrices within a layer across GPUs) inside layers, and pipeline parallelism (splitting sequential layer groups across different devices) across layer ranges.
The diagram below shows how data, tensor, and pipeline parallelism combine to distribute a model across a large cluster. Solid arrows show pipeline flow through layer ranges. Dotted arrows show synchronization between matching data-parallel replicas.
Tensor parallelism splits the work inside a layer. Instead of putting a full weight matrix on one GPU, it partitions the matrix across GPUs.
Pipeline parallelism splits the model across depth by assigning different layer ranges to different devices or nodes.
| Parallelism | Primary shard | Communication pattern | Best for |
|---|---|---|---|
| Data / ZeRO / FSDP | Batch across replicas, with model states optionally sharded | All-reduce, reduce-scatter, all-gather | General scaling and memory reduction |
| Tensor (TP) | Weight matrices within layers | Collectives inside each layer | Very wide layers on fast intra-node links |
| Pipeline (PP) | Sequential layer groups | Point-to-point activations and gradients | Very deep models and multi-node partitioning |
The escalation path should be deliberate.[13]
A common placement keeps TP on fast intra-node links, uses PP to partition depth across broader topology boundaries, and applies data parallelism across replica groups. The right degrees still depend on available links, memory, model shape, and global-batch constraints.
In current PyTorch, a DeviceMesh is the abstraction used to express multiple device dimensions, which lets FSDP2 compose with tensor parallelism on the same hardware (often called FSDP + TP, or 2D parallelism). Megatron Core also documents context parallelism, which splits the sequence dimension for long-context workloads. Combining data, tensor, pipeline, and context parallelism creates a four-axis topology; runtimes don't always use the same "3D" or "4D" label for that composition.[3][13]
Each topology degree multiplies into the required world size. Calculate it before reserving hardware:
1topology = {
2 "data_parallel": 8,
3 "tensor_parallel": 4,
4 "pipeline_parallel": 2,
5 "context_parallel": 2,
6}
7
8world_size = 1
9for degree in topology.values():
10 world_size *= degree
11
12print("world_size=", world_size)
13print("ranks_per_data_replica=", topology["tensor_parallel"] * topology["pipeline_parallel"] * topology["context_parallel"])
14assert world_size == 1281world_size= 128
2ranks_per_data_replica= 16Distributed training fails in ways that single-GPU training doesn't. Jobs hang instead of crash. One bad rank can stall everybody else. Performance regressions often come from communication scheduling, not math kernels.
TORCH_NCCL_BLOCKING_WAIT=1 can turn a silent hang into a visible failure.fsdp_model.clip_grad_norm_() because it computes across sharded gradients. FSDP2's DTensor-based parameters instead support torch.nn.utils.clip_grad_norm_(model.parameters(), ...) in the current tutorial.[6][4]To optimize throughput (tokens/sec), you must identify if you're compute-bound or communication-bound.
torch.profiler: Look at ncclAllGather, ncclReduceScatter, and gaps where GPUs wait for communication. For FSDP1, tune wrapping granularity and backward_prefetch; use limit_all_gathers=True for peak-memory rate limiting, not as the overlap knob. For FSDP2, use its module prefetch APIs when implicit scheduling is insufficient.[6][3]For hard distributed bugs, capture one clean repro run with NCCL_DEBUG=INFO and TORCH_DISTRIBUTED_DEBUG=DETAIL instead of trying to reason from high-level symptoms alone.
Once the job runs, you still need to know why it is slow or unstable. Good training dashboards separate data, compute, communication, and memory instead of showing only one loss curve.[13]
| Signal | What it answers | Bad symptom |
|---|---|---|
| tokens/sec | Is end-to-end throughput improving? | flat or falling throughput after adding GPUs |
| step time split | Is time spent in data load, forward, backward, optimizer, or checkpoint save? | one phase dominates without explanation |
| all-gather / reduce-scatter share | Is sharding traffic now bottlenecking the job? | communication time grows faster than compute time |
| dataloader wait time | Are GPUs starved by CPU preprocessing or storage? | GPU idle gaps before forward |
| peak HBM and reserved memory | Are you close to OOM or fragmenting memory? | retries, allocator spikes, sudden OOM at checkpoint save |
| MFU or FLOP utilization | Are expensive kernels doing useful math? | very low utilization despite full memory use |
| checkpoint save duration | Is fault tolerance becoming a throughput tax? | long pauses every save interval |
If you only watch training loss, a run can look healthy while throughput quietly collapses. Distributed training needs systems telemetry, not only ML telemetry.
Strong distributed-training debugging connects the abstraction to the bottleneck. A useful answer names the split, the communication primitive, and the symptom you would measure.
| Topic | What to explain | Debug signal |
|---|---|---|
| Data parallelism | each rank owns a batch shard and synchronizes gradients | all-reduce time rises with parameter size |
| Tensor parallelism | one layer is split across ranks | frequent collectives inside attention and MLP blocks |
| Pipeline parallelism | different ranks own different layer ranges | idle bubbles when micro-batching is too small |
| ZeRO / FSDP | optimizer states, gradients, and parameters are sharded | all-gather and reduce-scatter dominate step time |
| Activation checkpointing | save memory by recomputing activations in backward | lower peak memory, higher compute time |
| NCCL bottlenecks | collectives depend on rank order, topology, and matching calls | hangs, timeouts, or wide gaps in profiler traces |
| FlashAttention | reduce attention memory traffic with tiled kernels | better long-context throughput when attention is memory-bound |
A practical debugging loop is:
This is the bridge between ML and distributed systems. The right answer is rarely "use FSDP" or "use DeepSpeed." The stronger answer is "model states don't fit, so I shard them; then I measure whether communication or activation memory became the new bottleneck."
A common sizing prompt is: "How many 80 GB GPUs do you need to train a 70B model with ZeRO-3?" Walk through it step by step.
Step 1: count the model states.
With mixed-precision Adam, each parameter carries:
Total: 16 bytes per parameter.
For 70B parameters: of model states.
Step 2: compute the model-state lower bound.
If you looked only at model states, the absolute lower bound on 80 GB GPUs would be:
So 14 GPUs is the math-floor for model states alone.
The same calculation is worth making executable:
1import math
2
3total_state_gb = 70 * 16
4gpu_memory_gb = 80
5
6raw_floor = math.ceil(total_state_gb / gpu_memory_gb)
7print("raw_lower_bound_gpus=", raw_floor)
8
9for gpus in [14, 16, 24]:
10 per_gpu = total_state_gb / gpus
11 headroom = gpu_memory_gb - per_gpu
12 print(f"{gpus:>2} GPUs -> {per_gpu:5.1f} GB states/GPU, {headroom:5.1f} GB before activations")1raw_lower_bound_gpus= 14
214 GPUs -> 80.0 GB states/GPU, 0.0 GB before activations
316 GPUs -> 70.0 GB states/GPU, 10.0 GB before activations
424 GPUs -> 46.7 GB states/GPU, 33.3 GB before activationsStep 3: add overhead.
The math above is for model states only. Real training also needs:
There isn't one honest fixed headroom percentage: a context-length or micro-batch change can dominate it. Measure or estimate the remaining peak memory for the specific workload, then compare candidate topologies:
1state_total_gb = 1120
2gpu_capacity_gb = 80
3measured_non_state_peak_gb = 18 # activations + buffers + allocator margin for this trial workload
4
5def fits(world_size):
6 state_per_gpu = state_total_gb / world_size
7 peak = state_per_gpu + measured_non_state_peak_gb
8 return state_per_gpu, peak, peak <= gpu_capacity_gb
9
10for world_size in [16, 24, 32]:
11 state, peak, ok = fits(world_size)
12 print(f"{world_size} GPUs -> states={state:.1f} GB peak_estimate={peak:.1f} GB fits={ok}")116 GPUs -> states=70.0 GB peak_estimate=88.0 GB fits=False
224 GPUs -> states=46.7 GB peak_estimate=64.7 GB fits=True
332 GPUs -> states=35.0 GB peak_estimate=53.0 GB fits=TrueStep 4: answer the question.
The clean answer is: 14 GPUs is only the raw ZeRO-3 lower bound for model states under this 16-byte recipe. It isn't a runnable cluster recommendation. A practical count requires a measured or modeled activation/buffer peak and an acceptable throughput result. In the example above, an 18 GB non-state peak rules out 16 GPUs but fits at 24; a longer sequence could invalidate that answer again. Activation checkpointing, a smaller micro-batch, tensor parallelism, or pipeline parallelism may change the result.
Symptom: A team says they "added ZeRO-3 for tensor parallelism," but profiler traces still show whole-layer matmuls on each rank and the real change is only lower model-state memory.
Cause: ZeRO and FSDP shard model states across data-parallel ranks. They do not split the matrix multiply itself. Tensor parallelism splits layer computation, and pipeline parallelism splits model depth.
Fix: Check what moved in the trace or memory profile before naming the technique. If weights, gradients, and optimizer states shrank per rank, that is sharding. If one layer's matmul is split across devices, that is tensor parallelism.
Symptom: You enable ZeRO-3, but peak memory is still unexpectedly high and you get OOM during the forward pass.
Cause: You wrapped the entire model as a single FSDP unit instead of wrapping individual transformer layers. When the whole model is one unit, FSDP all-gathers every parameter at once. That defeats the purpose of layer-by-layer sharding.
Fix: Use auto_wrap_policy to wrap at the transformer-block level. The code example earlier shows how to do this with transformer_auto_wrap_policy.
limit_all_gathers=True creates overlapSymptom: You set limit_all_gathers=True and expect faster training, but throughput doesn't improve.
Cause: limit_all_gathers is a CPU-side rate limiter. It caps how many all-gathers are in flight at once to control peak memory. It isn't an overlap or performance knob.
Fix: In FSDP1, use backward_prefetch and communication-unit granularity for overlap. In FSDP2, use its prefetch methods when needed; it doesn't expose FSDP1's limit_all_gathers argument.
SHARD_GRAD_OP as identical to ZeRO-2Symptom: You switch from ZeRO-2-style reasoning to FSDP SHARD_GRAD_OP, then wonder why peak memory is higher than expected even though gradients are sharded.
Cause: The names are similar, but the parameter lifetime differs. PyTorch's SHARD_GRAD_OP keeps parameters unsharded through forward and backward, then reshards after backward.[6] That changes peak memory behavior even when the high-level mental model sounds like ZeRO-2.
Fix: Confirm the actual parameter lifetime in docs and profiler traces before sizing memory. If you need the lowest parameter footprint, compare against FULL_SHARD rather than assuming SHARD_GRAD_OP behaves like a DeepSpeed stage name.
Symptom: The model-state math says everything should fit, but you still OOM.
Cause: Activations can be larger than model states for long sequences. A single forward pass with a long context can generate gigabytes of intermediate tensors.
Fix: Pair ZeRO-3 or FSDP with activation checkpointing. Trade compute for memory by recomputing activations during the backward pass instead of storing them.
Symptom: You configure AutoTP with ZeRO-3 and the job crashes or hangs.
Cause: DeepSpeed has documented compatibility limits. AutoTP currently supports ZeRO stages 0, 1, and 2, but not 3. PipelineModule isn't compatible with ZeRO-2 or ZeRO-3.
Fix: Check the stage-compatibility matrix before you design your training topology. Don't assume that because DeepSpeed supports both features, they work together.
By the end of this lesson, you should be able to explain and operate these parts of a distributed training setup:
FULL_SHARD, SHARD_GRAD_OP, NO_SHARD, and HYBRID_SHARD trade memory for communication, and why new code should start with FSDP2 fully_shard.fully_shard is different from the FSDP1 wrapper API.Use this sequence when sizing a real run:
| Scenario | Recommendation | Why |
|---|---|---|
| Model fits in aggregate GPU memory and you want native PyTorch tooling | FSDP2 fully_shard, applied bottom-up | Current native path with DTensor-backed sharding and PyTorch checkpointing |
| Maintaining FSDP1 code and willing to spend more memory for fewer reshard events | FSDP1 SHARD_GRAD_OP | Higher peak memory, less aggressive resharding |
| GPU memory is the blocker and CPU/NVMe offload is acceptable | DeepSpeed ZeRO-Infinity | Uses larger memory tiers to make the run feasible |
| Need DeepSpeed tensor parallel training with optimizer sharding | DeepSpeed AutoTP + ZeRO-1/2 | Current docs support AutoTP with ZeRO stages 0, 1, and 2, but not 3 |
Need DeepSpeed PipelineModule | DeepSpeed pipeline + ZeRO-0/1 | Current docs say PipelineModule isn't compatible with ZeRO-2 or ZeRO-3 |
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
Rajbhandari, S., et al. · 2020 · SC 2020
PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel.
Zhao, Y., et al. · 2023 · VLDB 2023
torch.distributed.fsdp.fully_shard
PyTorch Contributors · 2025
Getting Started with Fully Sharded Data Parallel (FSDP2)
PyTorch Contributors · 2025
Where to apply torch.compile?
PyTorch Contributors · 2025
FullyShardedDataParallel
PyTorch Contributors · 2025
Checkpointing in torchtune.
PyTorch Contributors · 2026
Universal Checkpointing with DeepSpeed: A Practical Guide.
DeepSpeed Team · 2026
ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning.
Rajbhandari, S., et al. · 2021 · SC 2021
Training API
DeepSpeed Team · 2026
Pipeline Parallelism
DeepSpeed Team · 2026
Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.
Shoeybi, M., et al. · 2019
Parallelism Strategies Guide.
NVIDIA · 2026