Update beliefs about parameters as you see data — get distributions, not point estimates.
Key idea
Start with a belief, see some data, update. Classical statistics gives you a single best estimate of a parameter (and maybe a confidence interval). Bayesian inference gives you a full distribution over the parameter — capturing everything you know and don't know about it after seeing the data.
Flip coins · adjust the prior · watch the posterior emerge as prior × likelihood (normalised)
α₀ = 2.0β₀ = 2.0
The dashed indigo curve is your prior — Beta(α₀, β₀). Each "Flip H" / "Flip T" click adds an observation. The bold orange curve is the posterior; it's literally prior × likelihood normalised. With no data the posterior is the prior; with a lot of data the posterior is dominated by the likelihood (the prior is "swamped"). The dashed grey curve in the middle is the likelihood at the current data, plotted to scale.
Suppose you want to know a coin's bias. After 10 flips you saw 7 heads. A classical estimate is 0.7 — period. A Bayesian says: "Before flipping, I believed the bias was probably around 0.5 (the prior). Now I've updated to a distribution centered slightly above 0.5 but with a lot of spread (the posterior) — I'm not very sure yet."
That distribution is the answer. Want a point estimate? Take its mean. Want uncertainty? Read off the credible interval. Want to predict? Average predictions over the posterior. The framework gives you principled answers to all of these.
p(D)marginal likelihood — usually intractable, but it's just a normalizer over θ
Conjugate priors. If the prior and likelihood are a "conjugate pair", the posterior is in the same family as the prior — closed-form update, no numerical integration needed. Beta-Binomial (coin flips), Gaussian-Gaussian (mean estimation with known variance), Dirichlet-Multinomial (category proportions). Use these when you can.
When conjugacy fails. Most real models aren't conjugate. Two options:
MCMC (Markov chain Monte Carlo) generates samples from the posterior without computing the normalizer. Modern variants — Hamiltonian Monte Carlo (HMC), the No-U-Turn Sampler (NUTS) — handle high-dim continuous parameters well.
Variational inference (VI) turns inference into optimization: pick a tractable approximating family q(θ), then optimize q to minimize KL divergence to the posterior. Faster than MCMC but biased — q can't capture posterior shapes it doesn't have the flexibility for.
Predictive distributions. Instead of plugging in a point estimate, average predictions over the posterior: p(xnew | D) = ∫ p(xnew | θ) p(θ | D) dθ. This automatically accounts for parameter uncertainty.
Reach for it when
Small / heterogeneous data — priors regularize gracefully
You need a coherent uncertainty propagation through several modelling steps
Mixed-effects / hierarchical models — Bayesian framework is natural
Decision making — combine posteriors with utility functions
Skip it when
Likelihood is too expensive to evaluate repeatedly (no MCMC budget)
Prior is hard to justify and the audience wants "no priors"
Posterior is multi-modal and you don't have specialized samplers
Real-time inference — MCMC is slow
import pymc as pm
# Bayesian linear regression: model = α + β·x + noise
with pm.Model() as model:
alpha = pm.Normal("alpha", mu=0, sigma=10)
beta = pm.Normal("beta", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=1)
mu = alpha + beta * x_obs
pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
# NUTS sampler — adaptive HMC
trace = pm.sample(2000, tune=1000, chains=4, target_accept=0.9)
# Posterior summaries
import arviz as az
print(az.summary(trace, var_names=["alpha", "beta", "sigma"]))
Want the MCMC mechanics, VI bounds, and the diagnostics?
Maximize ℒ over q — equivalent to minimizing KL(q ∥ posterior)
Gap = KL divergence — when zero, q = posterior
MCMC choices. Random-walk Metropolis is robust but mixes slowly. Gibbs sampling factorizes posterior into conditionals — efficient when conjugacy is partial. Hamiltonian Monte Carlo uses gradients to make informed proposals — much better mixing on continuous high-dim problems. NUTS auto-tunes HMC step size and trajectory length. Use NUTS by default in modern PPLs (PyMC, Stan, NumPyro).
Convergence diagnostics. R-hat compares within-chain to between-chain variance — should be near 1 for convergence. Effective sample size measures autocorrelation-adjusted information per draw. Energy plots reveal HMC pathologies. Trust nothing; always check.
VI methods. Mean-field VI factorizes q(θ) = ∏ q(θi) — fast but blind to correlations. Black-box VI / SVI uses gradient-based optimization with Monte Carlo gradients (works for any model in a PPL). Normalizing flows give expressive q while keeping density evaluation tractable — closing the gap on MCMC for high-dim continuous posteriors.
Prior sensitivity. Sensitivity to priors should be checked. Tighter priors = stronger regularization. For hierarchical models, weakly informative priors on top-level scales (HalfNormal, HalfCauchy) prevent funnel pathologies in MCMC. Reparameterize centered models to non-centered when geometry is hostile.
Model comparison. Bayes factors are tempting but unstable for diffuse priors. Use information criteria — WAIC, LOO-CV via importance sampling (PSIS-LOO) — for predictive comparison.
Reach for it when
Hierarchical / multilevel structure to exploit
You need to propagate uncertainty through a pipeline of models
Sparse signals where priors do the regularizing work
Probabilistic programming gives you compositional model construction
Skip it when
Big data & deep models — VI / MCMC don't scale; use SGD-friendly point estimates with explicit uncertainty (deep ensembles, MC dropout)
Likelihood-free / simulation-based setting — use ABC or SBI methods
Strict latency budget at inference time
Posterior is sharply multi-modal and you can't engineer good chains
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
# Same Bayesian regression in NumPyro — JAX-accelerated NUTS
def model(x, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0., 10.))
beta = numpyro.sample("beta", dist.Normal(0., 10.))
sigma = numpyro.sample("sigma", dist.HalfNormal(1.))
mu = alpha + beta * x
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, target_accept_prob=0.9)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(random.PRNGKey(0), x=x_obs, y=y_obs)
mcmc.print_summary()