LeetLLM
LearnFeaturesBlog
LeetLLM

Your go-to resource for mastering AI & LLM systems.

Product

  • Learn
  • Features
  • Blog

Legal

  • Terms of Service
  • Privacy Policy

© 2026 LeetLLM. All rights reserved.

All Topics
Your Progress
0%

0 of 155 articles completed

🛠️Computing Foundations0/6
NumPy and Tensor ShapesCUDA for ML TrainingMPS & Metal for ML on MacData Structures for AISQL and Data ModelingAlgorithms for ML Engineers
📊Math & Statistics0/8
Gradients and BackpropVectors, Matrices & TensorsLinear Algebra for MLAdam, Momentum, SchedulersProbability for Machine LearningStatistics and UncertaintyDistributions and SamplingHypothesis Tests, Intervals, and pass@k
📚Preparation & Prerequisites0/13
Neural Networks from ScratchCNNs from ScratchTraining & BackpropagationSoftmax, Cross-Entropy & OptimizationRNNs, LSTMs, GRUs, and Sequence ModelingAutoencoders and VAEsThe Transformer Architecture End-to-EndLanguage Modeling & Next TokensFrom GPT to Modern LLMsPrompt Engineering FundamentalsCalling LLM APIs in ProductionFirst AI App End-to-EndThe LLM Lifecycle
🧮ML Algorithms & Evaluation0/11
Linear Regression from ScratchLogistic Regression and MetricsDecision Trees, Forests, and BoostingReinforcement Learning BasicsValidation and LeakageClustering and PCACore Retrieval AlgorithmsDecoding AlgorithmsExperiment Design and A/B TestingPyTorch Training LoopsDataset Pipelines and Data Quality
📦Production ML Systems0/6
Feature Engineering for Production MLBatch and Streaming Feature PipelinesGradient Boosted Trees in ProductionRanking and Recommendation SystemsForecasting and Anomaly DetectionMonitoring Predictive Models
🧪Core LLM Foundations0/8
The Bitter Lesson & ComputeBPE, WordPiece, and SentencePieceStatic to Contextual EmbeddingsPerplexity & Model EvaluationFile Ingestion for AIChunking StrategiesLLM Benchmarks & LimitationsInstruction Tuning & Chat Templates
🧰Applied LLM Engineering0/23
Dimensionality Reduction for EmbeddingsCoT, ToT & Self-Consistency PromptingFunction Calling & Tool UseMCP & Tool Protocol StandardsPrompt Injection DefenseResponsible AI GovernanceData Labeling and Human FeedbackEvaluating AI AgentsProduction RAG PipelinesHybrid Search: Dense + SparseReranking and Cross-Encoders for RAGRAG Evaluation for Reliable AnswersLLM-as-a-Judge EvaluationBias & Fairness in LLMsHallucination Detection & MitigationLLM Observability & MonitoringExperiment Tracking with MLflow and W&BMixed Precision TrainingModel Versioning & DeploymentSemantic Caching & Cost OptimizationLLM Cost Engineering & Token EconomicsModel Gateways, Routing, and FallbacksDesign an Automated Support Agent
🎓Portfolio Capstones0/9
Capstone: Delivery ETA PredictionCapstone: Product RankingCapstone: Demand ForecastingCapstone: Image Damage ClassifierCapstone: Production ML PipelineCapstone: Document QACapstone: Eval DashboardCapstone: Fine-Tuned ClassifierCapstone: Production Agent
🧠Transformer Deep Dives0/8
Sentence Embeddings & Contrastive LossEmbedding Similarity & QuantizationScaled Dot-Product AttentionVision Transformers and Image EncodersPositional Encoding: RoPE & ALiBiLayer Normalization: Pre-LN vs Post-LNMechanistic InterpretabilityDecoding Strategies: Greedy to Nucleus
🧬Advanced Training & Adaptation0/16
Scaling Laws & Compute-Optimal TrainingPre-training Data at ScaleBuild GPT from Scratch LabContinued Pretraining for Domain ShiftSynthetic Data PipelinesSupervised Fine-Tuning PipelineDistributed Training: FSDP & ZeROLoRA & Parameter-Efficient TuningReward Modeling from Preference DataRLHF & DPO AlignmentConstitutional AI & Red TeamingRLVR & Verifiable RewardsKnowledge Distillation for LLMsModel Merging and Weight InterpolationPrompt Optimization with DSPyRecursive Language Models (RLM)
🤖Advanced Agents & Retrieval0/14
Vector DB Internals: HNSW & IVFAdvanced RAG: HyDE & Self-RAGGraphRAG & Knowledge GraphsRAG Security & Access ControlStructured Output GenerationReAct & Plan-and-ExecuteGuardrails & Safety FiltersCode Generation & SandboxingComputer-Use / GUI / Browser AgentsHuman-in-the-Loop Agent ArchitectureAI Coding Workflow with AgentsAgent Memory & PersistenceAgent Failure & RecoveryMulti-Agent Orchestration
⚡Inference & Production Scale0/20
Inference: TTFT, TPS & KV CacheMulti-Query & Grouped-Query AttentionKV Cache & PagedAttentionPrefix Caching and Prompt CachingFlashAttention & Memory EfficiencyContinuous Batching & SchedulingScaling LLM InferenceModel Parallelism for LLM InferenceModel Quantization: GPTQ, AWQ & GGUFLocal LLM DeploymentSLM Specialization & Edge DeploymentSpeculative DecodingLong Context Window ManagementContext EngineeringMixture of Experts ArchitectureMamba & State Space ModelsReasoning & Test-Time ComputeAdvanced MLOps & DevOps for AIGPU Serving & AutoscalingA/B Testing for LLMs
🏗️System Design Capstones0/9
Content Moderation SystemCode Completion SystemMulti-Tenant LLM PlatformLLM-Powered Search EngineVision-Language Models & CLIPMultimodal LLM ArchitectureDiffusion Models & Image GenerationReal-Time Voice AI AgentReasoning & Test-Time Compute
🎤AI Lab Interviewing0/4
AI Lab Coding Interview: Python SystemsAI Lab System Design InterviewAI Lab Behavioral InterviewAI Lab Technical Presentation
Back to Topics
LearnAdvanced Training & AdaptationDistributed Training: FSDP & ZeRO
⚡HardFine-Tuning & Training

Distributed Training: FSDP & ZeRO

Understand ZeRO stages, current FSDP1 vs FSDP2 guidance, and when native PyTorch or DeepSpeed is the right choice for large-model training.

41 min read
Learning path
Step 99 of 155 in the full curriculum
Supervised Fine-Tuning PipelineLoRA & Parameter-Efficient Tuning

Distributed Training: FSDP & ZeRO

The SFT pipeline chapter ended with a trustworthy single-device training recipe. This chapter handles the next limit: full-weight updates may be correct for the task, but one GPU can't hold the states and activations needed to execute them. Fully Sharded Data Parallel (FSDP) and ZeRO (Zero Redundancy Optimizer) reduce repeated model-state storage across data-parallel ranks, then make communication part of the training cost.

Start with the memory contract. One useful sizing convention, used in the ZeRO analysis, assumes low-precision parameters and gradients plus an FP32 master parameter copy and FP32 Adam moments.[1] Under that specific recipe, each parameter costs 16 bytes of model-state memory before activations, temporary buffers, or allocator overhead:

  • 2 bytes for the BF16 (bfloat16, short for Brain Floating Point) / FP16 (16-bit half precision) weight
  • 2 bytes for the BF16 / FP16 gradient
  • 4 bytes for the FP32 (32-bit single precision) master copy of the weight
  • 4 bytes for Adam's first moment (momentum)
  • 4 bytes for Adam's second moment (variance)

For a 1B model, that is 16 GB of model states alone. Change the optimizer, parameter dtype, master-weight policy, or offload policy and the byte count changes. Always write down the recipe before using a bytes-per-parameter number.

Scale this recipe to 7B parameters and it reaches 112 GB of model states. At 70B parameters, it exceeds 1 TB before activations. If you need full-parameter training at that scale, replicated data parallelism alone won't make it fit.

FSDP (Fully Sharded Data Parallel) and ZeRO (Zero Redundancy Optimizer, a DeepSpeed technique that shards model states across GPUs) don't make memory disappear. They remove redundant copies of model states, then pay the price in communication.

Memory layout comparison: DDP replicates model states, ZeRO-2 shards gradients and optimizer state, and ZeRO-3 or FSDP FULL_SHARD shards parameters, gradients, and optimizer state. Memory layout comparison: DDP replicates model states, ZeRO-2 shards gradients and optimizer state, and ZeRO-3 or FSDP FULL_SHARD shards parameters, gradients, and optimizer state.
Compare the three columns. P, G, and O stand for parameters, gradients, and optimizer state. DDP puts a full copy of everything on every GPU. ZeRO-2 splits gradients and optimizer state. ZeRO-3 splits all three, so each GPU keeps only a thin shard.

The memory wall in numbers

For a 70B parameter model under the FP16/BF16 parameter-and-gradient plus FP32-master-and-Adam accounting convention, the model-state math looks like this:

  • Weights: 70×109×2 bytes≈70 \times 10^9 \times 2 \text{ bytes} \approx70×109×2 bytes≈ 140 GB
  • Gradients: 70×109×2 bytes≈70 \times 10^9 \times 2 \text{ bytes} \approx70×109×2 bytes≈ 140 GB
  • Adam state + FP32 master weights: 70×109×(4+4+4) bytes≈70 \times 10^9 \times (4 + 4 + 4) \text{ bytes} \approx70×109×(4+4+4) bytes≈ 840 GB

That is about 1.12 TB of model states before activations, temporary buffers, or allocator overhead. It is an assumption-bound calculation, not a promise about every optimizer implementation.

The arithmetic is small enough to turn into a quick check. This example uses decimal GB to match the table below.

the-memory-wall-in-numbers.py
1GB = 10**9 2 3def state_memory_gb(params_billions: float, bytes_per_param: int) -> float: 4 return params_billions * 1_000_000_000 * bytes_per_param / GB 5 6recipe_bytes = { 7 "low_precision_parameters": 2, 8 "low_precision_gradients": 2, 9 "fp32_master_and_adam": 12, 10} 11weights = state_memory_gb(70, recipe_bytes["low_precision_parameters"]) 12gradients = state_memory_gb(70, recipe_bytes["low_precision_gradients"]) 13optimizer_and_master = state_memory_gb(70, recipe_bytes["fp32_master_and_adam"]) 14total_model_states = weights + gradients + optimizer_and_master 15zero3_per_gpu = total_model_states / 256 16 17print("bytes_per_parameter=", sum(recipe_bytes.values())) 18print(f"weights={weights:.0f} GB") 19print(f"optimizer_and_master={optimizer_and_master:.0f} GB") 20print(f"total_model_states={total_model_states:.0f} GB") 21print(f"ZeRO-3 model states per GPU on 256 GPUs: {zero3_per_gpu:.2f} GB")
Model-state memory output
1bytes_per_parameter= 16 2weights=140 GB 3optimizer_and_master=840 GB 4total_model_states=1120 GB 5ZeRO-3 model states per GPU on 256 GPUs: 4.38 GB
70B model-state memory tally under a 16-byte mixed-precision Adam assumption, compared with one 80 GB GPU and ZeRO-3 sharding. 70B model-state memory tally under a 16-byte mixed-precision Adam assumption, compared with one 80 GB GPU and ZeRO-3 sharding.
Under this mixed-precision Adam accounting recipe, model state is not just weights: gradients, FP32 master weights, and optimizer moments bring a 70B model to 1.12 TB before activations.

Why DDP hits a limit

Before we look at the fix, recall how ordinary Distributed Data Parallel (DDP) works. DDP places a full copy of the model on every GPU, splits the training batch across them, and synchronizes gradients with an all-reduce (a collective operation that sums values across all GPUs and returns the result to every GPU) after each backward pass. It works well for models that fit on one GPU, because adding more GPUs improves throughput without changing the per-GPU memory requirement.

The problem appears when the model itself is too large for one GPU. DDP replicates the full model states on every rank (process/GPU), so adding more GPUs improves throughput but doesn't reduce the per-GPU memory footprint. FSDP and ZeRO attack that specific problem by sharding model states across the data-parallel group.

Think of DDP like every warehouse node keeping a full copy of the same giant catalog. Coordination is easy, but the storage is wasteful. Every GPU keeps a complete copy of the parameters, gradients, and optimizer states.

ZeRO changes that trade-off. Instead of duplicating the full catalog on every rank, it partitions the catalog across ranks and reconstructs only the pieces needed for computation.[1] In logistics terms, ZeRO is like a distributed warehouse network: each regional hub stores only a fraction of the catalog and pulls the remaining items on demand when an order arrives.

  • DDP: Replicate everything. Simpler communication pattern, much higher memory use.
  • ZeRO: Shard model states. Lower memory use, more communication.

ZeRO stages: removing redundancy one layer at a time

ZeRO (Zero Redundancy Optimizer) is easiest to understand as an incremental evolution. Each stage removes one more redundant copy of the model states.

ZeRO stage 1: shard optimizer states

Each GPU keeps optimizer state for only 1/N1/N1/N of the parameters. Parameters and gradients are still replicated.

  • Memory savings: Optimizer state drops from 840 GB total to roughly 840/N840/N840/N GB per GPU.
  • Communication: Similar aggregate gradient-synchronization volume to the replicated baseline; optimizer ownership changes.

ZeRO stage 2: shard optimizer + gradients

Stage 2 also shards gradients, so each GPU owns only the gradient shard that matches its optimizer shard.

  • Memory savings: Optimizer states and gradients both scale down by roughly 1/N1/N1/N.
  • Communication: Gradient sync is typically implemented with reduce-scatter (a collective that sums values across GPUs and returns each GPU's shard) instead of all-reduce. The aggregate volume stays in the same ballpark as DDP, but the result is already partitioned.

ZeRO stage 3: shard everything

Stage 3 shards parameters as well. No rank owns a full persistent copy of the model states anymore.

  • Memory savings: Parameters, gradients, and optimizer states are all partitioned. For the 70B example on 256 GPUs, the model-state footprint falls to about 4.4 GB per GPU before activations.
  • Communication: Parameters now have to be materialized on demand for computation.
    1. Forward: all-gather (collect shards from all GPUs to rebuild the full tensor) the current layer or FSDP unit, run compute, then reshard.
    2. Backward: all-gather that unit again, compute gradients, then reduce-scatter gradient shards.
  • Trade-off: Memory scales almost linearly with the data-parallel group size, but communication becomes a first-class bottleneck.

Key insight: ZeRO doesn't reduce the aggregate model-state bytes; it eliminates persistent redundant copies. A 70B model still needs roughly 1.12 TB of model-state memory in aggregate, but ZeRO-3 partitions its steady-state storage across GPUs. Temporary materialized units and communication buffers still add to each rank's peak. The trade-off is communication: you pay in network bandwidth for what you save in persistent per-GPU state.

Memory per GPU (70B model, 256 GPUs)

The table below assumes BF16/FP16 weights and gradients with Adam states in FP32. The totals are for model states only. Activation memory is listed separately because it depends heavily on sequence length, hidden size, micro-batch size, checkpointing, and attention implementation.

ComponentDDPZeRO-1ZeRO-2ZeRO-3
Weights140 GB140 GB140 GB0.55 GB
Gradients140 GB140 GB0.55 GB0.55 GB
Optimizer + FP32 master weights840 GB3.3 GB3.3 GB3.3 GB
ActivationsExtra and workload-dependentExtra and workload-dependentExtra and workload-dependentExtra and workload-dependent
Model states / GPU1120 GB283 GB144 GB4.4 GB

This is why ZeRO-3 or FULL_SHARD is often paired with activation checkpointing (recomputing activations during the backward pass to save memory, also called gradient checkpointing). Sharding fixes model-state memory, but activations can still dominate the actual training footprint.

The same table should be generated from a recipe rather than copied into a sizing document:

zero_stage_memory.py
1state_gb = {"parameters": 140, "gradients": 140, "optimizer_and_master": 840} 2world_size = 256 3 4def per_rank_gb(stage): 5 sharded = { 6 0: set(), 7 1: {"optimizer_and_master"}, 8 2: {"optimizer_and_master", "gradients"}, 9 3: {"optimizer_and_master", "gradients", "parameters"}, 10 }[stage] 11 return sum(value / world_size if name in sharded else value for name, value in state_gb.items()) 12 13for stage in [0, 1, 2, 3]: 14 print(f"ZeRO-{stage}: {per_rank_gb(stage):.2f} GB/model-state rank")
ZeRO stage memory output
1ZeRO-0: 1120.00 GB/model-state rank 2ZeRO-1: 283.28 GB/model-state rank 3ZeRO-2: 143.83 GB/model-state rank 4ZeRO-3: 4.38 GB/model-state rank

Communication overhead analysis

ZeRO-3 buys memory by turning parameter access into communication. That makes the interconnect part of your training system design, not an afterthought. Just as a distributed warehouse system only works if the trucking network between hubs can handle replenishment traffic, ZeRO-3's memory savings depend on whether your GPU interconnect can keep up with all-gather volume.

Communication primitives you need to know

PrimitiveWhat it doesWhere it shows up
All-reduceSum across ranks, then give the full result back to every rankClassic DDP gradient sync
Reduce-scatterSum across ranks, but return only each rank's shardZeRO-2/3 gradient sync
All-gatherCollect shards from all ranks to rebuild the full tensorZeRO-3 and FSDP parameter materialization

Those names stop blurring once you track one tiny tensor through each collective. Read each panel left to right: same four ranks, different "after" state.

Comparison of all-reduce, reduce-scatter, and all-gather showing whether each rank ends with a full tensor or one shard. Comparison of all-reduce, reduce-scatter, and all-gather showing whether each rank ends with a full tensor or one shard.
Focus on one question: after the collective, does each rank keep the full tensor or only a shard? That distinction is what makes DDP, ZeRO, and FSDP feel different in practice.

The operations become concrete on a two-rank gradient vector. All-reduce leaves the summed vector on both ranks; reduce-scatter leaves one summed slice per rank; all-gather reconstructs a parameter vector from shards.

collective_primitives.py
1rank_gradients = [[1, 2], [10, 20]] 2summed = [sum(values) for values in zip(*rank_gradients)] 3 4all_reduce_result = [summed.copy(), summed.copy()] 5reduce_scatter_result = [[summed[0]], [summed[1]]] 6 7parameter_shards = [["p0"], ["p1"]] 8full_parameters = [token for shard in parameter_shards for token in shard] 9all_gather_result = [full_parameters.copy(), full_parameters.copy()] 10 11print("all_reduce=", all_reduce_result) 12print("reduce_scatter=", reduce_scatter_result) 13print("all_gather=", all_gather_result)
Collectives output
1all_reduce= [[11, 22], [11, 22]] 2reduce_scatter= [[11], [22]] 3all_gather= [['p0', 'p1'], ['p0', 'p1']]

The 1.5x heuristic

People often summarize ZeRO-3 communication as "about 1.5x DDP." That's a useful heuristic, but it's only a heuristic.

For a model with PPP parameter bytes:

  1. DDP does one gradient all-reduce per step, which is commonly modeled as about 2P2P2P bytes per rank.
  2. ZeRO-3 / FULL_SHARD does roughly two parameter all-gathers plus one gradient reduce-scatter, which is commonly modeled as about 3P3P3P bytes per rank.

That gives the familiar 3P/2P=1.5x3P / 2P = 1.5x3P/2P=1.5x rule of thumb. The harder part in practice isn't only byte volume, but how that volume is scheduled. DDP tends to use a smaller number of large collectives. ZeRO-3 and FSDP break communication into many layer-level or unit-level operations. If those collectives are small and your network has high latency, throughput can fall off quickly.

ZeRO-3 is most attractive when memory is the hard constraint. If the model already fits and the network is weak, DDP or ZeRO-2 style sharding can be faster.

communication_heuristic.py
1parameter_bytes = 140 # GB of low-precision parameters in the 70B example 2ddp_modeled_gb = 2 * parameter_bytes 3full_shard_modeled_gb = 3 * parameter_bytes 4 5print("DDP_modeled_traffic_GB=", ddp_modeled_gb) 6print("FULL_SHARD_modeled_traffic_GB=", full_shard_modeled_gb) 7print("byte_ratio=", full_shard_modeled_gb / ddp_modeled_gb) 8print("warning=latency_and_overlap_still_determine_step_time")
Communication heuristic output
1DDP_modeled_traffic_GB= 280 2FULL_SHARD_modeled_traffic_GB= 420 3byte_ratio= 1.5 4warning=latency_and_overlap_still_determine_step_time

FSDP (Fully Sharded Data Parallel)

FSDP is PyTorch's native sharded data-parallel stack.[2] The classic API is FullyShardedDataParallel (often called FSDP1). Current PyTorch exposes a DTensor-based API via torch.distributed.fsdp.fully_shard (often called FSDP2), and its tutorial marks FSDP1 deprecated while giving a migration guide. Treat FSDP1 as an existing-code path and start new designs from fully_shard.[3][4]

The main reason people choose FSDP is operational simplicity. It stays close to PyTorch's native distributed stack, native profilers, and native checkpointing tools. PyTorch's general torch.compile guidance is to compile the highest-level train or evaluation step, or the top-level module. If a distributed wrapper such as FSDP causes issues, compile the inner module instead. If you're compiling FSDP1, current docs also require use_orig_params=True.[5][6]

FSDP solves model-state memory. It doesn't solve activation memory. Long sequences and large micro-batches still push activations up, which is why FSDP is often paired with activation checkpointing.

How full sharding materializes one unit

For FULL_SHARD, or for an FSDP2 configuration that reshards after forward, one transformer-sized communication unit follows this lifecycle:

  1. Group: Parameters are grouped at transformer-block granularity.
  2. Shard: Each unit's parameters are sharded across the process group.
  3. Execution:
    • Before forward pass of a unit: All-gather full parameters.
    • After forward pass: Discard full parameters (keep only shard).
    • Before backward pass: All-gather full parameters again.
    • After backward pass: Reduce-scatter gradients and update the sharded optimizer state.
Full-sharding communication-unit lifecycle: persistent parameter shard, all-gather for forward, reshard, re-gather for backward, and reduce-scatter gradient shard. Full-sharding communication-unit lifecycle: persistent parameter shard, all-gather for forward, reshard, re-gather for backward, and reduce-scatter gradient shard.
With reshard-after-forward behavior, full parameters exist only around compute for the current block; persistent storage returns to shards after the unit finishes.

This sequence diagram shows one wrapped unit in a two-GPU FSDP job.

Diagram showing GPU 0 and GPU 1. Diagram showing GPU 0 and GPU 1.
GPU 0 and GPU 1.

The following code shows a practical FSDP1 setup. You will still meet this API in existing codebases, so it is worth reading even though new designs should begin with fully_shard. The important details are:

  1. Wrap at the transformer-block level instead of wrapping the whole model as one giant unit.
  2. Create the optimizer after FSDP wraps the model.
  3. Don't confuse backward_prefetch with limit_all_gathers. The first helps overlap communication with compute. The second is mainly a CPU-side rate limiter that caps in-flight all-gathers and peak memory.[6]

This is distributed setup code, not a single-process notebook cell. It assumes torch.distributed.init_process_group() has already run under torchrun and that each rank has selected its local CUDA device.

how-fsdp-works.py
1import torch 2import torch.nn as nn 3import functools 4from torch.distributed.fsdp import ( 5 BackwardPrefetch, 6 FullyShardedDataParallel as FSDP, 7 MixedPrecision, 8 ShardingStrategy, 9) 10from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 11 12class TransformerDecoderLayer(nn.Module): 13 def __init__(self, d_model=1024, n_heads=16): 14 super().__init__() 15 self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) 16 self.ffn = nn.Sequential( 17 nn.Linear(d_model, 4 * d_model), 18 nn.GELU(), 19 nn.Linear(4 * d_model, d_model), 20 ) 21 self.norm1 = nn.LayerNorm(d_model) 22 self.norm2 = nn.LayerNorm(d_model) 23 24 def forward(self, x, mask=None): 25 attn_out, _ = self.self_attn(x, x, x, attn_mask=mask) 26 x = self.norm1(x + attn_out) 27 ffn_out = self.ffn(x) 28 x = self.norm2(x + ffn_out) 29 return x 30 31my_auto_wrap_policy = functools.partial( 32 transformer_auto_wrap_policy, 33 transformer_layer_cls={TransformerDecoderLayer}, 34) 35 36mixed_precision_policy = MixedPrecision( 37 param_dtype=torch.bfloat16, 38 reduce_dtype=torch.bfloat16, 39 buffer_dtype=torch.bfloat16, 40) 41 42model = nn.Sequential(*[ 43 TransformerDecoderLayer(d_model=1024, n_heads=16) 44 for _ in range(8) 45]) 46 47fsdp_model = FSDP( 48 model, 49 auto_wrap_policy=my_auto_wrap_policy, 50 sharding_strategy=ShardingStrategy.FULL_SHARD, 51 mixed_precision=mixed_precision_policy, 52 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 53 limit_all_gathers=True, 54 device_id=torch.cuda.current_device(), 55) 56 57optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=2e-5) # example SFT start; evaluate it

FSDP2: the newer DTensor-based API

FSDP2 moves away from the flat-parameter wrapper model and toward per-parameter sharding with DTensor. The point isn't only elegance. It changes how composable and inspectable the system is.

For new code, the structural pattern is bottom-up application of fully_shard, then optimizer construction over the resulting DTensor parameters:[3]

fsdp2_structure.py
1from torch.distributed.fsdp import fully_shard 2 3model = TransformerModel() 4for block in model.blocks: 5 fully_shard(block) # one communication group per transformer block 6fully_shard(model) # shard remaining root-level parameters 7 8optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

This snippet shows structure rather than a runnable local demo: fully_shard must execute inside a distributed process group on suitable accelerators. The learning rate remains an SFT baseline to evaluate, not a distributed-training requirement.

FeatureFSDP1 (FullyShardedDataParallel)FSDP2 (fully_shard)
How it applies shardingWrapper module with flat parametersIn-place hooks on original modules
Sharding representationFlatParameter groupsPer-parameter DTensor
Parameter namesFlattened internals to reason aboutOriginal FQNs stay intact
Composition with TP / frozen paramsMore awkwardBetter fit with DeviceMesh and frozen params
Checkpointing styleFull and sharded state dict APIsSharded state dicts first, reshard to full when needed

Key technical improvements

  1. Cleaner composition: FSDP2 uses DeviceMesh, which makes it much easier to compose data-parallel sharding with tensor parallelism in the same topology.

  2. Cleaner state dict story for large runs: Sharded state dicts are first-class, and distributed checkpointing avoids the old pattern of gathering everything onto rank 0 for saving.

  3. Cleaner handling of frozen parameters: Per-parameter sharding is a better fit for PEFT and LoRA-style workflows where only a small subset of weights should update.

  4. Different scheduling controls: fully_shard is no longer a wrapper configured through FSDP1's auto_wrap_policy, backward_prefetch, use_orig_params, or limit_all_gathers arguments. Apply it bottom-up, then use FSDP2 methods such as set_modules_to_forward_prefetch or set_modules_to_backward_prefetch when explicit prefetch control is needed.[3]

The practical takeaway is simple: FSDP1 still shows up in existing codebases, while current PyTorch's FSDP2 documentation provides the migration path and fully_shard API for new work.

The most important FSDP choice is usually sharding granularity. Use transformer blocks as communication units. If one root unit owns the entire model, every all-gather becomes too large and peak memory shoots back up.

This simplified calculation isolates why granularity matters. It models only the transient full-parameter materialization, not activations or prefetch overlap:

fsdp_unit_granularity.py
1block_parameter_gb = [3.5, 3.5, 3.5, 3.5] 2 3root_only_peak_full_materialization = sum(block_parameter_gb) 4blockwise_peak_full_materialization = max(block_parameter_gb) 5 6print("root_only_active_parameters_GB=", root_only_peak_full_materialization) 7print("blockwise_active_parameters_GB=", blockwise_peak_full_materialization) 8print("modeled_reduction=", root_only_peak_full_materialization / blockwise_peak_full_materialization)
Sharding granularity output
1root_only_active_parameters_GB= 14.0 2blockwise_active_parameters_GB= 3.5 3modeled_reduction= 4.0

FSDP1 sharding strategies reference

StrategyEquivalentDescription
FULL_SHARDZeRO-3Shards parameters, gradients, and optimizer states. Most memory efficient.
SHARD_GRAD_OPZeRO-2-likeUnshards parameters for forward, keeps them unsharded until backward finishes, then reshards. Uses more memory than FULL_SHARD, but does fewer reshard steps.
NO_SHARDDDP-like replicationKeeps model states replicated and avoids sharded parameter materialization, but doesn't reduce per-rank model-state memory.
HYBRID_SHARDZeRO-3 within node + replication across nodesShards within a node and replicates across nodes. Useful when intra-node links are fast and inter-node bandwidth is tighter.

Checkpointing, resume, and fault recovery

Large training runs fail in the middle all the time. A node reboots. A preemptible instance disappears. A job gets rescheduled. If your checkpoint only contains model weights, you didn't save the training run. You saved an export artifact.[7]

For a real resume point, the checkpoint usually needs at least:

Must-save stateWhy it matters on resume
Model weightsRestores current parameters
Optimizer statePreserves momentum and adaptive moments
Scheduler stateKeeps warmup and decay aligned with training step
Gradient-scaler / AMP state, if usedPreserves mixed-precision control state across resume
Consumed step or token countLets logs, schedulers, and stop rules stay honest
Sampler / dataloader cursorPrevents replaying or skipping large data regions
RNG stateKeeps dropout, shuffling, and sampling reproducible when needed

In a sharded system, the safest default is to save sharded checkpoints, not to gather everything onto rank 0 first. FSDP2 and distributed checkpoint APIs are pushing in that direction because the "collect all weights on one machine and write a giant file" pattern eventually becomes the bottleneck or the failure point.[3][7]

DeepSpeed goes one step further with Universal Checkpointing, which is meant to make checkpoint artifacts more portable across different parallelism layouts instead of tying restore logic to the exact original topology.[8] That matters once you start changing world size between runs or resume on a different cluster shape.

The operational rule is simple:

  1. Save checkpoints by training step and consumed tokens, not vague names like latest-final.
  2. Test one real resume path early in the run, before trusting a week-long job to it.
  3. Verify that resumed loss on the next batch is close to the non-resumed run.

If you can't answer "what exact state do we restore, and how do we know resume is faithful?", the training system isn't production-ready yet.

A distributed checkpoint manifest also has to record the sharding topology it was written under, even when a portable conversion path exists:

distributed_checkpoint_contract.py
1required = { 2 "model_shards", 3 "optimizer_shards", 4 "scheduler_state", 5 "consumed_tokens", 6 "sampler_cursor", 7 "rng_state", 8 "gradient_scaler_state", 9 "template_tokenizer_version", 10 "data_manifest", 11 "evaluation_manifest", 12 "best_metric", 13 "world_size", 14 "sharding_strategy", 15} 16 17manifest = { 18 "model_shards": "ckpt/step_800/model/", 19 "optimizer_shards": "ckpt/step_800/optimizer/", 20 "scheduler_state": "ckpt/step_800/scheduler.pt", 21 "consumed_tokens": 104_857_600, 22 "sampler_cursor": {"epoch": 0, "batch": 800}, 23 "rng_state": "ckpt/step_800/rng.pt", 24 "gradient_scaler_state": None, # BF16 run; store scaler state for FP16 when used 25 "template_tokenizer_version": "support-chat-v3", 26 "data_manifest": "manifests/sft-train-v7.json", 27 "evaluation_manifest": "manifests/support-eval-v4.json", 28 "best_metric": {"name": "support_resolution_accuracy", "value": 0.78}, 29 "world_size": 32, 30 "sharding_strategy": "FULL_SHARD", 31} 32 33missing = sorted(required - manifest.keys()) 34print("resume_manifest_complete=", not missing) 35print("saved_topology=", manifest["sharding_strategy"], manifest["world_size"]) 36print("gradient_scaler_state=", manifest["gradient_scaler_state"]) 37print("best_metric=", manifest["best_metric"])
Distributed checkpoint output
1resume_manifest_complete= True 2saved_topology= FULL_SHARD 32 3gradient_scaler_state= None 4best_metric= {'name': 'support_resolution_accuracy', 'value': 0.78}

DeepSpeed ZeRO

DeepSpeed is Microsoft's distributed training runtime that introduced ZeRO and provides heterogeneous-memory offload through ZeRO-Offload and ZeRO-Infinity.[1][9] It's got a steeper learning curve than native PyTorch, but it remains a strong choice when memory pressure is the real blocker or when the rest of your stack already depends on Megatron-DeepSpeed style training.

One important DeepSpeed-specific detail is that ZeRO-3 depends on consistent module execution order across ranks. Dynamic routing patterns such as MoE can deadlock parameter all-gathers if different ranks enter different submodules, which is why DeepSpeed exposes the idea of ZeRO-3 leaf modules.[10]

The following distributed setup snippet shows how to configure and initialize a model using DeepSpeed. Run it through the DeepSpeed launcher in a real training job. The configuration dictionary defines the ZeRO stage, offloading settings, optimizer, and training batch parameters. The learning rate is an example SFT starting point to evaluate, and the bucket sizes are tuning knobs rather than canonical values:

deepspeed-zero.py
1import deepspeed 2import torch.nn as nn 3 4ds_config = { 5 "train_micro_batch_size_per_gpu": 1, 6 "gradient_accumulation_steps": 8, 7 "bf16": { 8 "enabled": True, 9 }, 10 "optimizer": { 11 "type": "DeepSpeedCPUAdam", 12 "params": { 13 "lr": 2e-5, 14 "betas": [0.9, 0.999], 15 "eps": 1e-8, 16 "weight_decay": 0.01, 17 }, 18 }, 19 "zero_optimization": { 20 "stage": 3, 21 "offload_optimizer": { 22 "device": "cpu", 23 "pin_memory": True, 24 }, 25 "offload_param": { 26 "device": "cpu", 27 "pin_memory": True, 28 }, 29 "overlap_comm": True, 30 "contiguous_gradients": True, 31 "reduce_bucket_size": 500_000_000, 32 "stage3_prefetch_bucket_size": 50_000_000, 33 "stage3_param_persistence_threshold": 100_000, 34 } 35} 36 37class SimpleTransformer(nn.Module): 38 def __init__(self, d_model=1024, n_layers=4): 39 super().__init__() 40 self.layers = nn.ModuleList([ 41 nn.TransformerEncoderLayer(d_model, nhead=16, batch_first=True) 42 for _ in range(n_layers) 43 ]) 44 self.proj = nn.Linear(d_model, 1000) 45 46 def forward(self, x): 47 for layer in self.layers: 48 x = layer(x) 49 return self.proj(x) 50 51model = SimpleTransformer(d_model=1024, n_layers=4) 52 53model_engine, optimizer, _, _ = deepspeed.initialize( 54 model=model, 55 config=ds_config, 56 model_parameters=model.parameters(), 57) 58 59# for batch in dataloader: 60# outputs = model_engine(batch) 61# loss = compute_loss(outputs) 62# model_engine.backward(loss) 63# model_engine.step()

ZeRO-Infinity: CPU and NVMe offloading

DeepSpeed extends ZeRO-3 with ZeRO-Infinity, which allows offloading sharded parameters and optimizer states to CPU RAM or even NVMe SSDs.[9]

TierMemory PoolUse Case
GPUGPU HBMActive compute, activations, temporary buffers
CPUSystem RAMOptimizer states or parameter shards when GPU memory is tight
NVMeLocal SSD / NVMeDeep spillover tier when RAM still isn't enough

This makes some otherwise impossible training runs feasible, but it doesn't make offload free. Once parameters or optimizer state spill into CPU or NVMe, throughput becomes constrained by PCIe, host memory bandwidth, storage latency, and how well the runtime can overlap data movement with compute.

DeepSpeed's feature matrix is broader, but it isn't "all combinations are valid." Current docs say AutoTP training supports ZeRO stages 0, 1, and 2, while PipelineModule isn't compatible with ZeRO-2 or ZeRO-3.[10][11]

Turn compatibility constraints into a config gate, not a surprise after scheduling a large job:

deepspeed_compatibility_gate.py
1requested_runs = [ 2 {"feature": "AutoTP", "zero_stage": 2}, 3 {"feature": "AutoTP", "zero_stage": 3}, 4 {"feature": "PipelineModule", "zero_stage": 2}, 5] 6 7def supported(run): 8 if run["feature"] == "AutoTP": 9 return run["zero_stage"] in {0, 1, 2} 10 if run["feature"] == "PipelineModule": 11 return run["zero_stage"] in {0, 1} 12 return False 13 14for run in requested_runs: 15 print(run["feature"], f"ZeRO-{run['zero_stage']}", "allowed=", supported(run))
Compatibility gate output
1AutoTP ZeRO-2 allowed= True 2AutoTP ZeRO-3 allowed= False 3PipelineModule ZeRO-2 allowed= False

FSDP vs DeepSpeed comparison

Both frameworks solve the same core problem, but they optimize for different operational constraints. FSDP optimizes for staying close to standard PyTorch. DeepSpeed optimizes for a wider set of large-scale runtime features.

FeatureFSDPDeepSpeed ZeRO
Runtime modelNative PyTorch distributed stackSeparate DeepSpeed runtime
ZeRO stagesZeRO-2-like and ZeRO-3-like sharding modesNative ZeRO-1/2/3
CPU offloadingCPUOffload for params and gradientsStrong CPU and NVMe offload
NVMe offloadingNoYes
Compile storyCompile the train step or top-level module first; fall back to the inner module if the wrapper causes issues. FSDP1 needs use_orig_params=TrueDepends on integration path
Pipeline parallelExternalBuilt-in, but PipelineModule excludes ZeRO-2/3
Tensor parallelExternalAutoTP or Megatron integration (AutoTP supports ZeRO-0/1/2)
Checkpointing / toolingStandard PyTorch APIsDeepSpeed-specific runtime APIs
DebuggingStandard PyTorch stack traces and profilerMore runtime-specific behavior

When to use which?

  • Use FSDP when you want the PyTorch-native path, the model fits in aggregate GPU memory once sharded, and standard PyTorch profiling and checkpointing matter more than exotic runtime features.
  • Use DeepSpeed when CPU or NVMe offload is non-negotiable, or when you're already in a DeepSpeed or Megatron-DeepSpeed stack whose TP/PP requirements fit the current stage-compatibility rules.

Beyond data parallelism: 3D parallelism

For very large jobs, data-parallel sharding alone isn't enough. As you push the data-parallel group wider, global batch size, optimizer dynamics, and cross-node communication all start fighting you. The standard answer is to combine multiple forms of parallelism in the same training topology.[12]

That's what people mean by 3D parallelism: data parallelism for replica groups, tensor parallelism (splitting weight matrices within a layer across GPUs) inside layers, and pipeline parallelism (splitting sequential layer groups across different devices) across layer ranges.

The diagram below shows how data, tensor, and pipeline parallelism combine to distribute a model across a large cluster. Solid arrows show pipeline flow through layer ranges. Dotted arrows show synchronization between matching data-parallel replicas.

3D parallelism topology with two data-parallel replicas, pipeline stages across depth, and tensor-parallel ranks within each stage. 3D parallelism topology with two data-parallel replicas, pipeline stages across depth, and tensor-parallel ranks within each stage.
3D parallelism is not one split. Data parallelism repeats the pipeline, pipeline parallelism splits depth, and tensor parallelism splits the layer work inside each stage.

Tensor parallelism (TP)

Tensor parallelism splits the work inside a layer. Instead of putting a full weight matrix on one GPU, it partitions the matrix across GPUs.

  • How it works: In a transformer MLP, the first projection is often column-sharded and the second projection row-sharded. Each rank computes a partial result, then the ranks synchronize those partials.
  • Bandwidth requirement: This synchronization happens inside the layer, so TP strongly prefers very fast intra-node links such as NVLink or NVSwitch (high-bandwidth GPU interconnects).

Pipeline parallelism (PP)

Pipeline parallelism splits the model across depth by assigning different layer ranges to different devices or nodes.

  • How it works: Stage 0 might own layers 1-12, stage 1 owns 13-24, and activations flow from one stage to the next.
  • Main trade-off: You reduce memory pressure per rank and avoid TP-style per-layer collectives across the whole model, but you introduce pipeline bubbles and need micro-batching to keep stages busy.
ParallelismPrimary shardCommunication patternBest for
Data / ZeRO / FSDPBatch across replicas, with model states optionally shardedAll-reduce, reduce-scatter, all-gatherGeneral scaling and memory reduction
Tensor (TP)Weight matrices within layersCollectives inside each layerVery wide layers on fast intra-node links
Pipeline (PP)Sequential layer groupsPoint-to-point activations and gradientsVery deep models and multi-node partitioning

The escalation path should be deliberate.[13]

  • Start with single-node or data-parallel sharding when the model mostly fits and you want the simplest debug story.
  • Add tensor parallelism when single layers are too wide for one device or when matmul shards need to stay on NVLink-class links.
  • Add pipeline parallelism when layer depth itself must be partitioned across devices.

A common placement keeps TP on fast intra-node links, uses PP to partition depth across broader topology boundaries, and applies data parallelism across replica groups. The right degrees still depend on available links, memory, model shape, and global-batch constraints.

In current PyTorch, a DeviceMesh is the abstraction used to express multiple device dimensions, which lets FSDP2 compose with tensor parallelism on the same hardware (often called FSDP + TP, or 2D parallelism). Megatron Core also documents context parallelism, which splits the sequence dimension for long-context workloads. Combining data, tensor, pipeline, and context parallelism creates a four-axis topology; runtimes don't always use the same "3D" or "4D" label for that composition.[3][13]

Each topology degree multiplies into the required world size. Calculate it before reserving hardware:

parallelism_topology_size.py
1topology = { 2 "data_parallel": 8, 3 "tensor_parallel": 4, 4 "pipeline_parallel": 2, 5 "context_parallel": 2, 6} 7 8world_size = 1 9for degree in topology.values(): 10 world_size *= degree 11 12print("world_size=", world_size) 13print("ranks_per_data_replica=", topology["tensor_parallel"] * topology["pipeline_parallel"] * topology["context_parallel"]) 14assert world_size == 128
Topology size output
1world_size= 128 2ranks_per_data_replica= 16

Debugging & profiling distributed training

Distributed training fails in ways that single-GPU training doesn't. Jobs hang instead of crash. One bad rank can stall everybody else. Performance regressions often come from communication scheduling, not math kernels.

Common failure modes

  1. NCCL timeouts or hangs: Often one rank died earlier or entered a different collective pattern. TORCH_NCCL_BLOCKING_WAIT=1 can turn a silent hang into a visible failure.
  2. Startup OOM: Large models can OOM before training starts if you materialize full parameters too early. Lazy or meta-device initialization matters.
  3. Wrong gradient handling: With an FSDP1 sharded strategy, use fsdp_model.clip_grad_norm_() because it computes across sharded gradients. FSDP2's DTensor-based parameters instead support torch.nn.utils.clip_grad_norm_(model.parameters(), ...) in the current tutorial.[6][4]

Profiling tools

To optimize throughput (tokens/sec), you must identify if you're compute-bound or communication-bound.

  • torch.profiler: Look at ncclAllGather, ncclReduceScatter, and gaps where GPUs wait for communication. For FSDP1, tune wrapping granularity and backward_prefetch; use limit_all_gathers=True for peak-memory rate limiting, not as the overlap knob. For FSDP2, use its module prefetch APIs when implicit scheduling is insufficient.[6][3]
  • DeepSpeed Flops Profiler or DeepSpeed runtime logs: Use them to break step time into compute vs communication and to see whether offload is the bottleneck.

For hard distributed bugs, capture one clean repro run with NCCL_DEBUG=INFO and TORCH_DISTRIBUTED_DEBUG=DETAIL instead of trying to reason from high-level symptoms alone.

Training observability signals that matter

Once the job runs, you still need to know why it is slow or unstable. Good training dashboards separate data, compute, communication, and memory instead of showing only one loss curve.[13]

SignalWhat it answersBad symptom
tokens/secIs end-to-end throughput improving?flat or falling throughput after adding GPUs
step time splitIs time spent in data load, forward, backward, optimizer, or checkpoint save?one phase dominates without explanation
all-gather / reduce-scatter shareIs sharding traffic now bottlenecking the job?communication time grows faster than compute time
dataloader wait timeAre GPUs starved by CPU preprocessing or storage?GPU idle gaps before forward
peak HBM and reserved memoryAre you close to OOM or fragmenting memory?retries, allocator spikes, sudden OOM at checkpoint save
MFU or FLOP utilizationAre expensive kernels doing useful math?very low utilization despite full memory use
checkpoint save durationIs fault tolerance becoming a throughput tax?long pauses every save interval

If you only watch training loss, a run can look healthy while throughput quietly collapses. Distributed training needs systems telemetry, not only ML telemetry.

Training-system bottleneck map

Strong distributed-training debugging connects the abstraction to the bottleneck. A useful answer names the split, the communication primitive, and the symptom you would measure.

TopicWhat to explainDebug signal
Data parallelismeach rank owns a batch shard and synchronizes gradientsall-reduce time rises with parameter size
Tensor parallelismone layer is split across ranksfrequent collectives inside attention and MLP blocks
Pipeline parallelismdifferent ranks own different layer rangesidle bubbles when micro-batching is too small
ZeRO / FSDPoptimizer states, gradients, and parameters are shardedall-gather and reduce-scatter dominate step time
Activation checkpointingsave memory by recomputing activations in backwardlower peak memory, higher compute time
NCCL bottleneckscollectives depend on rank order, topology, and matching callshangs, timeouts, or wide gaps in profiler traces
FlashAttentionreduce attention memory traffic with tiled kernelsbetter long-context throughput when attention is memory-bound

A practical debugging loop is:

  1. Reproduce on the smallest rank count that still fails.
  2. Confirm every rank enters the same collective sequence.
  3. Capture profiler traces and separate compute, communication, and idle time.
  4. Check peak memory before and after forward, backward, and optimizer step.
  5. Change one axis at a time: sharding granularity, micro-batch size, checkpointing, or placement.

This is the bridge between ML and distributed systems. The right answer is rarely "use FSDP" or "use DeepSpeed." The stronger answer is "model states don't fit, so I shard them; then I measure whether communication or activation memory became the new bottleneck."

A worked sizing exercise

A common sizing prompt is: "How many 80 GB GPUs do you need to train a 70B model with ZeRO-3?" Walk through it step by step.

Step 1: count the model states.

With mixed-precision Adam, each parameter carries:

  • 2 bytes (BF16 weight)
  • 2 bytes (BF16 gradient)
  • 12 bytes (FP32 master weight + Adam moments)

Total: 16 bytes per parameter.

For 70B parameters: 70×109×16 bytes=1,120 GB70 \times 10^9 \times 16 \text{ bytes} = 1{,}120 \text{ GB}70×109×16 bytes=1,120 GB of model states.

Step 2: compute the model-state lower bound.

If you looked only at model states, the absolute lower bound on 80 GB GPUs would be:

1,120 GB80 GB/GPU=14\frac{1{,}120 \text{ GB}}{80 \text{ GB/GPU}} = 1480 GB/GPU1,120 GB​=14

So 14 GPUs is the math-floor for model states alone.

The same calculation is worth making executable:

zero3_gpu_floor.py
1import math 2 3total_state_gb = 70 * 16 4gpu_memory_gb = 80 5 6raw_floor = math.ceil(total_state_gb / gpu_memory_gb) 7print("raw_lower_bound_gpus=", raw_floor) 8 9for gpus in [14, 16, 24]: 10 per_gpu = total_state_gb / gpus 11 headroom = gpu_memory_gb - per_gpu 12 print(f"{gpus:>2} GPUs -> {per_gpu:5.1f} GB states/GPU, {headroom:5.1f} GB before activations")
GPU lower-bound output
1raw_lower_bound_gpus= 14 214 GPUs -> 80.0 GB states/GPU, 0.0 GB before activations 316 GPUs -> 70.0 GB states/GPU, 10.0 GB before activations 424 GPUs -> 46.7 GB states/GPU, 33.3 GB before activations

Step 3: add overhead.

The math above is for model states only. Real training also needs:

  • Activations (depends on sequence length and batch size)
  • Communication buffers for all-gather and reduce-scatter
  • CUDA allocator overhead and fragmentation

There isn't one honest fixed headroom percentage: a context-length or micro-batch change can dominate it. Measure or estimate the remaining peak memory for the specific workload, then compare candidate topologies:

zero3_workload_budget.py
1state_total_gb = 1120 2gpu_capacity_gb = 80 3measured_non_state_peak_gb = 18 # activations + buffers + allocator margin for this trial workload 4 5def fits(world_size): 6 state_per_gpu = state_total_gb / world_size 7 peak = state_per_gpu + measured_non_state_peak_gb 8 return state_per_gpu, peak, peak <= gpu_capacity_gb 9 10for world_size in [16, 24, 32]: 11 state, peak, ok = fits(world_size) 12 print(f"{world_size} GPUs -> states={state:.1f} GB peak_estimate={peak:.1f} GB fits={ok}")
Workload memory budget output
116 GPUs -> states=70.0 GB peak_estimate=88.0 GB fits=False 224 GPUs -> states=46.7 GB peak_estimate=64.7 GB fits=True 332 GPUs -> states=35.0 GB peak_estimate=53.0 GB fits=True

Step 4: answer the question.

The clean answer is: 14 GPUs is only the raw ZeRO-3 lower bound for model states under this 16-byte recipe. It isn't a runnable cluster recommendation. A practical count requires a measured or modeled activation/buffer peak and an acceptable throughput result. In the example above, an 18 GB non-state peak rules out 16 GPUs but fits at 24; a longer sequence could invalidate that answer again. Activation checkpointing, a smaller micro-batch, tensor parallelism, or pipeline parallelism may change the result.

Common pitfalls

Misidentifying which axis is split

Symptom: A team says they "added ZeRO-3 for tensor parallelism," but profiler traces still show whole-layer matmuls on each rank and the real change is only lower model-state memory.

Cause: ZeRO and FSDP shard model states across data-parallel ranks. They do not split the matrix multiply itself. Tensor parallelism splits layer computation, and pipeline parallelism splits model depth.

Fix: Check what moved in the trace or memory profile before naming the technique. If weights, gradients, and optimizer states shrank per rank, that is sharding. If one layer's matmul is split across devices, that is tensor parallelism.

Wrapping the whole model as one FSDP unit

Symptom: You enable ZeRO-3, but peak memory is still unexpectedly high and you get OOM during the forward pass.

Cause: You wrapped the entire model as a single FSDP unit instead of wrapping individual transformer layers. When the whole model is one unit, FSDP all-gathers every parameter at once. That defeats the purpose of layer-by-layer sharding.

Fix: Use auto_wrap_policy to wrap at the transformer-block level. The code example earlier shows how to do this with transformer_auto_wrap_policy.

Thinking FSDP1 limit_all_gathers=True creates overlap

Symptom: You set limit_all_gathers=True and expect faster training, but throughput doesn't improve.

Cause: limit_all_gathers is a CPU-side rate limiter. It caps how many all-gathers are in flight at once to control peak memory. It isn't an overlap or performance knob.

Fix: In FSDP1, use backward_prefetch and communication-unit granularity for overlap. In FSDP2, use its prefetch methods when needed; it doesn't expose FSDP1's limit_all_gathers argument.

Treating SHARD_GRAD_OP as identical to ZeRO-2

Symptom: You switch from ZeRO-2-style reasoning to FSDP SHARD_GRAD_OP, then wonder why peak memory is higher than expected even though gradients are sharded.

Cause: The names are similar, but the parameter lifetime differs. PyTorch's SHARD_GRAD_OP keeps parameters unsharded through forward and backward, then reshards after backward.[6] That changes peak memory behavior even when the high-level mental model sounds like ZeRO-2.

Fix: Confirm the actual parameter lifetime in docs and profiler traces before sizing memory. If you need the lowest parameter footprint, compare against FULL_SHARD rather than assuming SHARD_GRAD_OP behaves like a DeepSpeed stage name.

Ignoring activation memory

Symptom: The model-state math says everything should fit, but you still OOM.

Cause: Activations can be larger than model states for long sequences. A single forward pass with a long context can generate gigabytes of intermediate tensors.

Fix: Pair ZeRO-3 or FSDP with activation checkpointing. Trade compute for memory by recomputing activations during the backward pass instead of storing them.

Assuming every DeepSpeed feature composes with ZeRO-3

Symptom: You configure AutoTP with ZeRO-3 and the job crashes or hangs.

Cause: DeepSpeed has documented compatibility limits. AutoTP currently supports ZeRO stages 0, 1, and 2, but not 3. PipelineModule isn't compatible with ZeRO-2 or ZeRO-3.

Fix: Check the stage-compatibility matrix before you design your training topology. Don't assume that because DeepSpeed supports both features, they work together.

What you should be able to defend

By the end of this lesson, you should be able to explain and operate these parts of a distributed training setup:

  • How weights, gradients, FP32 master weights, and Adam moments create the memory wall.
  • Why DDP improves throughput but doesn't reduce per-GPU model-state memory.
  • What ZeRO-1, ZeRO-2, and ZeRO-3 shard.
  • How FSDP1 FULL_SHARD, SHARD_GRAD_OP, NO_SHARD, and HYBRID_SHARD trade memory for communication, and why new code should start with FSDP2 fully_shard.
  • Why FSDP2's fully_shard is different from the FSDP1 wrapper API.
  • How all-reduce, reduce-scatter, and all-gather show up in training traces.
  • When to use FSDP, DeepSpeed ZeRO, offload, tensor parallelism, and pipeline parallelism.
  • How to size a 70B-style training job before accounting for activations and communication buffers.
  • Which profiler and observability signals reveal communication, memory, checkpointing, and dataloader bottlenecks.

Final selection workflow

Use this sequence when sizing a real run:

  1. Compute model-state memory first. If the bytes-per-parameter math already blows past one GPU, DDP is not your answer.
  2. Ask what is actually limiting memory. If sharding fixes model states but activations still dominate, add activation checkpointing or reduce sequence and micro-batch before changing frameworks.
  3. Measure the new bottleneck after sharding. If step time now shifts into all-gather and reduce-scatter, network quality and wrapping granularity matter more than the headline ZeRO stage.
  4. Choose the operational stack. Prefer FSDP when you want native PyTorch tooling and the run fits once sharded. Prefer DeepSpeed when CPU or NVMe offload, or an existing DeepSpeed stack, is the real requirement.
  5. Escalate topology only when needed. Add tensor parallelism when layers are too wide for one device, pipeline parallelism when depth must be split, and always verify stage-compatibility before combining features.

Quick decision reference

ScenarioRecommendationWhy
Model fits in aggregate GPU memory and you want native PyTorch toolingFSDP2 fully_shard, applied bottom-upCurrent native path with DTensor-backed sharding and PyTorch checkpointing
Maintaining FSDP1 code and willing to spend more memory for fewer reshard eventsFSDP1 SHARD_GRAD_OPHigher peak memory, less aggressive resharding
GPU memory is the blocker and CPU/NVMe offload is acceptableDeepSpeed ZeRO-InfinityUses larger memory tiers to make the run feasible
Need DeepSpeed tensor parallel training with optimizer shardingDeepSpeed AutoTP + ZeRO-1/2Current docs support AutoTP with ZeRO stages 0, 1, and 2, but not 3
Need DeepSpeed PipelineModuleDeepSpeed pipeline + ZeRO-0/1Current docs say PipelineModule isn't compatible with ZeRO-2 or ZeRO-3
Next Step
Continue to LoRA & Parameter-Efficient Tuning

Distributed training shows how full-model updates become possible; LoRA shows when you can avoid that cost by training small adapter matrices for targeted adaptation instead of updating the entire model.

PreviousSupervised Fine-Tuning Pipeline
Share this article
XFacebookLinkedInBlueskyRedditHacker NewsEmail
References

ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.

Rajbhandari, S., et al. · 2020 · SC 2020

PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel.

Zhao, Y., et al. · 2023 · VLDB 2023

torch.distributed.fsdp.fully_shard

PyTorch Contributors · 2025

Getting Started with Fully Sharded Data Parallel (FSDP2)

PyTorch Contributors · 2025

Where to apply torch.compile?

PyTorch Contributors · 2025

FullyShardedDataParallel

PyTorch Contributors · 2025

Checkpointing in torchtune.

PyTorch Contributors · 2026

Universal Checkpointing with DeepSpeed: A Practical Guide.

DeepSpeed Team · 2026

ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning.

Rajbhandari, S., et al. · 2021 · SC 2021

Training API

DeepSpeed Team · 2026

Pipeline Parallelism

DeepSpeed Team · 2026

Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.

Shoeybi, M., et al. · 2019

Parallelism Strategies Guide.

NVIDIA · 2026