Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added kernel density estimation to score matching. #258

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coreax/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Kernel(ABC):
"""
Base class for kernels.

:param length_scale: Kernel length_scale to use
:param length_scale: Kernel ``length_scale`` to use
:param output_scale: Output scale to use
"""

Expand Down
98 changes: 96 additions & 2 deletions coreax/score_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
child class of the abstract base class :class:`ScoreMatching`.

When using :class:`SlicedScoreMatching`, the score function is approximated using a
neural network.
neural network, whereas in :class:`KernelDensityMatching`, it is approximated by fitting
and then differentiating a kernel density estimate to the data.
"""

from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial

import jax
import numpy as np
import optax
from flax.training.train_state import TrainState
from jax import jit, jvp
Expand All @@ -36,6 +38,7 @@
from jax.typing import ArrayLike
from tqdm import tqdm

import coreax.kernel as ck
from coreax.networks import ScoreNetwork, create_train_state


Expand Down Expand Up @@ -447,9 +450,100 @@ def match(self, x: ArrayLike) -> Callable:
return lambda x_: state.apply_fn({"params": state.params}, x_)


class KernelDensityMatching(ScoreMatching):
r"""
Implementation of a kernel density estimate to determine a score function.

The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by 'matching' the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of this base class.

With kernel density matching, we approximate the underlying distribution function
from a dataset using kernel density estimation, and then differentiate this to
compute an estimate of the score function. A Gaussian kernel is used to construct
the kernel density estimate.

:param length_scale: Kernel ``length_scale`` to use when fitting the kernel density
estimate
:param kde_data: Set of :math:`n \times d` samples from the underlying distribution
that are used to build the kernel density estimate
"""

def __init__(self, length_scale: float, kde_data: ArrayLike):
r"""
Define the kernel density matching class.
"""
# Define a normalised Gaussian kernel (which is a special cases of the squared
# exponential kernel) to construct the kernel density estimate
self.kernel = ck.SquaredExponentialKernel(
length_scale=length_scale,
output_scale=1.0 / (np.sqrt(2 * np.pi) * length_scale),
)

# Hold the data the kernel density estimate will be built from, which will be
# needed for any call of the score function.
self.kde_data = kde_data

# Initialise parent
super().__init__()

def _tree_flatten(self):
"""
Flatten a pytree.

Define arrays & dynamic values (children) and auxiliary data (static values).
A method to flatten the pytree needs to be specified to enable jit decoration
of methods inside this class.
"""
children = (self.kde_data,)
aux_data = {"kernel": self.kernel}

return children, aux_data

def match(self, x: ArrayLike | None = None) -> Callable:
r"""
Learn a score function using kernel density estimation to model a distribution.

For the kernel density matching approach, the score function is determined by
fitting a kernel density estimate to samples from the underlying distribution
and then differentiating this. Therefore, learning in this context refers to
simply defining the score function and kernel density estimate given some
samples we wish to evaluate the score function at, and the data used to build
the kernel density estimate.

:param x: The :math:`n \times d` data vectors. Unused in this implementation.
:return: A function that applies the learned score function to input ``x``
"""

def score_function(x_):
r"""
Compute the score function using a kernel density estimation.

The score function is determined by fitting a kernel density estimate to
samples from the underlying distribution and then differentiating this. The
kernel density estimate is create using a Gaussian kernel.

:param x_: The :math:`n \times d` data vectors we wish to evaluate the score
function at
"""
# Check format
x_ = jnp.atleast_2d(x_)

# Get the gram matrix row means
gram_matrix_row_means = self.kernel.compute(x_, self.kde_data).mean(axis=1)

# Compute gradients with respect to x
gradients = self.kernel.grad_x(x_, self.kde_data).mean(axis=1)

return gradients / gram_matrix_row_means[:, None]

return score_function


# Define the pytree node for the added class to ensure methods with jit decorators
# are able to run. This tuple must be updated when a new class object is defined.
score_matching_classes = (SlicedScoreMatching,)
score_matching_classes = (SlicedScoreMatching, KernelDensityMatching)
for current_class in score_matching_classes:
tree_util.register_pytree_node(
current_class, current_class._tree_flatten, current_class._tree_unflatten
Expand Down
180 changes: 176 additions & 4 deletions tests/unit/test_score_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jax.scipy.stats import multivariate_normal, norm
from optax import sgd

import coreax.kernel as ck
import coreax.networks as cn
import coreax.score_matching as csm

Expand All @@ -37,6 +38,177 @@ def __call__(self, x: csm.ArrayLike) -> csm.ArrayLike:
return x


class TestKernelDensityMatching(unittest.TestCase):
"""
Tests related to the class in score_matching.py
"""

def test_univariate_gaussian_score(self) -> None:
"""
Test a simple univariate Gaussian with a known score function.
"""
# Setup univariate Gaussian
mu = 0.0
std_dev = 1.0
num_points = 500
np.random.seed(0)
samples = np.random.normal(mu, std_dev, size=(num_points, 1))

def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:
return -(x_ - mu) / std_dev**2

# Define data
x = np.linspace(-2, 2).reshape(-1, 1)
true_score_result = true_score(x)

# Define a kernel density matching object
kernel_density_matcher = csm.KernelDensityMatching(
length_scale=ck.median_heuristic(samples), kde_data=samples
)

# Extract the score function (this is not really learned from the data, more
# defined within the object)
learned_score = kernel_density_matcher.match()
score_result = learned_score(x)

# Check learned score and true score align
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5)

def test_multivariate_gaussian_score(self) -> None:
"""
Test a simple multivariate Gaussian with a known score function.
"""
# Setup multivariate Gaussian
dimension = 2
mu = np.zeros(dimension)
sigma_matrix = np.eye(dimension)
lambda_matrix = np.linalg.pinv(sigma_matrix)
num_points = 500
np.random.seed(0)
samples = np.random.multivariate_normal(mu, sigma_matrix, size=num_points)

def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:
return np.array(list(map(lambda z: -lambda_matrix @ (z - mu), x_)))

# Define data
x, y = np.meshgrid(np.linspace(-2, 2), np.linspace(-2, 2))
data_stacked = np.vstack([x.ravel(), y.ravel()]).T
true_score_result = true_score(data_stacked)

# Define a kernel density matching object
kernel_density_matcher = csm.KernelDensityMatching(
length_scale=ck.median_heuristic(samples), kde_data=samples
)

# Extract the score function (this is not really learned from the data, more
# defined within the object)
learned_score = kernel_density_matcher.match()
score_result = learned_score(data_stacked)

# Check learned score and true score align
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.75)

def test_univariate_gmm_score(self):
"""
Test a univariate Gaussian mixture model with a known score function.
"""
# Define the univariate Gaussian mixture model
mus = np.array([-4.0, 4.0])
std_devs = np.array([1.0, 2.0])
p = 0.7
mix = np.array([1 - p, p])
num_points = 1000
np.random.seed(0)
comp = np.random.binomial(1, p, size=num_points)
samples = np.random.normal(mus[comp], std_devs[comp]).reshape(-1, 1)

def egrad(g: csm.Callable) -> csm.Callable:
def wrapped(x_, *rest):
y, g_vjp = jax.vjp(lambda x__: g(x, *rest), x_)
(x_bar,) = g_vjp(np.ones_like(y))
return x_bar

return wrapped

def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:
log_pdf = lambda y: jax.numpy.log(norm.pdf(y, mus, std_devs) @ mix)
return egrad(log_pdf)(x_)

# Define data
x = np.linspace(-10, 10).reshape(-1, 1)
true_score_result = true_score(x)

# Define a kernel density matching object
kernel_density_matcher = csm.KernelDensityMatching(
length_scale=ck.median_heuristic(samples), kde_data=samples
)

# Extract the score function (this is not really learned from the data, more
# defined within the object)
learned_score = kernel_density_matcher.match()
score_result = learned_score(x)

# Check learned score and true score align
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5)

def test_multivariate_gmm_score(self):
"""
Test a multivariate Gaussian mixture model with a known score function.
"""
# Define the multivariate Gaussian mixture model (we don't want to go much
# higher than dimension=2)
np.random.seed(0)
dimension = 2
k = 10
mus = np.random.multivariate_normal(
np.zeros(dimension), np.eye(dimension), size=k
)
sigmas = np.array(
[np.random.gamma(2.0, 1.0) * np.eye(dimension) for _ in range(k)]
)
mix = np.random.dirichlet(np.ones(k))
num_points = 500
comp = np.random.choice(k, size=num_points, p=mix)
samples = np.array(
[np.random.multivariate_normal(mus[c], sigmas[c]) for c in comp]
)

def egrad(g: csm.Callable) -> csm.Callable:
def wrapped(x_, *rest):
y, g_vjp = jax.vjp(lambda x__: g(x_, *rest), x_)
(x_bar,) = g_vjp(np.ones_like(y))
return x_bar

return wrapped

def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:
def logpdf(y: csm.ArrayLike) -> csm.ArrayLike:
lpdf = 0.0
for k_ in range(k):
lpdf += multivariate_normal.pdf(y, mus[k_], sigmas[k_]) * mix[k_]
return jax.numpy.log(lpdf)

return egrad(logpdf)(x_)

# Define data
coords = np.meshgrid(*[np.linspace(-7.5, 7.5) for _ in range(dimension)])
x_stacked = np.vstack([c.ravel() for c in coords]).T
true_score_result = true_score(x_stacked)

# Define a kernel density matching object
kernel_density_matcher = csm.KernelDensityMatching(
length_scale=ck.median_heuristic(samples), kde_data=samples
)

# Extract the score function (this is not really learned from the data, more
# defined within the object)
learned_score = kernel_density_matcher.match()
score_result = learned_score(x_stacked)

# Check learned score and true score align
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5)


class TestSlicedScoreMatching(unittest.TestCase):
"""
Tests related to the class SlicedScoreMatching in score_matching.py.
Expand Down Expand Up @@ -401,7 +573,7 @@ def test_train_step(self) -> None:

def test_univariate_gaussian_score(self):
"""
Test a simple univariate Gaussian known score function.
Test a simple univariate Gaussian with a known score function.
"""
# Setup univariate Gaussian
mu = 0.0
Expand Down Expand Up @@ -443,7 +615,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:

def test_multivariate_gaussian_score(self) -> None:
"""
Test a simple multivariate Gaussian known score function.
Test a simple multivariate Gaussian with a known score function.
"""
# Setup multivariate Gaussian
dimension = 2
Expand Down Expand Up @@ -487,7 +659,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:

def test_univariate_gmm_score(self):
"""
Test a univariate Gaussian mixture model known score function.
Test a univariate Gaussian mixture model with a known score function.
"""
# Define the univariate Gaussian mixture model
mus = np.array([-4.0, 4.0])
Expand Down Expand Up @@ -540,7 +712,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike:

def test_multivariate_gmm_score(self):
"""
Test a multivariate Gaussian mixture model known score function.
Test a multivariate Gaussian mixture model with a known score function.
"""
# Define the multivariate Gaussian mixture model (we don't want to go much
# higher than dimension=2)
Expand Down