Don't start from scratch — take a pre-trained model and specialise it. The trick that made modern ML scale.
Key idea
Most of what a model learns is generally useful. Edges, textures, syntax, word relationships — these aren't specific to your downstream task; they're general structure of images or language. Pre-train a model on a huge corpus, then specialise the last bit to your task with whatever labels you have. You get a head-start equivalent to thousands of labelled examples.
Compare three strategies as the labelled set shrinks — from scratch vs frozen-features vs full fine-tune
n = 20
The chart shows test accuracy as a function of labelled-training-set size for three strategies. From scratch trains a fresh model on just the labels. Linear probe uses a frozen pre-trained encoder, training only a linear head. Fine-tune initialises with the pre-trained weights and updates everything. As labels become plentiful the curves converge — but at the low end, transfer learning provides a 10-30 point head-start.
Why transfer works. The first few layers of a CNN are essentially edge detectors regardless of the task. The first few layers of a transformer encode syntactic relationships. These are reusable. The last layers are task-specific. Replace those; keep the rest.
Linear probe / feature extraction. Freeze the encoder. Train only a linear classifier on top. Fast, robust to small datasets, won't overfit. Best when your task is similar to the pre-training task.
Fine-tuning. Initialise with pre-trained weights and update everything (or just the later layers). More flexible, sometimes more accurate, more sensitive to overfitting on small labelled sets.
Adapter-based fine-tuning. Insert small trainable modules into a frozen backbone (LoRA, adapters). Cheap to train, cheap to switch tasks. The dominant fine-tuning recipe for LLMs.
Domain adaptation. Train on labelled source domain; deploy on unlabelled target domain. Use adversarial losses, importance weighting, or feature alignment to bridge the gap. Useful when labels exist for one domain (medical imaging from hospital A) but not another (hospital B).
Reach for it when
You have a small labelled dataset for your task
A pre-trained model exists in your domain (vision, NLP, audio)
You can't afford to train from scratch (compute or data)
The pre-training task is similar to your downstream task
Watch out
Domain mismatch — features from natural images may not transfer to medical X-rays
Fine-tuning with too few labels can catastrophically forget useful pre-training
Different pre-training objectives → different features → not all encoders transfer equally
Beware "leaderboard gains" from pre-training on test-set lookalikes
import torch, torch.nn as nn
import torchvision.models as models
# Linear probe — freeze backbone, train head
def linear_probe(num_classes):
m = models.resnet50(weights="DEFAULT")
for p in m.parameters(): p.requires_grad = False
m.fc = nn.Linear(m.fc.in_features, num_classes) # only this trains
return m
# Full fine-tune — train everything, smaller learning rate on the backbone
def fine_tune(num_classes):
m = models.resnet50(weights="DEFAULT")
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
# Use param groups with different lrs:
opt = torch.optim.AdamW([
{"params": m.fc.parameters(), "lr": 1e-3}, # new head
{"params": [p for n, p in m.named_parameters() if not n.startswith("fc")],
"lr": 1e-5}, # pre-trained backbone
])
Want LoRA, adapters, prompt tuning, and catastrophic forgetting?
LoRA update$$ W_{\text{new}} = W_{\text{pretrained}} + \frac{\alpha}{r} \cdot B A, \quad A \in \mathbb{R}^{r \times d_{\text{in}}}, \; B \in \mathbb{R}^{d_{\text{out}} \times r} $$
Frozen W, low-rank trainable update BA
r << d ⇒ ~100× fewer trainable parameters
Composable: stack multiple LoRA adapters for different tasks
Parameter-efficient fine-tuning (PEFT). Don't update all the weights — just a small fraction. Adapters (Houlsby et al. 2019): insert small MLPs after each layer. LoRA (Hu et al. 2021): add low-rank updates to frozen weights. Prefix tuning / prompt tuning: train a small "prompt" embedding while keeping the model frozen. All are cheap to train, cheap to store, easy to switch.
Discriminative fine-tuning. Different layers, different learning rates. The standard recipe: small lr for early (general) layers; larger lr for later (task-specific) layers. Also called "layer-wise lr decay".
Distribution shift in transfer. The pre-training distribution rarely matches the downstream one exactly. Domain adaptation (adversarial alignment, importance weighting, test-time adaptation) bridges the gap when labels for the target domain are scarce.
Foundation models as the new default. CLIP, GPT-3/4, BERT, ViT, DINOv2. Pre-train once at enormous cost; reuse forever. Fine-tune for cheap. The economics of ML have changed: most teams now start from a foundation model and fine-tune, rarely training from scratch.
Zero-shot and few-shot. Modern foundation models can do many tasks without any task-specific training — prompt them appropriately. Few-shot uses a handful of demonstrations in the prompt (in-context learning). Often beats traditional fine-tuning when labels are very scarce.
import torch.nn as nn
from peft import LoraConfig, get_peft_model
# LoRA fine-tuning a transformer — only adapter params are trainable
config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
peft_model = get_peft_model(base_model, config)
peft_model.print_trainable_parameters()
# trainable params: 1,572,864 || all params: 175,000,000,000 || trainable%: 0.0009
# Discriminative fine-tuning — layer-wise lr decay
def layer_lrs(model, base_lr=1e-3, decay=0.7):
groups = []
for i, layer in enumerate(reversed(list(model.layers))):
groups.append({"params": layer.parameters(),
"lr": base_lr * (decay ** i)})
return groups
Want continual learning, prompt tuning, RLHF, and meta-learning?
θ*pre-trained parameters; Fi Fisher information per parameter
Penalise moving important pre-trained parameters far from their original values
One classical antidote to catastrophic forgetting
Continual / lifelong learning. Train sequentially on a stream of tasks; don't forget the old ones. Classical methods: EWC, Synaptic Intelligence, MAS — penalise moving important parameters. Modern: rehearsal buffers, progressive networks. Hard in general; trade-offs between plasticity and stability are inescapable.
Prompt tuning & prefix tuning. Keep the base model frozen; learn a small embedding ("soft prompt") prepended to inputs. Lester et al. (2021), Li & Liang (2021). Competitive with full fine-tuning at modest scale; doesn't catch up at the largest scales.
In-context learning. The strangest discovery of LLM-era ML: large enough models can learn from examples in their prompt without any gradient updates. Few-shot prompting is the canonical case; reasoning chains (chain-of-thought) extend it. The mechanism is still actively researched.
Multi-task learning. Train a single model on many tasks simultaneously. Share early layers; specialise heads. Risks: tasks interfere with each other (negative transfer); the gradient of the joint objective is messy. Helped by careful loss weighting (uncertainty weighting, GradNorm).
Meta-learning ("learn to learn"). Train on a distribution over tasks; the model learns to adapt quickly to a new task with few examples. MAML (Finn et al. 2017), ProtoNets (Snell et al. 2017), Reptile. Mostly superseded by foundation-model fine-tuning at scale, but still relevant for low-data regimes.
RLHF as transfer. Pre-trained LLMs are "supervised fine-tuned" on instructions, then RLHF-tuned against human preferences. This is transfer learning with multiple objectives and a learned reward. The pre-train → SFT → RLHF pipeline is now standard for assistant-grade models.
import torch
# Elastic Weight Consolidation — penalise moving important pre-trained params
class EWC:
def __init__(self, model, fisher_info, theta_star, lam=1000.0):
self.fisher, self.theta_star, self.lam = fisher_info, theta_star, lam
def penalty(self, model):
loss = 0.0
for n, p in model.named_parameters():
if n in self.fisher:
loss += (self.fisher[n] * (p - self.theta_star[n]) ** 2).sum()
return self.lam * loss
# In a training step:
loss = task_loss + ewc.penalty(model)
loss.backward()