Structured logs, the right print statements, and which debugger tricks work on tensors.
Key idea
"NaN" is a clue, not a stack trace. ML bugs usually present as silent degradation — bad metrics, divergent loss, slow training — not as crashes. The trick is making each level of the stack legible: structured logs you can search, useful prints at the right verbosity, a debugger that handles tensors.
Use real logging, not print. Python's logging module — or better, loguru — gives you levels, timestamps, structured records, and rotation for free. The five-minute investment pays off forever.
Log what matters. Loss, learning rate, gradient norms, data shapes, the configuration object on startup. Not every batch — sample every N batches. If you're using a tracker (W&B / MLflow), it's also your log.
Most ML bugs. NaN / Inf propagation, shape mismatches (sometimes silently broadcasting), gradient explosion or vanishing, learning-rate too high, data corruption upstream of training, label noise, distribution shift between train and val.
from loguru import logger
import torch
# Structured logging with levels
logger.add("train.log", rotation="100 MB", level="INFO",
format="{time} {level} {message}")
logger.info("config: {}", cfg)
for step in range(num_steps):
loss, metrics = train_step()
if torch.isnan(loss):
logger.error("NaN at step {}", step)
debug_dump(step)
raise RuntimeError("NaN loss")
if step % 100 == 0:
gnorm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
logger.info("step={} loss={:.4f} gnorm={:.4f}", step, loss.item(), gnorm)
What to watch in training
Loss: smooth decrease; NaN spike = something exploded
Gradient norm: stable; growing = lr too high; collapsing = vanishing
Activation stats: most activations should be neither 0 nor saturated
Learning rate: log it explicitly; schedule bugs are common
GPU utilization: low util usually = data-loading bottleneck
Common silent failures
Wrong data type (fp16 underflow, int truncation)
Broadcasting where you didn't expect (e.g. (B,) vs (B, 1))
Detached graph — gradients don't flow
Frozen layers you forgot to unfreeze
Wrong device — silent half-CPU half-GPU
Want anomaly detection, NaN sources, & debugging recipes?
Re-init: bad initialisation? deeper-net pathology?
Add eps to denominators, clip gradients
Check the data: outliers, label encoding, NaNs upstream
Try fp32: an fp16 underflow / overflow is masking the problem
Anomaly detection during training.torch.autograd.set_detect_anomaly(True) — slow, but tracks where a NaN originated. Use for one debug run; turn off for production.
Gradient clipping.torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0). Almost always worth it for transformers and RNNs. The clip-then-NaN heuristic: if clipping fixes the divergence, you have an exploding gradient problem.
NaN forensics. Hook every layer's forward/backward to check for NaN; first layer that produces them is the culprit. Common sources: 1/0 in normalisation, log(0), exp(huge), softmax over identical logits with fp16.
The 1-batch overfit. Take 2–8 examples. The model should overfit perfectly in under 1000 steps. If it can't, something is fundamentally wrong (architecture, loss, data pipeline).
Tensor-aware debugger. pdb works fine; for tensor inspection use display(tensor.shape, tensor.dtype, tensor.device, tensor.requires_grad). ipdb is nicer. PyCharm and VSCode have visual tensor inspectors.
The pickle vs JSON rule. Save configs, hyperparameters, and metrics as JSON or YAML — diff-friendly, language-agnostic. Save model weights as pickle / safetensors. Don't mix the two purposes.
import torch
import torch.nn as nn
# NaN-source finder: hook every module
def attach_nan_hooks(model):
def hook(name):
def fn(module, inp, out):
if isinstance(out, torch.Tensor) and torch.isnan(out).any():
print(f"NaN in forward output of {name}")
return fn
for name, mod in model.named_modules():
mod.register_forward_hook(hook(name))
# Print per-layer activation stats to spot saturating layers
def activation_summary(model, x):
activations = {}
def hook(name):
return lambda m, i, o: activations.update({name: o.detach()})
for name, mod in model.named_modules():
mod.register_forward_hook(hook(name))
model(x)
for name, a in activations.items():
print(f"{name:30s} mean={a.mean():.3f} std={a.std():.3f} "
f"max={a.max():.3f} sat={(a.abs() > 5).float().mean():.2%}")
Want anomaly detection at scale, distributed debugging, & OOM forensics?
Loss curve patterns$$ \text{divergence}, \;\text{plateau}, \;\text{cliff}, \;\text{oscillation}, \;\text{step decay} \;\to\; \text{each implies a specific bug class} $$
Each pattern has a small set of likely causes
Documented troubleshooting tree → much faster than guessing
Loss curve forensics. Diverges sharply: lr too high, bad init, or numerical issue. Plateaus immediately: gradient flow stopped (detach, frozen). Oscillates: lr too high or batch too small. Cliff (sudden drop): a learning rate schedule kick, or the model finally found the right answer.
Distributed debugging. Two ranks producing different losses on the same data → a sync or seeding issue. Use torch.distributed.barrier() + per-rank logging to find where they diverge. Always reproduce on a single GPU before debugging distributed; you'd be amazed how often a 1-GPU debug fixes the cluster bug.
OOM (Out-of-Memory) forensics. Memory usage growing over time → leak (something not getting freed; common with detached graphs kept alive). Plateaus high but stable → just need a smaller batch or more aggressive checkpointing. torch.cuda.memory_summary() shows allocation breakdown.
Profiler-led debugging. If training is slow, profile first. Common culprits: data loading (CPU bottleneck — increase num_workers), small ops (kernel launch overhead — fuse them), single-host sync (torch.distributed.gather blocking on the slowest rank).
The "minimal reproducer" discipline. Reproduce any non-trivial bug in < 30 lines of code with explicit seeds and data. Most ML bugs become 10× easier once isolated; many disappear on re-creation.
Logging hygiene at scale. Per-rank log files. JSON-lines format for parsing. Trace IDs across services. Sentry / Datadog / similar for alerting on real production failures. Cardinality matters — don't log per-example fields you'll have a billion of.
Reproducible bug reports. Save the config, the data hash, the git sha, and the exact CUDA / driver / torch version. Most "this used to work" bugs are environment drift.
import torch, gc
def memory_audit(label=""):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
alloc = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"[{label}] alloc {alloc:.2f} GB reserved {reserved:.2f} GB")
# Find the biggest tensors alive — useful for leak hunting
def biggest_tensors(top_k=10):
tensors = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensors.append((obj.numel() * obj.element_size(), tuple(obj.shape), obj.dtype))
except: pass
tensors.sort(reverse=True)
for size, shape, dtype in tensors[:top_k]:
print(f" {size / 1e6:6.1f} MB {shape} {dtype}")