Data parallel, model parallel, FSDP — the engineering nuances that turn "scales linearly" into "works at all".
Key idea
The cluster gives you many GPUs; the framework lets them cooperate. Data parallel: same model on each GPU, different data, all-reduce gradients. Model parallel: split the model across GPUs. FSDP / ZeRO: shard both. The right choice depends on what fits — and "what fits" is the main constraint.
Data parallel (DDP). The default. Each GPU has a full copy of the model. They process different batches, then all-reduce gradients. Easy to set up, scales well for models that fit on one GPU.
FSDP (Fully Sharded Data Parallel). Shards parameters, gradients, and optimizer state across GPUs. Lets you train models that wouldn't fit on a single GPU. Built into PyTorch; DeepSpeed's ZeRO has comparable functionality.
Tensor / model parallel. Split individual layers across GPUs. Used for very large models where even FSDP is insufficient. Megatron, FairScale. Most teams don't need this.
Pipeline parallel. Run different layers on different GPUs, pipelined. GPipe, PipeDream. Useful when one model doesn't fit but a layer does.
Pick by model size
< 1 GPU fits: don't bother distributed — single GPU is simplest
Model fits, want more throughput: DDP
Model nearly fits: FSDP / ZeRO-2
Model too big for any GPU: FSDP + activation checkpointing + sometimes tensor parallel
Hundreds of GPUs: 2D / 3D parallelism (DDP × FSDP × pipeline)
Common pitfalls
Different seeds per rank → ranks diverge silently
Logging from all ranks → step counter inflated by world_size
Batch norm with global stats vs per-rank → wrong stats in distributed
Saving from rank ≠ 0 → race conditions, corrupted checkpoints
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup_ddp():
dist.init_process_group("nccl") # NCCL for NVIDIA, gloo for CPU
rank = dist.get_rank()
torch.cuda.set_device(rank)
return rank
def train(rank):
model = MyNet().cuda(rank)
model = DDP(model, device_ids=[rank])
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
rank=rank, shuffle=True)
loader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # different shuffle per epoch
for x, y in loader:
loss = loss_fn(model(x.cuda(rank)), y.cuda(rank))
loss.backward()
opt.step(); opt.zero_grad()
# Save only on rank 0
if rank == 0:
torch.save(model.module.state_dict(), f"epoch_{epoch}.pt")
# Launch: torchrun --nproc-per-node=4 train.py
Want FSDP, checkpoint sharding, & gradient communication optimisations?
FSDP memory savings$$ \text{memory per GPU} \approx \frac{P + G + O}{W} + A $$
P parameters, G gradients, O optimizer state, W world size, A activations
Vanilla DDP: P + G + O + A per GPU
FSDP: divides the first three by W
FSDP details. Wraps your model recursively at a chosen granularity (per layer or per block). Each forward pass: all-gathers parameters for the active block, then frees them. Backward gathers again. Costs extra communication for the memory savings.
Gradient bucketing. DDP batches small gradient tensors into "buckets" before all-reducing. Reduces per-tensor overhead. Default bucket size is fine; tune for very-small or very-large models.
Overlap compute and communication. Modern DDP / FSDP launch the next layer's compute while the previous layer's gradient is being all-reduced. The default is good; verify with the profiler that the gap between kernels is small.
Checkpoint sharding. A 100B-parameter model can't be saved as a single 400 GB file. PyTorch's FullStateDictType.SHARDED_STATE_DICT writes one file per rank. Reload with the same sharding.
Mixed precision in distributed. Use bf16 for compute, fp32 for the master parameters. GradScaler with FSDP. Most frameworks handle this automatically when you turn on AMP + the right FSDP precision policy.
The launcher.torchrun for PyTorch, accelerate launch from HuggingFace, deepspeed for DeepSpeed. Each handles process spawning, environment variables, and (sometimes) restart-from-failure.
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
def fsdp_wrap(model):
return FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy({TransformerBlock}),
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
use_orig_params=True,
)
# Save sharded checkpoint
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
sd = model.state_dict()
torch.save(sd, f"ckpt-rank{dist.get_rank()}.pt")
3D parallelism$$ \text{world} = \text{tensor} \times \text{pipeline} \times \text{data} $$
Three orthogonal axes of parallelism
Total GPUs = product of the three
The standard recipe at LLM-training scale
Tensor parallel. Split a single layer (e.g., a linear's weight matrix) across GPUs. Each GPU computes part of the output. Requires fast intra-node interconnect (NVLink). Used for the largest models.
Pipeline parallel. Different layers on different GPUs. Micro-batches flow through the pipeline; "bubble" of idle time at start and end. GPipe (Huang et al. 2018), PipeDream, 1F1B (one-forward-one-backward) scheduling.
3D parallelism. Combine DP, TP, PP. Standard at the largest scales (1000+ GPUs). Megatron-LM is the reference implementation. NVIDIA's NeMo, DeepSpeed, and PyTorch's torch.distributed all support some version.
NCCL tuning.NCCL_DEBUG=INFO for diagnostics. Topology-aware: NCCL detects PCIe / NVLink. For very large jobs, tuning NCCL_TREE_THRESHOLD, NCCL_ALGO, NCCL_PROTO. Mostly trial-and-error; profile before tuning.
Multi-node debugging. When jobs run for hours on hundreds of nodes, a single bad NIC can corrupt training. Heartbeat checks, gradient norm checks, periodic checkpoints. Sentry / Datadog / Grafana with cluster-level alerting.
Elastic training. Nodes can join or drop mid-training. torchrun --rdzv-backend=etcd for elastic rendezvous. Checkpointing must be fast (incremental, streaming) to make resumes cheap.
The cluster you wish you had. Saturated all-reduce bandwidth is rare on real clusters; profile to confirm. Often the bottleneck is something more mundane: a slow storage backend, a bad scheduler queue, or a single bad node.
# Megatron-style tensor parallelism for a Linear layer
import torch, torch.nn as nn
import torch.distributed as dist
class ColumnParallelLinear(nn.Module):
"""Splits the weight matrix's output dimension across ranks."""
def __init__(self, in_features, out_features, world_size):
super().__init__()
assert out_features % world_size == 0
local_out = out_features // world_size
self.linear = nn.Linear(in_features, local_out, bias=False)
def forward(self, x):
local_out = self.linear(x)
# All-gather to reconstruct the full output (if needed downstream)
gathered = [torch.zeros_like(local_out) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, local_out)
return torch.cat(gathered, dim=-1)