Key idea

The cluster gives you many GPUs; the framework lets them cooperate. Data parallel: same model on each GPU, different data, all-reduce gradients. Model parallel: split the model across GPUs. FSDP / ZeRO: shard both. The right choice depends on what fits — and "what fits" is the main constraint.

Data parallel (DDP). The default. Each GPU has a full copy of the model. They process different batches, then all-reduce gradients. Easy to set up, scales well for models that fit on one GPU.

FSDP (Fully Sharded Data Parallel). Shards parameters, gradients, and optimizer state across GPUs. Lets you train models that wouldn't fit on a single GPU. Built into PyTorch; DeepSpeed's ZeRO has comparable functionality.

Tensor / model parallel. Split individual layers across GPUs. Used for very large models where even FSDP is insufficient. Megatron, FairScale. Most teams don't need this.

Pipeline parallel. Run different layers on different GPUs, pipelined. GPipe, PipeDream. Useful when one model doesn't fit but a layer does.

Pick by model size

  • < 1 GPU fits: don't bother distributed — single GPU is simplest
  • Model fits, want more throughput: DDP
  • Model nearly fits: FSDP / ZeRO-2
  • Model too big for any GPU: FSDP + activation checkpointing + sometimes tensor parallel
  • Hundreds of GPUs: 2D / 3D parallelism (DDP × FSDP × pipeline)

Common pitfalls

  • Different seeds per rank → ranks diverge silently
  • Logging from all ranks → step counter inflated by world_size
  • Batch norm with global stats vs per-rank → wrong stats in distributed
  • Saving from rank ≠ 0 → race conditions, corrupted checkpoints
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup_ddp():
    dist.init_process_group("nccl")     # NCCL for NVIDIA, gloo for CPU
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    return rank

def train(rank):
    model = MyNet().cuda(rank)
    model = DDP(model, device_ids=[rank])
    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                  rank=rank, shuffle=True)
    loader  = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)           # different shuffle per epoch
        for x, y in loader:
            loss = loss_fn(model(x.cuda(rank)), y.cuda(rank))
            loss.backward()
            opt.step(); opt.zero_grad()

        # Save only on rank 0
        if rank == 0:
            torch.save(model.module.state_dict(), f"epoch_{epoch}.pt")

# Launch: torchrun --nproc-per-node=4 train.py
Want FSDP, checkpoint sharding, & gradient communication optimisations?
FSDP memory savings $$ \text{memory per GPU} \approx \frac{P + G + O}{W} + A $$
  • P parameters, G gradients, O optimizer state, W world size, A activations
  • Vanilla DDP: P + G + O + A per GPU
  • FSDP: divides the first three by W

FSDP details. Wraps your model recursively at a chosen granularity (per layer or per block). Each forward pass: all-gathers parameters for the active block, then frees them. Backward gathers again. Costs extra communication for the memory savings.

Gradient bucketing. DDP batches small gradient tensors into "buckets" before all-reducing. Reduces per-tensor overhead. Default bucket size is fine; tune for very-small or very-large models.

Overlap compute and communication. Modern DDP / FSDP launch the next layer's compute while the previous layer's gradient is being all-reduced. The default is good; verify with the profiler that the gap between kernels is small.

Checkpoint sharding. A 100B-parameter model can't be saved as a single 400 GB file. PyTorch's FullStateDictType.SHARDED_STATE_DICT writes one file per rank. Reload with the same sharding.

Mixed precision in distributed. Use bf16 for compute, fp32 for the master parameters. GradScaler with FSDP. Most frameworks handle this automatically when you turn on AMP + the right FSDP precision policy.

The launcher. torchrun for PyTorch, accelerate launch from HuggingFace, deepspeed for DeepSpeed. Each handles process spawning, environment variables, and (sometimes) restart-from-failure.

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

def fsdp_wrap(model):
    return FSDP(
        model,
        auto_wrap_policy=transformer_auto_wrap_policy({TransformerBlock}),
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device(),
        use_orig_params=True,
    )

# Save sharded checkpoint
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    sd = model.state_dict()
    torch.save(sd, f"ckpt-rank{dist.get_rank()}.pt")
Want tensor / pipeline parallelism, NCCL tuning, & multi-node debugging?
3D parallelism $$ \text{world} = \text{tensor} \times \text{pipeline} \times \text{data} $$
  • Three orthogonal axes of parallelism
  • Total GPUs = product of the three
  • The standard recipe at LLM-training scale

Tensor parallel. Split a single layer (e.g., a linear's weight matrix) across GPUs. Each GPU computes part of the output. Requires fast intra-node interconnect (NVLink). Used for the largest models.

Pipeline parallel. Different layers on different GPUs. Micro-batches flow through the pipeline; "bubble" of idle time at start and end. GPipe (Huang et al. 2018), PipeDream, 1F1B (one-forward-one-backward) scheduling.

3D parallelism. Combine DP, TP, PP. Standard at the largest scales (1000+ GPUs). Megatron-LM is the reference implementation. NVIDIA's NeMo, DeepSpeed, and PyTorch's torch.distributed all support some version.

NCCL tuning. NCCL_DEBUG=INFO for diagnostics. Topology-aware: NCCL detects PCIe / NVLink. For very large jobs, tuning NCCL_TREE_THRESHOLD, NCCL_ALGO, NCCL_PROTO. Mostly trial-and-error; profile before tuning.

Multi-node debugging. When jobs run for hours on hundreds of nodes, a single bad NIC can corrupt training. Heartbeat checks, gradient norm checks, periodic checkpoints. Sentry / Datadog / Grafana with cluster-level alerting.

Elastic training. Nodes can join or drop mid-training. torchrun --rdzv-backend=etcd for elastic rendezvous. Checkpointing must be fast (incremental, streaming) to make resumes cheap.

The cluster you wish you had. Saturated all-reduce bandwidth is rare on real clusters; profile to confirm. Often the bottleneck is something more mundane: a slow storage backend, a bad scheduler queue, or a single bad node.

# Megatron-style tensor parallelism for a Linear layer
import torch, torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    """Splits the weight matrix's output dimension across ranks."""
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        assert out_features % world_size == 0
        local_out = out_features // world_size
        self.linear = nn.Linear(in_features, local_out, bias=False)

    def forward(self, x):
        local_out = self.linear(x)
        # All-gather to reconstruct the full output (if needed downstream)
        gathered = [torch.zeros_like(local_out) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered, local_out)
        return torch.cat(gathered, dim=-1)
Too dense?