Key idea

"NaN" is a clue, not a stack trace. ML bugs usually present as silent degradation — bad metrics, divergent loss, slow training — not as crashes. The trick is making each level of the stack legible: structured logs you can search, useful prints at the right verbosity, a debugger that handles tensors.

Use real logging, not print. Python's logging module — or better, loguru — gives you levels, timestamps, structured records, and rotation for free. The five-minute investment pays off forever.

Log what matters. Loss, learning rate, gradient norms, data shapes, the configuration object on startup. Not every batch — sample every N batches. If you're using a tracker (W&B / MLflow), it's also your log.

Most ML bugs. NaN / Inf propagation, shape mismatches (sometimes silently broadcasting), gradient explosion or vanishing, learning-rate too high, data corruption upstream of training, label noise, distribution shift between train and val.

from loguru import logger
import torch

# Structured logging with levels
logger.add("train.log", rotation="100 MB", level="INFO",
           format="{time} {level} {message}")

logger.info("config: {}", cfg)

for step in range(num_steps):
    loss, metrics = train_step()

    if torch.isnan(loss):
        logger.error("NaN at step {}", step)
        debug_dump(step)
        raise RuntimeError("NaN loss")

    if step % 100 == 0:
        gnorm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
        logger.info("step={} loss={:.4f} gnorm={:.4f}", step, loss.item(), gnorm)

What to watch in training

  • Loss: smooth decrease; NaN spike = something exploded
  • Gradient norm: stable; growing = lr too high; collapsing = vanishing
  • Activation stats: most activations should be neither 0 nor saturated
  • Learning rate: log it explicitly; schedule bugs are common
  • GPU utilization: low util usually = data-loading bottleneck

Common silent failures

  • Wrong data type (fp16 underflow, int truncation)
  • Broadcasting where you didn't expect (e.g. (B,) vs (B, 1))
  • Detached graph — gradients don't flow
  • Frozen layers you forgot to unfreeze
  • Wrong device — silent half-CPU half-GPU
Want anomaly detection, NaN sources, & debugging recipes?
The first-five-things-to-check checklist $$ \text{loss} \xrightarrow{?} \text{NaN/diverge} \;\Rightarrow\; \text{lr}, \;\text{init}, \;\text{numerical stability}, \;\text{data}, \;\text{precision} $$
  • Drop the lr by 10×: does the divergence go away?
  • Re-init: bad initialisation? deeper-net pathology?
  • Add eps to denominators, clip gradients
  • Check the data: outliers, label encoding, NaNs upstream
  • Try fp32: an fp16 underflow / overflow is masking the problem

Anomaly detection during training. torch.autograd.set_detect_anomaly(True) — slow, but tracks where a NaN originated. Use for one debug run; turn off for production.

Gradient clipping. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0). Almost always worth it for transformers and RNNs. The clip-then-NaN heuristic: if clipping fixes the divergence, you have an exploding gradient problem.

NaN forensics. Hook every layer's forward/backward to check for NaN; first layer that produces them is the culprit. Common sources: 1/0 in normalisation, log(0), exp(huge), softmax over identical logits with fp16.

The 1-batch overfit. Take 2–8 examples. The model should overfit perfectly in under 1000 steps. If it can't, something is fundamentally wrong (architecture, loss, data pipeline).

Tensor-aware debugger. pdb works fine; for tensor inspection use display(tensor.shape, tensor.dtype, tensor.device, tensor.requires_grad). ipdb is nicer. PyCharm and VSCode have visual tensor inspectors.

The pickle vs JSON rule. Save configs, hyperparameters, and metrics as JSON or YAML — diff-friendly, language-agnostic. Save model weights as pickle / safetensors. Don't mix the two purposes.

import torch
import torch.nn as nn

# NaN-source finder: hook every module
def attach_nan_hooks(model):
    def hook(name):
        def fn(module, inp, out):
            if isinstance(out, torch.Tensor) and torch.isnan(out).any():
                print(f"NaN in forward output of {name}")
        return fn
    for name, mod in model.named_modules():
        mod.register_forward_hook(hook(name))

# Print per-layer activation stats to spot saturating layers
def activation_summary(model, x):
    activations = {}
    def hook(name):
        return lambda m, i, o: activations.update({name: o.detach()})
    for name, mod in model.named_modules():
        mod.register_forward_hook(hook(name))
    model(x)
    for name, a in activations.items():
        print(f"{name:30s} mean={a.mean():.3f} std={a.std():.3f} "
              f"max={a.max():.3f} sat={(a.abs() > 5).float().mean():.2%}")
Want anomaly detection at scale, distributed debugging, & OOM forensics?
Loss curve patterns $$ \text{divergence}, \;\text{plateau}, \;\text{cliff}, \;\text{oscillation}, \;\text{step decay} \;\to\; \text{each implies a specific bug class} $$
  • Each pattern has a small set of likely causes
  • Documented troubleshooting tree → much faster than guessing

Loss curve forensics. Diverges sharply: lr too high, bad init, or numerical issue. Plateaus immediately: gradient flow stopped (detach, frozen). Oscillates: lr too high or batch too small. Cliff (sudden drop): a learning rate schedule kick, or the model finally found the right answer.

Distributed debugging. Two ranks producing different losses on the same data → a sync or seeding issue. Use torch.distributed.barrier() + per-rank logging to find where they diverge. Always reproduce on a single GPU before debugging distributed; you'd be amazed how often a 1-GPU debug fixes the cluster bug.

OOM (Out-of-Memory) forensics. Memory usage growing over time → leak (something not getting freed; common with detached graphs kept alive). Plateaus high but stable → just need a smaller batch or more aggressive checkpointing. torch.cuda.memory_summary() shows allocation breakdown.

Profiler-led debugging. If training is slow, profile first. Common culprits: data loading (CPU bottleneck — increase num_workers), small ops (kernel launch overhead — fuse them), single-host sync (torch.distributed.gather blocking on the slowest rank).

The "minimal reproducer" discipline. Reproduce any non-trivial bug in < 30 lines of code with explicit seeds and data. Most ML bugs become 10× easier once isolated; many disappear on re-creation.

Logging hygiene at scale. Per-rank log files. JSON-lines format for parsing. Trace IDs across services. Sentry / Datadog / similar for alerting on real production failures. Cardinality matters — don't log per-example fields you'll have a billion of.

Reproducible bug reports. Save the config, the data hash, the git sha, and the exact CUDA / driver / torch version. Most "this used to work" bugs are environment drift.

import torch, gc

def memory_audit(label=""):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        alloc = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"[{label}] alloc {alloc:.2f} GB  reserved {reserved:.2f} GB")

# Find the biggest tensors alive — useful for leak hunting
def biggest_tensors(top_k=10):
    tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                tensors.append((obj.numel() * obj.element_size(), tuple(obj.shape), obj.dtype))
        except: pass
    tensors.sort(reverse=True)
    for size, shape, dtype in tensors[:top_k]:
        print(f"  {size / 1e6:6.1f} MB  {shape}  {dtype}")
Too dense?