Skip to content

Commit

Permalink
Merge pull request #89 from ziatdinovmax/bnn
Browse files Browse the repository at this point in the history
Streamline usage of BNNs in DKL
  • Loading branch information
ziatdinovmax authored Feb 23, 2024
2 parents 1b3e5c6 + 9284cb5 commit 5d509eb
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 88 deletions.
120 changes: 58 additions & 62 deletions gpax/models/dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
Fully Bayesian implementation of deep kernel learning
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit

from .vgp import vExactGP
from .gp import ExactGP


class DKL(vExactGP):
class DKL(ExactGP):
"""
Fully Bayesian implementation of deep kernel learning
Expand All @@ -39,6 +39,10 @@ class DKL(vExactGP):
Priors over the weights and biases in 'nn'; uses normal priors by default
latent_prior:
Optional prior over the latent space (BNN embedding); uses none by default
hidden_dim:
Optional custom MLP architecture. For example [16, 8, 4] corresponds to a 3-layer
neural network backbone containing 16, 8, and 4 neurons activated by tanh(). The latent
layer is added autoamtically and doesn't have to be spcified here. Defaults to [64, 32].
**kwargs:
Optional custom prior distributions over observational noise (noise_dist_prior)
Expand Down Expand Up @@ -67,11 +71,12 @@ def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF',
nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
**kwargs
hidden_dim: Optional[List[int]] = None, **kwargs
) -> None:
super(DKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs)
self.nn = nn if nn else mlp
self.nn_prior = nn_prior if nn_prior else mlp_prior(input_dim, z_dim)
hdim = hidden_dim if hidden_dim is not None else [64, 32]
self.nn = nn if nn else get_mlp(hdim)
self.nn_prior = nn_prior if nn_prior else get_mlp_prior(input_dim, z_dim, hdim)
self.kernel_dim = z_dim
self.latent_prior = latent_prior

Expand All @@ -84,8 +89,8 @@ def model(self,
jitter = kwargs.get("jitter", 1e-6)
task_dim = X.shape[0]
# BNN part
bnn_params = self.nn_prior(task_dim)
z = jax.jit(jax.vmap(self.nn))(X, bnn_params)
nn_params = self.nn_prior(task_dim)
z = self.nn(X, nn_params)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
Expand All @@ -94,30 +99,28 @@ def model(self,
else:
kernel_params = self._sample_kernel_params(task_dim)
# Sample noise
noise = self._sample_noise(task_dim)
noise = self._sample_noise()
# GP's mean function
f_loc = jnp.zeros(z.shape[:2])
# compute kernel(s)
jitter = jnp.array(jitter).repeat(task_dim)
k_args = (z, z, kernel_params, noise)
k = jax.vmap(self.kernel)(*k_args, jitter=jitter)
f_loc = jnp.zeros(z.shape[0])
# compute kernel
k = self.kernel(z, z, kernel_params, noise, jitter=jitter)
# Sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k),
obs=y,
)

#@partial(jit, static_argnames='self')
def _get_mvn_posterior(self,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray, params: Dict[str, jnp.ndarray],
noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
def get_mvn_posterior(self,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
noise = params["noise"]
noise_p = noise * (1 - jnp.array(noiseless, int))
# embed data into the latent space
z_train = self.nn(X_train, params)
z_train = self.nn(self.X_train, params)
z_new = self.nn(X_new, params)
# compute kernel matrices for train and new ('test') data
k_pp = self.kernel(z_new, z_new, params, noise_p, **kwargs)
Expand All @@ -126,7 +129,7 @@ def _get_mvn_posterior(self,
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, self.y_train))
return mean, cov

@partial(jit, static_argnames='self')
Expand All @@ -140,16 +143,6 @@ def embed(self, X_new: jnp.ndarray) -> jnp.ndarray:
z = predictive(samples)
return z

def _set_data(self,
X: jnp.ndarray,
y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X[None] if X.ndim == 2 else X # add task pseudo-dimension
if y is not None:
y = y[None] if y.ndim == 1 else y # add task pseudo-dimension
return X, y
return X

def _print_summary(self):
list_of_keys = ["k_scale", "k_length", "noise", "period"]
samples = self.get_samples(1)
Expand All @@ -159,40 +152,43 @@ def _print_summary(self):

def sample_weights(name: str, in_channels: int, out_channels: int, task_dim: int) -> jnp.ndarray:
"""Sampling weights matrix"""
with numpyro.plate("batch_dim", task_dim, dim=-3):
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray:
"""Sampling bias vector"""
with numpyro.plate("batch_dim", task_dim, dim=-3):
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Simple MLP for a single MCMC sample of weights and biases"""
h1 = jnp.tanh(jnp.matmul(X, params["w1"]) + params["b1"])
h2 = jnp.tanh(jnp.matmul(h1, params["w2"]) + params["b2"])
z = jnp.matmul(h2, params["w3"]) + params["b3"]
return z


def mlp_prior(input_dim: int, zdim: int = 2) -> Dict[str, jnp.array]:
"""Priors over weights and biases in the default Bayesian MLP"""
hdim = [64, 32]

def _bnn_prior(task_dim: int):
w1 = sample_weights("w1", input_dim, hdim[0], task_dim)
b1 = sample_biases("b1", hdim[0], task_dim)
w2 = sample_weights("w2", hdim[0], hdim[1], task_dim)
b2 = sample_biases("b2", hdim[1], task_dim)
w3 = sample_weights("w3", hdim[1], zdim, task_dim)
b3 = sample_biases("b3", zdim, task_dim)
return {"w1": w1, "b1": b1, "w2": w2, "b2": b2, "w3": w3, "b3": b3}

return _bnn_prior
def get_mlp(architecture: List[int]) -> Callable:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior(task_dim: int):
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, task_dim)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels, task_dim)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim, task_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim, task_dim)
return params
return mlp_prior
4 changes: 2 additions & 2 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def test_acq_dkl(acq):
rng_keys = get_keys()
X = onp.random.randn(12, 8)
y = onp.random.randn(12,)
X_new = onp.random.randn(10, 8)[None]
X_new = onp.random.randn(10, 8)
m = DKL(X.shape[-1], 2, 'RBF')
m.fit(rng_keys[0], X, y, num_samples=5, num_warmup=5)
obj = acq(rng_keys[1], m, X_new, subsample_size=4)
assert_equal(obj.shape, (X_new.shape[1],))
assert_equal(obj.shape, (X_new.shape[0],))


def test_UCB_beta():
Expand Down
53 changes: 29 additions & 24 deletions tests/test_dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,44 @@ def test_get_mvn_posterior():
rng_key = get_keys()[0]
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
params = {"w1": jax.random.normal(rng_key, shape=(36, 64)),
"w2": jax.random.normal(rng_key, shape=(64, 32)),
"w3": jax.random.normal(rng_key, shape=(32, 2)),
"b1": jax.random.normal(rng_key, shape=(64,)),
"b2": jax.random.normal(rng_key, shape=(32,)),
"b3": jax.random.normal(rng_key, shape=(2,)),
params = {"w0": jax.random.normal(rng_key, shape=(36, 64)),
"w1": jax.random.normal(rng_key, shape=(64, 32)),
"w2": jax.random.normal(rng_key, shape=(32, 2)),
"b0": jax.random.normal(rng_key, shape=(64,)),
"b1": jax.random.normal(rng_key, shape=(32,)),
"b2": jax.random.normal(rng_key, shape=(2,)),
"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = DKL(X.shape[-1], kernel='RBF')
mean, cov = m._get_mvn_posterior(X, y, X_test, params)
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_get_mvn_posterior_noiseless():
rng_key = get_keys()[0]
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
params = {"w1": jax.random.normal(rng_key, shape=(36, 64)),
"w2": jax.random.normal(rng_key, shape=(64, 32)),
"w3": jax.random.normal(rng_key, shape=(32, 2)),
"b1": jax.random.normal(rng_key, shape=(64,)),
"b2": jax.random.normal(rng_key, shape=(32,)),
"b3": jax.random.normal(rng_key, shape=(2,)),
params = {"w0": jax.random.normal(rng_key, shape=(36, 64)),
"w1": jax.random.normal(rng_key, shape=(64, 32)),
"w2": jax.random.normal(rng_key, shape=(32, 2)),
"b0": jax.random.normal(rng_key, shape=(64,)),
"b1": jax.random.normal(rng_key, shape=(32,)),
"b2": jax.random.normal(rng_key, shape=(2,)),
"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = DKL(X.shape[-1], kernel='RBF')
mean1, cov1 = m._get_mvn_posterior(X, y, X_test, params, noiseless=False)
mean1_, cov1_ = m._get_mvn_posterior(X, y, X_test, params, noiseless=False)
mean2, cov2 = m._get_mvn_posterior(X, y, X_test, params, noiseless=True)
m.X_train = X
m.y_train = y
mean1, cov1 = m.get_mvn_posterior(X_test, params, noiseless=False)
mean1_, cov1_ = m.get_mvn_posterior(X_test, params, noiseless=False)
mean2, cov2 = m.get_mvn_posterior(X_test, params, noiseless=True)
assert_array_equal(mean1, mean1_)
assert_array_equal(cov1, cov1_)
assert_array_equal(mean1, mean2)
Expand Down Expand Up @@ -94,19 +99,19 @@ def test_jitter_mvn_posterior():
# X = X[None]
# y = y[None]
# X_test = X_test[None]
params = {"w1": jax.random.normal(rng_key, shape=(36, 64)),
"w2": jax.random.normal(rng_key, shape=(64, 32)),
"w3": jax.random.normal(rng_key, shape=(32, 2)),
"b1": jax.random.normal(rng_key, shape=(64,)),
"b2": jax.random.normal(rng_key, shape=(32,)),
"b3": jax.random.normal(rng_key, shape=(2,)),
params = {"w0": jax.random.normal(rng_key, shape=(36, 64)),
"w1": jax.random.normal(rng_key, shape=(64, 32)),
"w2": jax.random.normal(rng_key, shape=(32, 2)),
"b0": jax.random.normal(rng_key, shape=(64,)),
"b1": jax.random.normal(rng_key, shape=(32,)),
"b2": jax.random.normal(rng_key, shape=(2,)),
"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = DKL(X.shape[-1], 2, 'RBF')
m.X_train = X
m.y_train = y
mean1, cov1 = m._get_mvn_posterior(X, y, X_test, params, jitter=1e-6)
mean2, cov2 = m._get_mvn_posterior(X, y, X_test, params, jitter=1e-5)
mean1, cov1 = m.get_mvn_posterior(X_test, params, jitter=1e-6)
mean2, cov2 = m.get_mvn_posterior(X_test, params, jitter=1e-5)
assert_(onp.count_nonzero(mean1 - mean2) > 0)
assert_(onp.count_nonzero(cov1 - cov2) > 0)

0 comments on commit 5d509eb

Please sign in to comment.