Skip to content

Score-Based Measures

Divergences that use score functions (gradients of log-densities) rather than densities themselves. These have the key advantage of not requiring the normalizing constant.

Performance

kernel_stein_discrepancy dispatches to a Numba JIT kernel at n ≥ 500. The vectorized path materializes (n, n, d) diff arrays plus several (n, n) intermediates; the JIT path streams with O(n) memory (for the pre-evaluated scores). At n=3000 the JIT path is roughly 17× faster.

fisher_divergence with an estimated score_p uses a kernel density gradient that is now expressed via a linearity identity rather than an (m, n, d) intermediate, and the median-bandwidth helper used by both fisher_divergence and kernel_stein_discrepancy switches to a subsampling kernel at n ≥ 500.

fisher_divergence(samples_p, score_q, *, score_p=None, bandwidth=None)

Estimate the Fisher divergence between distributions P and Q.

.. math::

D_F(P \| Q) = \mathbb{E}_P\!\left[\|\nabla \log p(x) -
    \nabla \log q(x)\|^2\right]

Parameters:

Name Type Description Default
samples_p ndarray

Samples from distribution P, shape (n,) or (n, d).

required
score_q callable

Score function of Q: takes array of shape (n, d) and returns array of shape (n, d) with :math:\nabla \log q(x) at each point.

required
score_p callable or None

Score function of P. If None, estimated from samples_p via kernel density gradient with RBF kernel.

None
bandwidth float or None

Bandwidth for the kernel score estimator (used when score_p=None). If None, the median heuristic is used.

None

Returns:

Type Description
float

Estimated Fisher divergence, non-negative.

kernel_stein_discrepancy(samples, score_fn, *, kernel='rbf', bandwidth=None)

Compute the kernel Stein discrepancy (KSD).

Measures how well samples approximate the distribution P whose score function :math:\nabla \log p is provided.

The U-statistic estimator of the squared KSD is:

.. math::

\widehat{\mathrm{KSD}}^2 = \frac{1}{n(n-1)} \sum_{i \neq j}
    u_p(x_i, x_j)

where the Stein kernel is:

.. math::

u_p(x, y) = s_p(x)^\top s_p(y)\, k(x, y)
           + s_p(x)^\top \nabla_y k(x, y)
           + s_p(y)^\top \nabla_x k(x, y)
           + \nabla_x \cdot \nabla_y k(x, y)

Parameters:

Name Type Description Default
samples ndarray

Sample points, shape (n,) or (n, d).

required
score_fn callable

Score function of the target distribution P. Takes array of shape (n, d) and returns array of shape (n, d) with :math:\nabla \log p(x) evaluated at each sample point.

required
kernel str

Kernel function: "rbf" (Gaussian, default) or "imq" (inverse multiquadric). The IMQ kernel :math:k(x,y) = (c^2 + \|x-y\|^2)^{-1/2} has provable convergence control guarantees that the RBF kernel lacks.

'rbf'
bandwidth float or None

Bandwidth parameter. For the RBF kernel this is :math:\sigma in :math:k(x,y) = \exp(-\|x-y\|^2 / (2\sigma^2)). For the IMQ kernel this is :math:c in :math:k(x,y) = (c^2 + \|x-y\|^2)^{-1/2}. If None, the median heuristic is used for both kernels.

None

Returns:

Type Description
float

Squared KSD (U-statistic estimator). Close to 0 when samples come from P.

Raises:

Type Description
ValueError

If kernel is not "rbf" or "imq", or if fewer than 2 samples are provided.

Notes

The RBF kernel :math:k(x,y) = \exp(-\|x-y\|^2/(2\sigma^2)) is the standard choice. The IMQ kernel :math:k(x,y) = (c^2 + \|x-y\|^2)^{-1/2} is recommended for MCMC convergence diagnostics because it provides convergence control: :math:\mathrm{KSD}(\mu_n, \pi) \to 0 implies :math:\mu_n \Rightarrow \pi (weak convergence) and tightness of :math:\{\mu_n\} [3]_.

References

.. [1] Liu, Q., Lee, J., & Jordan, M. (2016). "A kernelized Stein discrepancy for goodness-of-fit tests." ICML. .. [2] Chwialkowski, K., Strathmann, H., & Gretton, A. (2016). "A kernel test of goodness of fit." ICML. .. [3] Gorham, J. & Mackey, L. (2017). "Measuring sample quality with kernels." ICML.

Examples:

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> samples = rng.standard_normal(1000)
>>> ksd = kernel_stein_discrepancy(samples, lambda x: -x, kernel="imq")
>>> abs(ksd) < 0.1  # close to zero for matching distribution
True