Key idea

Vanilla SGD is rarely the fastest. Real loss surfaces have narrow ravines, saddle points, and curvature that changes from one direction to another. Modern optimizers solve these in different ways: momentum remembers past direction, RMSprop / Adagrad scale each parameter by its own gradient history, Adam combines both. The viz below races them on the same start.

Same start, same step budget — watch four optimizers race across a curved loss surface
lr = 0.05 step 0

All four optimizers see the same loss surface from the same start. SGD moves straight down the local gradient and zig-zags in narrow valleys. Momentum accumulates velocity and "rolls through" valleys. RMSprop shrinks per-parameter steps where gradients are big, freeing it to take bigger steps where they're small. Adam combines momentum with RMSprop's adaptive scaling — usually the safest first pick.

SGD. Update is θ ← θ − η·∇L. Cheap, well-understood, often the best generalization in deep learning — but slow to converge on ill-conditioned problems where one direction needs huge steps and another tiny ones.

Momentum. Keep a running velocity that smooths gradient noise and accelerates along consistent directions. Nesterov momentum looks one step ahead before computing the gradient, which gives a small but real speedup.

Adam. Maintain per-parameter running estimates of the gradient mean and squared mean; scale each step by mean / √(second moment). Essentially momentum × RMSprop. The default for most deep learning today, though SGD-with-momentum sometimes generalizes better on vision tasks.

Practical truth: the learning rate matters more than the optimizer choice. A well-tuned SGD often beats a poorly-tuned Adam. Always sweep the learning rate first.

Adam (and friends)

  • Default for transformer / NLP training
  • You don't have time to tune lr per layer / per phase
  • Sparse gradients (embeddings, NLP) — Adam handles them well
  • RNN / LSTM training is much friendlier with Adam

SGD-with-momentum

  • Image classification, where it generalizes a little better
  • You've tuned the learning rate schedule carefully
  • Reproducing classical results (most ResNet papers use it)
  • You want the simplest possible thing in the loop
import torch

# SGD with momentum — the deep-learning workhorse
opt_sgd  = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True)

# Adam — the safe default everywhere else
opt_adam = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Learning rate matters more than optimizer. Always have a schedule:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt_adam, T_max=100)
Want the math: momentum, Adam, second-order, line search?
Adam update $$ \begin{aligned} m_t &= \beta_1 m_{t-1} + (1 - \beta_1)\, g_t \\ v_t &= \beta_2 v_{t-1} + (1 - \beta_2)\, g_t^2 \\ \hat m_t &= m_t / (1 - \beta_1^t), \quad \hat v_t = v_t / (1 - \beta_2^t) \\ \theta_{t+1} &= \theta_t - \eta\, \hat m_t / (\sqrt{\hat v_t} + \epsilon) \end{aligned} $$
  • mEMA of the gradient (momentum)
  • vEMA of squared gradients (per-parameter scale)
  • Bias correction 1 − βt matters early in training when both EMAs start at zero

Momentum, properly. The velocity update v ← β v + g is an exponentially-weighted average of past gradients with time-constant ≈ 1/(1−β). Past gradients in directions that point the same way reinforce; gradients that cancel out are smoothed. This is exactly why momentum helps in narrow ravines — the consistent ravine-floor direction gets amplified, while the bouncing perpendicular component cancels.

Adagrad → RMSprop → Adam. Adagrad scales each parameter's step by 1/√(Σ g²) — but the sum grows monotonically, killing learning eventually. RMSprop replaces the sum with an EMA, fixing the decay-to-zero problem. Adam adds momentum on top. AdamW decouples weight decay from the gradient (Loshchilov & Hutter, 2019) — for transformers this is a meaningful improvement.

Learning rate schedules. The single highest-leverage knob is the lr schedule, not the optimizer. Warm-up (linear ramp from 0 to lr over a few hundred steps) is essential for transformers — Adam's bias correction misbehaves at step 1. Cosine decay is the modern standard for the rest. One-cycle and triangular schedules (Smith 2017) are useful when you want a quick convergence.

Second-order methods. Newton's method scales steps by the inverse Hessian — works wonderfully in low dimensions, intractable in high. Practical approximations include K-FAC (block-diagonal Fisher), Shampoo (Kronecker factors of the Hessian), and L-BFGS (history-based quasi-Newton). Useful for fine-tuning or small models; rarely beat tuned Adam at scale.

Implicit regularization. Different optimizers reach different solutions on the same loss surface. SGD's gradient noise has been argued to act as a regularizer that finds "flatter" minima — this is part of why it sometimes generalizes better than Adam, despite Adam reaching a lower training loss.

import torch

# Linear warm-up + cosine decay — the modern transformer recipe
def warm_cos_lr(step, warmup, total, lr_max):
    if step < warmup:
        return lr_max * step / warmup
    p = (step - warmup) / max(1, total - warmup)
    return 0.5 * lr_max * (1 + math.cos(math.pi * p))

# Gradient clipping — almost always worth it for transformers / RNNs
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# AdamW: weight decay should NOT pass through Adam's denominator
opt = torch.optim.AdamW(model.parameters(), lr=lr,
                        betas=(0.9, 0.95), weight_decay=0.1)
Want K-FAC, Shampoo, sharpness-aware minimization, and natural gradients?
Natural gradient $$ \theta_{t+1} = \theta_t - \eta\, F(\theta_t)^{-1} \nabla \mathcal{L}(\theta_t) $$
  • FFisher information matrix — the Hessian of the KL between distributions
  • Steps in distribution space rather than parameter space
  • Reparameterization-invariant — re-scaling parameters doesn't change the trajectory

K-FAC. Kronecker-Factored Approximate Curvature (Martens & Grosse, 2015) approximates the Fisher as a block-diagonal Kronecker product per layer. Practical for medium-sized networks; offers real wall-clock speedup for some workloads but adds substantial implementation complexity.

Shampoo. Anil et al. (2020) — second-order preconditioning that maintains Kronecker factors of the gradient covariance and applies their inverse 1/4 root. Slower per step than Adam but converges in fewer steps; recent variants are competitive on real-world training.

SAM — Sharpness-Aware Minimization. Foret et al. (2021) — explicitly look for parameters that minimize loss and have small gradient norm in their neighbourhood (flat minima). Adds a small inner-loop "find the worst direction nearby" step. Reliably improves generalization on vision tasks; ~2× the compute cost.

Lion / Sophia / Schedule-Free. Three recent optimizers from Chen et al. (2023), Liu et al. (2024), and Defazio et al. (2024) respectively. Lion uses sign-of-momentum updates; Sophia uses a clipped second-order term; Schedule-Free does away with the lr schedule entirely. Each shows wins on some benchmarks but hasn't displaced AdamW as the default.

Why SGD generalizes better. Empirically, SGD-with-momentum often finds "wider" minima than Adam — measured by Hessian eigenvalues at the solution. There are theoretical arguments (gradient noise as injected SDE) and counter-arguments. The practical implication is unchanged: try AdamW first, and for vision benchmarks try SGD-with-momentum as a comparison.

Trust region methods. TRPO and PPO (in RL) use a KL constraint instead of a learning rate — they implicitly act like a natural gradient step. Lessons from those have leaked back into supervised learning (e.g., the "trust region for gradient steps" view of Adam's denominator).

import torch

# Sharpness-Aware Minimization in 10 lines
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_opt, rho=0.05, **kwargs):
        self.base = base_opt(params, **kwargs)
        defaults = dict(rho=rho, **self.base.defaults)
        super().__init__(self.base.param_groups, defaults)

    @torch.no_grad()
    def first_step(self):
        norm = torch.norm(torch.stack([p.grad.norm() for g in self.param_groups
                                       for p in g["params"] if p.grad is not None]))
        for g in self.param_groups:
            scale = g["rho"] / (norm + 1e-12)
            for p in g["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale
                p.add_(e_w); self.state[p]["e_w"] = e_w

    @torch.no_grad()
    def second_step(self):
        for g in self.param_groups:
            for p in g["params"]:
                if "e_w" in self.state[p]: p.sub_(self.state[p]["e_w"])
        self.base.step()

# Usage:
# loss.backward(); opt.first_step()
# loss = compute_loss(); loss.backward(); opt.second_step()
Too dense?