Three layers of validation: schema (the data is shaped right), range (values are within expected bounds), distribution (the statistics of the new batch match what the model expects). The first two are cheap and catch ~80% of issues; the third catches the silent ones.
Data Validation
Schemas, ranges, distributions — catching bad data before it poisons your model.
"Garbage in, garbage out" — but quietly. Bad data rarely throws an error; it just makes your model worse, and the failure looks like "the model is bad" rather than "the data is bad". Catch it at the boundary: schema (right columns, right types), ranges (numbers within bounds), and distributions (no surprise drift).
Three layers of validation
- Schema: column names, types, nullability
- Range: numeric bounds, categorical sets, regex on strings
- Distribution: mean, variance, % missing, % unique within tolerance of training stats
- Cross-column: invariants ("end_date >= start_date")
Common pitfalls
- Validating output schema but not input distribution
- Validation too strict — production data is messier than train
- Validation too loose — silent drift goes undetected
- Hand-coded checks that drift apart from the data they should match
import pandera as pa
from pandera import Column, Check
import pandas as pd
# Declare once, enforce everywhere
schema = pa.DataFrameSchema({
"user_id": Column(int, Check.greater_than(0)),
"age": Column(int, Check.in_range(13, 110)),
"income": Column(float, Check.greater_than_or_equal_to(0), nullable=True),
"country": Column(str, Check.isin(["UK", "US", "FR", "DE"])),
"ts": Column(pd.Timestamp),
})
def load(path):
df = pd.read_csv(path, parse_dates=["ts"])
return schema.validate(df, lazy=True) # raise with full list of failures
# Same schema, used at every IO boundary: load, before save, after merge
- Pairwise comparison between reference (training) and current data
- Above-threshold drift triggers an alert or a retrain
Great Expectations. Declarative validation library with a "suite" of expectations per dataset. Plays well with Airflow, dbt, pandas, Spark. The reference tool for production tabular validation.
Drift metrics. KS (Kolmogorov-Smirnov) for numerical, chi-squared for categorical, PSI (Population Stability Index) for binned numerical. MMD (Maximum Mean Discrepancy) for high-dimensional distributional comparison. Pick by feature type.
What to monitor. Inputs (every feature), predictions (the output distribution), residuals (where you have ground truth), latency. Each tells you something different about how the model is doing.
Reference window. The training distribution is the natural reference. For seasonal data, use the same season last year. For streaming, a rolling reference window — but watch out for "drift" becoming "today is normal, yesterday was anomalous".
Validation at training vs serving. Train-time: catch corrupt training data. Serve-time: catch corrupt input requests. Same schema, different consequences — train-time often crashes the pipeline; serve-time silently degrades predictions.
Alerting fatigue. Validation that fires often gets ignored. Tune thresholds. Group related failures. Have a clear runbook ("if FeatureX drifts, do Y"). False alarms are worse than no alarms.
from evidently.report import Report
from evidently.metrics import DataDriftPreset
from scipy.stats import ks_2samp
# Drift report comparing reference (training) and current data
report = Report(metrics=[DataDriftPreset()])
report.run(reference_data=df_train, current_data=df_today)
report.save_html("drift.html")
# Or roll-your-own per feature
def drift_check(reference, current, threshold=0.05):
drifted = {}
for col in reference.columns:
if pd.api.types.is_numeric_dtype(reference[col]):
stat, p = ks_2samp(reference[col].dropna(), current[col].dropna())
if p < threshold: drifted[col] = ("ks", stat, p)
return drifted
- Bins on the reference distribution; sum of (pcur − pref) · log(pcur/pref)
- Empirical thresholds from credit-risk literature, often used elsewhere
Data contracts. The producer of a dataset commits to a schema and SLA; the consumer relies on it. Tools: Soda, Monte Carlo, dbt tests. Breaks the "we changed an upstream column and the model silently broke" loop.
Lineage and provenance. Track where each feature came from — which raw table, which transformation, which dataset version. dbt, OpenLineage, Marquez. Essential for debugging "why did this prediction look weird" in regulated domains.
Out-of-distribution detection. Not just feature-level drift but example-level — is this specific input from a distribution the model has seen? Mahalanobis distance, energy scores, outlier exposure. Pairs with conformal prediction for safe deployment.
Privacy & PII validation. Scan incoming data for PII before training. Presidio, Snowflake's classifier, custom regex. Often a compliance requirement; nearly always a good idea anyway.
Synthetic-data validation. When you augment with synthetic data (rare-class oversampling, simulation), validate that the synthetic distribution doesn't push the model toward spurious correlations. Compare statistics; train one model on real-only and one on real+synthetic; compare them.
Sampling for validation. Streaming or massive datasets — validate a sample, not the whole batch. Reservoir sampling for unbiased samples; stratified sampling to keep rare classes represented.
Feedback loops. A drift detector that fires triggers retraining → new model → new "normal" distribution → drift detector goes quiet. Be aware of these self-fulfilling cycles; sometimes the right answer is to fix the upstream data, not the model.
import numpy as np
# Population Stability Index — binned distributional drift
def psi(reference, current, bins=10):
edges = np.percentile(reference, np.linspace(0, 100, bins + 1))
edges = np.unique(edges)
if len(edges) < 3: return 0.0
ref_hist, _ = np.histogram(reference, bins=edges)
cur_hist, _ = np.histogram(current, bins=edges)
ref = (ref_hist + 1e-6) / ref_hist.sum()
cur = (cur_hist + 1e-6) / cur_hist.sum()
return float(((cur - ref) * np.log(cur / ref)).sum())
# Usage
for col in numeric_cols:
score = psi(df_train[col], df_today[col])
flag = "OK" if score < 0.1 else ("MINOR" if score < 0.25 else "DRIFT")
print(f"{col:20s} psi={score:.3f} {flag}")