Key idea

Inference cost is the binding constraint for most deployed ML. Three knobs reduce it: quantize (smaller representations), distill (smaller architectures), prune (sparser weights). Combined they routinely give 4-10× speed-ups with marginal accuracy loss. The exact recipe depends on the model and the deployment target.

Quantization. Convert fp32 weights and activations to int8 (or int4 / fp8). 4× memory, 2-4× faster matmuls. Three flavours: post-training quantization (PTQ — apply to a trained model), quantization-aware training (QAT — train with simulated quantization), and weight-only quantization (the easiest for LLMs).

Distillation. Train a small "student" model to imitate a large "teacher" model. The student learns soft probabilities (not hard labels), often along with hidden representations. Works because the teacher's full output distribution encodes information that the labels alone don't.

Pruning. Remove weights / neurons / heads that contribute little. Unstructured (individual weights → sparse matrix) or structured (whole channels / heads → smaller dense matrix). Structured pruning gives real speed-ups on standard hardware.

When to use each

  • Quantize: every deployment. Lowest risk, biggest typical win
  • Distill: when target is much smaller than teacher; needs training pipeline
  • Prune: structured pruning before quantization for compounded gains
  • Speculative decoding: for LLMs — use a small model as drafter

Trade-offs

  • int8 PTQ: ~1pp accuracy drop, sometimes none
  • int4 weight-only: usable for LLMs; vision models often suffer more
  • Distillation: needs more training compute, possibly more data
  • Unstructured pruning: rarely faster on commodity GPUs
import torch
from torch.ao.quantization import quantize_dynamic

# Dynamic quantization — int8 weights, fp32 activations.
# Easiest possible win; ~2× speedup on CPU inference for transformers / LSTMs.
qmodel = quantize_dynamic(
    model.eval(),
    {torch.nn.Linear},
    dtype=torch.qint8,
)
torch.save(qmodel.state_dict(), "quantized.pt")
Want PTQ vs QAT, distillation losses, & the LLM quantization stack?
Distillation loss $$ \mathcal{L} = \alpha \cdot \mathrm{CE}(y_{\text{hard}}, p_S) + (1-\alpha) \cdot T^2 \cdot \mathrm{KL}\!\big(p_T / T \;\|\; p_S / T\big) $$
  • T = temperature (typically 2–10) — softens both distributions
  • α mixes hard-label loss with soft-distribution-matching loss
  • Hinton et al. (2015) — the original knowledge distillation recipe

Post-training quantization (PTQ). Quantize a trained model with calibration data — find scale/zero-point per layer that minimises quantization error. Easy, fast, ~1pp accuracy drop typical. torch.ao, ONNX Runtime, TensorRT all support it.

Quantization-aware training (QAT). Train with simulated quantization in the forward pass. The model adapts; accuracy is usually equal to fp32. Costs an extra training pass; useful when PTQ loses too much.

Weight-only quantization (LLMs). Weights in int4, activations stay in fp16/bf16. Specialised kernels (FlashInfer, ExLlama, AWQ) for the dequantize-and-multiply. Standard for LLM inference. GPTQ, AWQ, bitsandbytes are the three reference algorithms.

Knowledge distillation. Loss matches the teacher's softmax outputs (soft labels), often along with intermediate hidden states. Works because the teacher's "wrong" output probabilities encode useful structure. DistilBERT (66% smaller, 97% of BERT's accuracy) is the canonical demo.

Structured pruning. Remove entire channels, heads, or layers. torch.nn.utils.prune for unstructured; SparseML, nn_pruning for structured. Modern recipe: prune + fine-tune to recover accuracy.

Mixed quantization. Different layers at different precisions. SmoothQuant for LLMs; llm.int8() for outlier-aware schemes. The largest gains often come from being smart about where the precision lives.

import torch.nn.functional as F

# Hinton-style distillation loss
def distill_loss(student_logits, teacher_logits, y_hard, alpha=0.7, T=4.0):
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=-1),
        F.softmax(teacher_logits / T, dim=-1),
        reduction="batchmean",
    ) * (T * T)
    hard_loss = F.cross_entropy(student_logits, y_hard)
    return alpha * soft_loss + (1 - alpha) * hard_loss

# LLM weight-only quantization — bitsandbytes (huggingface)
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
qconf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype="bfloat16")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B-Instruct",
                                              quantization_config=qconf)
Want SmoothQuant, GPTQ, AWQ, & modern LLM compression?
Per-channel quantization scale $$ q_i = \mathrm{round}\!\left(\frac{w_i}{s_c}\right) \cdot s_c, \quad s_c = \frac{\max_i |w_{i,c}|}{127} $$
  • Each output channel c gets its own scale
  • Mitigates the loss from outlier channels
  • Per-tensor is the alternative — cheaper, more loss

SmoothQuant. Xiao et al. (2023). LLM activations have huge outliers in a few channels. Shift the difficulty from activations into weights by an absorbed scaling factor; quantize both successfully in int8. Production-quality int8 LLM inference.

GPTQ. Frantar et al. (2022). Layer-wise weight quantization with second-order error minimisation. Strong int4 quantization for LLMs at relatively low cost.

AWQ. Lin et al. (2023). Identify which weight channels are most "salient" and protect them during quantization. Typically beats GPTQ for int4 LLM serving.

LoRA + quantization (QLoRA). Dettmers et al. (2023). Fine-tune a quantized base model with LoRA adapters. Lets you fine-tune 65B-parameter models on a single 48GB GPU.

Speculative decoding. Use a small drafter model to propose k tokens; the big target model verifies them in one pass. Standard in vLLM, TGI, llama.cpp. Often gives 2× speed-up for LLMs at no accuracy cost.

Mixture of Experts (MoE) inference. Only a fraction of parameters active per token. Effective parameter count is much higher than active. Different deployment story than dense models — KV cache management, routing optimisation, expert co-location.

Hardware-aware design. Different hardware likes different things: GPUs reward batch parallelism; CPUs reward int8 + vectorisation; mobile rewards int4 + sparsity. The "best" compression is hardware-dependent.

# QLoRA — fine-tune a 4-bit base model with LoRA adapters
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    quantization_config=BitsAndBytesConfig(load_in_4bit=True,
                                            bnb_4bit_compute_dtype="bfloat16"),
    device_map="auto",
)
base = prepare_model_for_kbit_training(base)
peft_cfg = LoraConfig(r=16, lora_alpha=32,
                     target_modules=["q_proj", "v_proj"],
                     bias="none", task_type="CAUSAL_LM")
model = get_peft_model(base, peft_cfg)
model.print_trainable_parameters()
# trainable: ~1% of total — fits on a single consumer GPU
Too dense?