-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added kernel density estimation to score matching. #258
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from jax.scipy.stats import multivariate_normal, norm | ||
from optax import sgd | ||
|
||
import coreax.kernel as ck | ||
import coreax.networks as cn | ||
import coreax.score_matching as csm | ||
|
||
|
@@ -37,6 +38,177 @@ def __call__(self, x: csm.ArrayLike) -> csm.ArrayLike: | |
return x | ||
|
||
|
||
class TestKernelDensityMatching(unittest.TestCase): | ||
""" | ||
Tests related to the class in score_matching.py | ||
""" | ||
|
||
def test_univariate_gaussian_score(self) -> None: | ||
""" | ||
Test a simple univariate Gaussian with a known score function. | ||
""" | ||
# Setup univariate Gaussian | ||
mu = 0.0 | ||
std_dev = 1.0 | ||
num_points = 500 | ||
np.random.seed(0) | ||
samples = np.random.normal(mu, std_dev, size=(num_points, 1)) | ||
|
||
def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | ||
return -(x_ - mu) / std_dev**2 | ||
|
||
# Define data | ||
x = np.linspace(-2, 2).reshape(-1, 1) | ||
true_score_result = true_score(x) | ||
|
||
# Define a kernel density matching object | ||
kernel_density_matcher = csm.KernelDensityMatching( | ||
length_scale=ck.median_heuristic(samples), kde_data=samples | ||
) | ||
|
||
# Extract the score function (this is not really learned from the data, more | ||
# defined within the object) | ||
learned_score = kernel_density_matcher.match(samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If comment here is followed, we shouldn't need the Repeat for all other calls to match in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to not input anything (since it's unused) |
||
score_result = learned_score(x) | ||
|
||
# Check learned score and true score align | ||
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5) | ||
|
||
def test_multivariate_gaussian_score(self) -> None: | ||
""" | ||
Test a simple multivariate Gaussian with a known score function. | ||
""" | ||
# Setup multivariate Gaussian | ||
dimension = 2 | ||
mu = np.zeros(dimension) | ||
sigma_matrix = np.eye(dimension) | ||
lambda_matrix = np.linalg.pinv(sigma_matrix) | ||
num_points = 500 | ||
np.random.seed(0) | ||
samples = np.random.multivariate_normal(mu, sigma_matrix, size=num_points) | ||
|
||
def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | ||
return np.array(list(map(lambda z: -lambda_matrix @ (z - mu), x_))) | ||
|
||
# Define data | ||
x, y = np.meshgrid(np.linspace(-2, 2), np.linspace(-2, 2)) | ||
data_stacked = np.vstack([x.ravel(), y.ravel()]).T | ||
true_score_result = true_score(data_stacked) | ||
|
||
# Define a kernel density matching object | ||
kernel_density_matcher = csm.KernelDensityMatching( | ||
length_scale=ck.median_heuristic(samples), kde_data=samples | ||
) | ||
|
||
# Extract the score function (this is not really learned from the data, more | ||
# defined within the object) | ||
learned_score = kernel_density_matcher.match(samples) | ||
score_result = learned_score(data_stacked) | ||
|
||
# Check learned score and true score align | ||
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.75) | ||
|
||
def test_univariate_gmm_score(self): | ||
""" | ||
Test a univariate Gaussian mixture model with a known score function. | ||
""" | ||
# Define the univariate Gaussian mixture model | ||
mus = np.array([-4.0, 4.0]) | ||
std_devs = np.array([1.0, 2.0]) | ||
p = 0.7 | ||
mix = np.array([1 - p, p]) | ||
num_points = 1000 | ||
np.random.seed(0) | ||
comp = np.random.binomial(1, p, size=num_points) | ||
samples = np.random.normal(mus[comp], std_devs[comp]).reshape(-1, 1) | ||
|
||
def egrad(g: csm.Callable) -> csm.Callable: | ||
def wrapped(x_, *rest): | ||
y, g_vjp = jax.vjp(lambda x__: g(x, *rest), x_) | ||
(x_bar,) = g_vjp(np.ones_like(y)) | ||
return x_bar | ||
|
||
return wrapped | ||
|
||
def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | ||
log_pdf = lambda y: jax.numpy.log(norm.pdf(y, mus, std_devs) @ mix) | ||
return egrad(log_pdf)(x_) | ||
|
||
# Define data | ||
x = np.linspace(-10, 10).reshape(-1, 1) | ||
true_score_result = true_score(x) | ||
|
||
# Define a kernel density matching object | ||
kernel_density_matcher = csm.KernelDensityMatching( | ||
length_scale=ck.median_heuristic(samples), kde_data=samples | ||
) | ||
|
||
# Extract the score function (this is not really learned from the data, more | ||
# defined within the object) | ||
learned_score = kernel_density_matcher.match(samples) | ||
score_result = learned_score(x) | ||
|
||
# Check learned score and true score align | ||
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5) | ||
|
||
def test_multivariate_gmm_score(self): | ||
""" | ||
Test a multivariate Gaussian mixture model with a known score function. | ||
""" | ||
# Define the multivariate Gaussian mixture model (we don't want to go much | ||
# higher than dimension=2) | ||
np.random.seed(0) | ||
dimension = 2 | ||
k = 10 | ||
mus = np.random.multivariate_normal( | ||
np.zeros(dimension), np.eye(dimension), size=k | ||
) | ||
sigmas = np.array( | ||
[np.random.gamma(2.0, 1.0) * np.eye(dimension) for _ in range(k)] | ||
) | ||
mix = np.random.dirichlet(np.ones(k)) | ||
num_points = 500 | ||
comp = np.random.choice(k, size=num_points, p=mix) | ||
samples = np.array( | ||
[np.random.multivariate_normal(mus[c], sigmas[c]) for c in comp] | ||
) | ||
|
||
def egrad(g: csm.Callable) -> csm.Callable: | ||
def wrapped(x_, *rest): | ||
y, g_vjp = jax.vjp(lambda x__: g(x_, *rest), x_) | ||
(x_bar,) = g_vjp(np.ones_like(y)) | ||
return x_bar | ||
|
||
return wrapped | ||
|
||
def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | ||
def logpdf(y: csm.ArrayLike) -> csm.ArrayLike: | ||
lpdf = 0.0 | ||
for k_ in range(k): | ||
lpdf += multivariate_normal.pdf(y, mus[k_], sigmas[k_]) * mix[k_] | ||
return jax.numpy.log(lpdf) | ||
|
||
return egrad(logpdf)(x_) | ||
|
||
# Define data | ||
coords = np.meshgrid(*[np.linspace(-7.5, 7.5) for _ in range(dimension)]) | ||
x_stacked = np.vstack([c.ravel() for c in coords]).T | ||
true_score_result = true_score(x_stacked) | ||
|
||
# Define a kernel density matching object | ||
kernel_density_matcher = csm.KernelDensityMatching( | ||
length_scale=ck.median_heuristic(samples), kde_data=samples | ||
) | ||
|
||
# Extract the score function (this is not really learned from the data, more | ||
# defined within the object) | ||
learned_score = kernel_density_matcher.match(samples) | ||
score_result = learned_score(x_stacked) | ||
|
||
# Check learned score and true score align | ||
self.assertLessEqual(np.abs(true_score_result - score_result).mean(), 0.5) | ||
|
||
|
||
class TestSlicedScoreMatching(unittest.TestCase): | ||
""" | ||
Tests related to the class SlicedScoreMatching in score_matching.py. | ||
|
@@ -401,7 +573,7 @@ def test_train_step(self) -> None: | |
|
||
def test_univariate_gaussian_score(self): | ||
""" | ||
Test a simple univariate Gaussian known score function. | ||
Test a simple univariate Gaussian with a known score function. | ||
""" | ||
# Setup univariate Gaussian | ||
mu = 0.0 | ||
|
@@ -443,7 +615,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | |
|
||
def test_multivariate_gaussian_score(self) -> None: | ||
""" | ||
Test a simple multivariate Gaussian known score function. | ||
Test a simple multivariate Gaussian with a known score function. | ||
""" | ||
# Setup multivariate Gaussian | ||
dimension = 2 | ||
|
@@ -487,7 +659,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | |
|
||
def test_univariate_gmm_score(self): | ||
""" | ||
Test a univariate Gaussian mixture model known score function. | ||
Test a univariate Gaussian mixture model with a known score function. | ||
""" | ||
# Define the univariate Gaussian mixture model | ||
mus = np.array([-4.0, 4.0]) | ||
|
@@ -540,7 +712,7 @@ def true_score(x_: csm.ArrayLike) -> csm.ArrayLike: | |
|
||
def test_multivariate_gmm_score(self): | ||
""" | ||
Test a multivariate Gaussian mixture model known score function. | ||
Test a multivariate Gaussian mixture model with a known score function. | ||
""" | ||
# Define the multivariate Gaussian mixture model (we don't want to go much | ||
# higher than dimension=2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
x
is not used, should we putx: ArrayLike | None = None
? Then we don't need to callmatch
with anything; particularly intest_score_matching.py
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed and added