Key idea

The shape of your loss curve is a diagnosis. Spike-then-NaN is exploding gradients. Plateau at the start is gradients not flowing. Smooth decrease then sharp drop is a learning-rate schedule kick. Train loss low but val high is overfitting. Pattern → diagnosis → fix is a much faster loop than guesswork.

Click each pattern — see the loss curve and the most likely diagnosis

Six canonical training pathologies side-by-side with their likely causes. The healthy curve is your target. The other five are common; each has 1–3 typical root causes and a known fix.

Healthy curve. Smooth, monotonic-ish decrease in train loss. Val loss tracks it for a while, may diverge a bit (mild overfitting). Both flatten near the end.

Divergent. Loss rises rapidly then NaN. Almost always: learning rate too high, gradient explosion, fp16 underflow, or bad initialisation. Fix: drop lr 10×, clip gradients, try bf16 or fp32.

Plateau from step 1. Loss never decreases meaningfully. Gradients aren't flowing — detached graph, frozen layer, wrong optimizer setup. Run a 1-batch overfit: if it fails, you have a loop bug.

Oscillates. Loss bounces up and down without converging. Learning rate too high; reduce it. Or batch too small + lr too high — same fix.

Train ≪ val (overfitting). Train loss continues to drop but val rises. Need regularization, more data, or earlier stopping.

Schedule kick. Smooth curve with a sudden drop at a known step. A learning-rate scheduler moved to a lower lr; the model finally settled. Often a good sign, but check there's not an unintended schedule kicking in.

What to log

  • Train loss every N batches; val loss per epoch (or every K steps)
  • Gradient norm
  • Learning rate (yes, even though you control it)
  • Activation stats (a sample of layers)
  • System metrics (GPU util, memory)

What loss curves don't tell you

  • Whether the model has learned the right thing (run evals!)
  • Subgroup performance (look at per-class metrics)
  • Calibration (low loss doesn't mean calibrated)
  • Generalization to OOD inputs
Want subgroup diagnostics & eval-time forensics?
The 4 forensic plots $$ \text{loss curve}, \;\text{gradient norm}, \;\text{per-class val acc}, \;\text{calibration} $$
  • Loss: smooth or not
  • Gradient norm: stable, growing, or vanishing
  • Per-class: catches "great on average, awful on class 7"
  • Calibration: are the probabilities trustworthy?

Gradient norm. Stable around a few units → healthy. Growing exponentially → explosion. Decaying toward 0 → vanishing. Always log it.

Per-class metrics. Average accuracy can hide a class that's at 0% recall. Plot per-class val accuracy over time; a class that's always near zero needs attention (more data, weighted loss, or oversampling).

Smoothed vs raw. Raw loss per batch is noisy. Smoothed (EMA, window-mean) is what you compare across runs. Most trackers do this automatically; if not, smooth with α ≈ 0.9.

Loss spikes. Common with large batch + adaptive optimizer (Adam). Usually safe to ignore if recovery is quick; investigate if recovery takes a while. Gradient clipping reduces them.

Warm-up. Linear lr ramp from 0 to target over the first few hundred steps. Essential for transformers (Adam's bias correction misbehaves at step 1). Without it, you often see a brief spike that the rest of training overcomes — but cleaner with warm-up.

Train vs val together. Plot them on the same axes. Train below val and both decreasing — healthy. Train below val and val rising — overfitting; stop or regularise. Train above val (rare but real) — usually a metric bug or different normalisation.

import wandb

def log_health(step, loss, grad_norm, lr, per_class_acc=None):
    payload = {"train/loss": loss, "train/grad_norm": grad_norm, "lr": lr}
    if per_class_acc is not None:
        for c, acc in enumerate(per_class_acc):
            payload[f"val/acc_class_{c}"] = acc
    wandb.log(payload, step=step)

# Quick "early stopping watchdog" — bail if no improvement for K evals
class EarlyStop:
    def __init__(self, patience=5, min_delta=0.001):
        self.best = float("inf"); self.wait = 0
        self.patience = patience; self.min_delta = min_delta
    def step(self, val_loss):
        if val_loss < self.best - self.min_delta:
            self.best = val_loss; self.wait = 0
            return False
        self.wait += 1
        return self.wait > self.patience
Want curvature signs, double descent, & the modern "loss spike" literature?
Adam's bias correction effect $$ \hat m = \frac{m_t}{1 - \beta_1^t}, \quad \hat v = \frac{v_t}{1 - \beta_2^t} $$
  • Near t = 0, denominators are tiny → effective step size is huge
  • Why warm-up is essential for transformers + Adam / AdamW

Loss spikes at scale. LLM training (~1B+ params) often shows occasional sharp loss spikes followed by recovery. Investigated in PaLM 2 / Megatron logs. Sometimes attributed to outlier examples; sometimes to optimiser numerical state. Many shops just blacklist outlier batches and continue.

Double descent. In overparameterised models, test loss can rise then fall again as you increase capacity past the interpolation threshold. A loss curve that "should" indicate overfitting can be the descent's start. Modern: just train longer / bigger.

Grokking. Small-scale algorithmic tasks (modular arithmetic) — train loss plateaus near zero, val loss takes 100× more steps to catch up. The curve looks like a plateau then a sudden cliff. Nanda et al. 2023 reverse-engineered the mechanism.

Curvature diagnostics. Hessian eigenvalues, gradient covariance, Fisher information at the trained parameters. Flat minima → better generalization (sometimes). Tools: PyHessian, BackPACK. Useful for advanced analysis; not routine debugging.

Catastrophic forgetting traces. Multi-task or sequential fine-tuning — a task's loss rises after training on a different task. Plot per-task loss to see which tasks are at risk; mitigate with EWC, replay buffers, or lower learning rates.

Generalization gap. Train loss − val loss. Stable gap → expected; widening gap → overfitting. Closes at the end of training in well-regularised runs.

Beyond loss. Some pathologies don't show in loss but in downstream metrics — calibration drift, subgroup regressions, hallucination rates for LLMs. Make sure those are also tracked.

import torch

# Largest Hessian eigenvalue via power iteration (Yao et al., PyHessian)
def top_hessian_eig(loss_fn, params, num_iters=20):
    v = [torch.randn_like(p) for p in params]
    for _ in range(num_iters):
        grads = torch.autograd.grad(
            loss_fn(), params, create_graph=True
        )
        Hv = torch.autograd.grad(
            sum((g * vi).sum() for g, vi in zip(grads, v)), params,
        )
        # Normalise
        n = torch.sqrt(sum((h * h).sum() for h in Hv))
        v = [h / n for h in Hv]
    return n.item()
Too dense?