Score-Based Divergences: Fisher and Stein¶
"Every divergence we have seen so far compares distributions by looking at their densities — the 'height' of the probability landscape. What if we compare the slopes instead?"
In the preceding notebooks, we built a rich toolkit for comparing distributions:
- Notebook 1: Shannon's information-theoretic measures (entropy, KL, MI)
- Notebook 2: f-divergences and Rényi families
- Notebook 3: Integral probability metrics (Wasserstein, energy, MMD)
- Notebook 4: Multivariate dependence and causal information flow
This notebook ventures into the modern frontier of divergence measures, where two powerful mathematical traditions converge:
| Measure | Core idea | Tradition | Key advantage |
|---|---|---|---|
| Fisher Divergence | Compare slopes of log-densities | Differential geometry | Works without normalizing constants |
| Kernel Stein Discrepancy | Stein's identity meets kernel methods | Functional analysis | Goodness-of-fit without computing $Z$ |
These measures have become essential tools in modern machine learning — powering score-matching models and MCMC diagnostics.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from divergence import (
fisher_divergence,
kernel_stein_discrepancy,
)
plt.rcParams.update({
'figure.figsize': (8, 4),
'axes.spines.top': False,
'axes.spines.right': False,
'font.size': 12,
})
np.random.seed(42)
from pathlib import Path
FIGURES_DIR = Path('figures/scores_and_transport')
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
1. The Score Function: Fisher's Enduring Insight¶
The man who shaped modern statistics¶
Sir Ronald Aylmer Fisher (1890–1962) is arguably the most influential statistician of the twentieth century. Working at Rothamsted Experimental Station in the English countryside — surrounded by agricultural field trials — he laid the foundations of modern statistical inference: maximum likelihood, sufficiency, analysis of variance, experimental design.
Among Fisher's many contributions, one concept stands out for its depth and elegance: the score function. For a probability density $p(x)$, the score is simply the gradient of the log-density:
$$ s_p(x) = \nabla_x \log p(x) $$
Why is this important? Because the score captures the local geometry of the distribution — at each point, it tells you which direction increases probability the fastest. It is the "slope" of the log-density landscape.
For the standard normal $\mathcal{N}(0, 1)$, the score is delightfully simple:
$$ s(x) = \frac{d}{dx} \log\left(\frac{1}{\sqrt{2\pi}} e^{-x^2/2}\right) = -x $$
At $x = 0$ (the peak), the score is zero — you're at the top. At $x = 2$, the score is $-2$ — the density is falling steeply to the right, pulling you back toward the center.
# Visualize the density and its score function side by side
x_grid = np.linspace(-4, 4, 300)
density = norm.pdf(x_grid)
score = -x_grid # ∇ log p(x) = -x for N(0,1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 4.5))
# Density
ax1.fill_between(x_grid, density, alpha=0.3, color='steelblue')
ax1.plot(x_grid, density, color='steelblue', lw=2)
ax1.set_title('Density $p(x)$', fontsize=14)
ax1.set_xlabel('$x$')
ax1.set_ylabel('$p(x)$')
ax1.annotate('peak', xy=(0, norm.pdf(0)), xytext=(1.5, 0.35),
arrowprops=dict(arrowstyle='->', color='gray'),
fontsize=11, color='gray')
# Score
ax2.plot(x_grid, score, color='coral', lw=2)
ax2.axhline(0, color='gray', lw=0.5, ls='--')
ax2.axvline(0, color='gray', lw=0.5, ls='--')
ax2.set_title('Score $s(x) = \\nabla \\log p(x) = -x$', fontsize=14)
ax2.set_xlabel('$x$')
ax2.set_ylabel('$s(x)$')
ax2.annotate('score = 0\nat the peak', xy=(0, 0), xytext=(1.8, 1.5),
arrowprops=dict(arrowstyle='->', color='gray'),
fontsize=11, color='gray')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'density_and_score.png', dpi=300, bbox_inches='tight')
plt.show()
The density tells you how likely each point is. The score tells you which way probability increases fastest. This shift in perspective — from heights to slopes — is the foundation of modern score-based methods.
2. Fisher Divergence: Comparing Slopes¶
The Fisher divergence between distributions $P$ and $Q$ measures how different their score functions are, as experienced by points drawn from $P$:
$$ D_F(P \| Q) = \mathbb{E}_{x \sim P}\!\left[\|\nabla \log p(x) - \nabla \log q(x)\|^2\right] $$
This is the expected squared difference between the slopes of the two log-densities. If $P$ and $Q$ have the same shape (i.e., the same score function everywhere), the Fisher divergence is zero.
A beautiful closed-form case: For $P = \mathcal{N}(0, 1)$ and $Q = \mathcal{N}(\mu, 1)$ (same variance, different means), the Fisher divergence is exactly:
$$ D_F = \mathbb{E}_P[\|(-x) - (-(x - \mu))\|^2] = \mathbb{E}_P[\mu^2] = \mu^2 $$
Elegant! The Fisher divergence between two unit-variance normals is simply the squared distance between their means.
n = 5000
samples_p = np.random.randn(n) # P = N(0, 1)
# Known score functions
def score_standard_normal(x):
"""Score of N(0, 1): s(x) = -x"""
return -x
def score_shifted_normal(x, mu=2.0):
"""Score of N(mu, 1): s(x) = -(x - mu)"""
return -(x - mu)
# Fisher divergence: P = Q → should be 0
fd_same = fisher_divergence(
samples_p, score_standard_normal, score_p=score_standard_normal
)
print(f'D_F(N(0,1) || N(0,1)) = {fd_same:.4f} (expected: 0.0)')
# Fisher divergence: P = N(0,1), Q = N(2,1) → should be μ² = 4.0
fd_shifted = fisher_divergence(
samples_p,
lambda x: score_shifted_normal(x, mu=2.0),
score_p=score_standard_normal,
)
print(f'D_F(N(0,1) || N(2,1)) = {fd_shifted:.4f} (expected: μ² = 4.0)')
D_F(N(0,1) || N(0,1)) = 0.0000 (expected: 0.0) D_F(N(0,1) || N(2,1)) = 4.0000 (expected: μ² = 4.0)
How Fisher divergence grows with separation¶
Let's sweep the mean shift $\mu$ and verify the $D_F = \mu^2$ relationship:
mu_values = np.linspace(0, 3, 15)
fd_values = []
for mu in mu_values:
fd = fisher_divergence(
samples_p,
lambda x, m=mu: -(x - m), # score of N(mu, 1)
score_p=score_standard_normal,
)
fd_values.append(fd)
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(mu_values, fd_values, 'o', color='coral', markersize=8, label='Estimated $D_F$')
ax.plot(mu_values, mu_values**2, '-', color='steelblue', lw=2, label='Analytical $\\mu^2$')
ax.set_xlabel('Mean shift $\\mu$')
ax.set_ylabel('Fisher Divergence')
ax.set_title('Fisher Divergence: $D_F(\\mathcal{N}(0,1) \\| \\mathcal{N}(\\mu,1)) = \\mu^2$',
fontsize=13)
ax.legend()
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'fisher_divergence_parabola.png', dpi=300, bbox_inches='tight')
plt.show()
The estimated values trace the analytical $\mu^2$ parabola almost perfectly. This is a hallmark of a well-behaved divergence: clean, interpretable, and efficiently estimable.
When you don't know the score: estimation from data¶
In practice, we often don't have an analytical expression for $\nabla \log p$. The fisher_divergence function can estimate the score of $P$ from samples using a kernel density gradient estimator — trading some accuracy for generality:
# Estimate score of P from data (no analytical score_p provided)
fd_estimated = fisher_divergence(
samples_p,
lambda x: score_shifted_normal(x, mu=2.0),
# score_p not specified → estimated via kernel density gradient
)
print(f'D_F with estimated score_p: {fd_estimated:.4f}')
print(f'D_F with known score_p: {fd_shifted:.4f}')
print(f'Analytical value: 4.0000')
print(f'\nScore estimation introduces some noise, but the order of magnitude is right.')
D_F with estimated score_p: 4.1912 D_F with known score_p: 4.0000 Analytical value: 4.0000 Score estimation introduces some noise, but the order of magnitude is right.
Why Fisher divergence matters today¶
The Fisher divergence has become a cornerstone of modern generative modeling:
- Score matching (Hyvärinen, 2005): Train a model by minimizing the Fisher divergence between data and model — crucially, this doesn't require computing the normalizing constant of the model.
- Diffusion models (Song & Ermon, 2019): The explosion of diffusion-based image generation (Stable Diffusion, DALL-E 3) is built on score matching — estimating the score function at different noise levels.
- Stein variational inference (Liu & Wang, 2016): Uses score functions to transport particles toward the posterior.
Fisher's original insight about the importance of log-likelihood gradients, conceived amid English wheat fields in the 1920s, now powers the most impressive generative AI systems in the world.
3. Kernel Stein Discrepancy: Testing Without the Normalizing Constant¶
The problem that plagued Bayesian statistics¶
Here is one of the deepest practical problems in statistics: you have samples (say, from an MCMC sampler) and you want to test whether they actually come from a target distribution $P$. But computing $P$'s density requires evaluating a normalizing constant $Z = \int \exp(-U(x))\,dx$ that is often intractable.
This is the situation in Bayesian inference, where $P$ is the posterior distribution and $Z$ is the marginal likelihood — the integral that makes Bayesian computation so challenging.
Stein's identity: a mathematical miracle¶
In 1972, Charles Stein at Stanford discovered a remarkable identity: for any smooth function $h$ and distribution $P$ with score $s_p$,
$$ \mathbb{E}_{x \sim P}\bigl[s_p(x)\,h(x) + \nabla h(x)\bigr] = 0 $$
This identity holds for any test function $h$ — it characterizes the distribution $P$ completely. And it depends on $P$ only through the score function, which doesn't require the normalizing constant!
From Stein's identity to a practical test¶
In 2016, Qiang Liu, Jason Lee, and Michael Jordan at UC Berkeley combined Stein's identity with kernel methods to create the kernel Stein discrepancy (KSD) — a practical, computable measure of how well samples approximate a target distribution.
The KSD uses a Stein kernel:
$$ u_p(x, y) = s_p(x)^\top s_p(y)\,k(x,y) + s_p(x)^\top \nabla_y k(x,y) + s_p(y)^\top \nabla_x k(x,y) + \nabla_x \cdot \nabla_y k(x,y) $$
where $k$ is an RBF kernel. The squared KSD is estimated via a U-statistic:
$$ \widehat{\mathrm{KSD}}^2 = \frac{1}{n(n-1)} \sum_{i \neq j} u_p(x_i, x_j) $$
The key insight: $\mathrm{KSD}(Q, P) = 0$ if and only if $Q = P$. And computing it requires only the score function of $P$ — not the normalizing constant.
n = 2000
# Case 1: Samples actually from N(0,1) → KSD should be ≈ 0
samples_correct = np.random.randn(n)
ksd_correct = kernel_stein_discrepancy(samples_correct, score_standard_normal)
# Case 2: Samples from N(2,1) tested against N(0,1) → KSD should be large
samples_wrong = np.random.normal(2.0, 1.0, n)
ksd_wrong = kernel_stein_discrepancy(samples_wrong, score_standard_normal)
print('Kernel Stein Discrepancy (testing against N(0,1) score):')
print(f' Samples from N(0,1): KSD² = {ksd_correct:.6f} ← close to zero ✓')
print(f' Samples from N(2,1): KSD² = {ksd_wrong:.6f} ← clearly positive ✗')
Kernel Stein Discrepancy (testing against N(0,1) score): Samples from N(0,1): KSD² = 0.001471 ← close to zero ✓ Samples from N(2,1): KSD² = 2.272547 ← clearly positive ✗
How KSD grows with mismatch¶
Let's visualize how the KSD responds as we shift the sample distribution further from the target:
mu_shifts = np.linspace(0, 3, 12)
ksd_values = []
for mu in mu_shifts:
samples_shifted = np.random.normal(mu, 1.0, n)
ksd = kernel_stein_discrepancy(samples_shifted, score_standard_normal)
ksd_values.append(ksd)
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(mu_shifts, ksd_values, 'o-', color='coral', markersize=8, lw=2)
ax.axhline(0, color='gray', lw=0.5, ls='--')
ax.set_xlabel('Mean of sample distribution (target is $\\mu = 0$)')
ax.set_ylabel('KSD²')
ax.set_title('KSD detects distribution mismatch without the normalizing constant',
fontsize=13)
ax.annotate('Samples match target', xy=(0, ksd_values[0]),
xytext=(0.5, max(ksd_values)*0.3),
arrowprops=dict(arrowstyle='->', color='steelblue'),
fontsize=11, color='steelblue')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'ksd_vs_mismatch.png', dpi=300, bbox_inches='tight')
plt.show()
The KSD is essentially zero when the samples match the target, then rises steeply as the mismatch grows. This makes it an excellent diagnostic tool for MCMC convergence: if your KSD is small, your sampler has found the right distribution.
Two dimensions: the same principle scales¶
# 2D standard normal
samples_2d_correct = np.random.randn(n, 2)
samples_2d_wrong = np.random.normal(1.5, 1.0, (n, 2))
ksd_2d_correct = kernel_stein_discrepancy(samples_2d_correct, score_standard_normal)
ksd_2d_wrong = kernel_stein_discrepancy(samples_2d_wrong, score_standard_normal)
print('2D KSD (testing against N(0, I₂) score):')
print(f' Correct samples: KSD² = {ksd_2d_correct:.6f}')
print(f' Wrong samples: KSD² = {ksd_2d_wrong:.6f}')
2D KSD (testing against N(0, I₂) score): Correct samples: KSD² = 0.000412 Wrong samples: KSD² = 2.659863
Summary: The Modern Landscape¶
We have explored two fundamentally different approaches to score-based comparison of distributions:
| Measure | Compares | Key advantage | Computational cost |
|---|---|---|---|
| Fisher Divergence | Score functions (slopes) | No normalizing constant needed | $O(n)$ with known scores |
| Kernel Stein Discrepancy | Samples vs. score function | Goodness-of-fit without $Z$ | $O(n^2)$ (kernel matrix) |
These measures represent three of the most active research frontiers in computational statistics and machine learning. They illustrate how ideas from very different mathematical traditions — Fisher's score functions (1920s) and Stein's characterization of distributions (1970s) — can be combined into practical tools that power today's most advanced AI systems.
The Divergence Notebook Series¶
| # | Notebook | What it covers |
|---|---|---|
| 1 | Divergence | Shannon's foundations: entropy, cross entropy, KL divergence, Jensen-Shannon, mutual information, joint and conditional entropy |
| 2 | Beyond KL | f-divergences (TV, Hellinger, chi-squared, Jeffreys, Cressie-Read) and the Rényi family |
| 3 | Distances and Testing | Sample-based methods: Wasserstein, Sinkhorn, energy distance, MMD, kNN estimators, two-sample permutation tests |
| 4 | Dependence and Causality | Multivariate dependence (TC, NMI, VI) and directed information flow (transfer entropy) |
| 5 | Bayesian Diagnostics | End-to-end MCMC with emcee on the Nile change-point — convergence diagnostics, information gain, Bayesian surprise |
| 6 | Real-World Applications | Stock market contagion, crop yields, Phillips Curve — real data, real stakes |
| 7 | Score-Based Divergences: Fisher and Stein (this notebook) | Fisher divergence and kernel Stein discrepancy |
| 8 | Did My Sampler Find the Truth? | KSD as convergence diagnostic with NumPyro: NUTS vs VI, the 250-year journey from Bayes to Stein |
| 9 | Phillips Curve TVP | Time-varying Phillips Curve with PyJAGS Gibbs sampling — stagflation as a structural break |