From 6c2c93f6c2f6ae4c90f6c1b75c8c4a4387c723bc Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 25 Feb 2024 14:21:40 -0800 Subject: [PATCH] Add standalone bnn module --- gpax/models/bnn.py | 81 ++++++++++++++++++++++++++++++++++++++ tests/test_bnn.py | 98 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 gpax/models/bnn.py create mode 100644 tests/test_bnn.py diff --git a/gpax/models/bnn.py b/gpax/models/bnn.py new file mode 100644 index 0000000..2ff2eef --- /dev/null +++ b/gpax/models/bnn.py @@ -0,0 +1,81 @@ +""" +bnn.py +======= + +Fully Bayesian MLPs + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) +""" + +from typing import Callable, Dict, Optional, List, Union, Tuple + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist + +from .spm import sPM + + +class BNN(sPM): + """Fully Bayesian MLP""" + def __init__(self, + input_dim: int, + output_dim: int, + noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + hidden_dim: Optional[List[int]] = None, **kwargs): + hidden_dim = [64, 32] if not hidden_dim else hidden_dim + nn = kwargs.get("nn", get_mlp(hidden_dim)) + nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim)) + super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist) + + def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None + ) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: + X = X if X.ndim > 1 else X[:, None] + if y is not None: + y = y[:, None] if y.ndim < 1 else y + return X, y + return X + + +def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray: + """Sampling weights matrix""" + w = numpyro.sample(name=name, fn=dist.Normal( + loc=jnp.zeros((in_channels, out_channels)), + scale=jnp.ones((in_channels, out_channels)))) + return w + + +def sample_biases(name: str, channels: int) -> jnp.ndarray: + """Sampling bias vector""" + b = numpyro.sample(name=name, fn=dist.Normal( + loc=jnp.zeros((channels)), scale=jnp.ones((channels)))) + return b + + +def get_mlp(architecture: List[int]) -> Callable: + """Returns a function that represents an MLP for a given architecture.""" + def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" + h = X + for i in range(len(architecture)): + h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"]) + # No non-linearity after the last layer + z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"] + return z + return mlp + + +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: + """Priors over weights and biases for a Bayesian MLP""" + def mlp_prior(): + params = {} + in_channels = input_dim + for i, out_channels in enumerate(architecture): + params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels) + params[f"b{i}"] = sample_biases(f"b{i}", out_channels) + in_channels = out_channels + # Output layer + params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim) + params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim) + return params + return mlp_prior diff --git a/tests/test_bnn.py b/tests/test_bnn.py new file mode 100644 index 0000000..2c30479 --- /dev/null +++ b/tests/test_bnn.py @@ -0,0 +1,98 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +import jax +import numpyro +from numpy.testing import assert_equal, assert_array_equal + +sys.path.insert(0, "../gpax/") + +from gpax.models.bnn import BNN +from gpax.utils import get_keys + + +def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False): + X = onp.random.randn(8, feature_dim) + y = onp.random.randn(X.shape[0], target_dim) + if squeezed: + return X.squeeze(), y.squeeze() + return X, y + + +def test_bnn_fit(): + key, _ = get_keys() + X, y = get_dummy_data() + bnn = BNN(1, 1) + bnn.fit(key, X, y, num_warmup=50, num_samples=50) + assert bnn.mcmc is not None + + +def test_bnn_custom_layers_fit(): + key, _ = get_keys() + X, y = get_dummy_data() + bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) + bnn.fit(key, X, y, num_warmup=50, num_samples=50) + samples = bnn.get_samples() + assert_equal(samples["w0"].shape, (50, 1, 32)) + assert_equal(samples["w1"].shape, (50, 32, 16)) + assert_equal(samples["w2"].shape, (50, 16, 8)) + assert_equal(samples["w3"].shape, (50, 8, 1)) + assert_equal(samples["b0"].shape, (50, 32)) + assert_equal(samples["b1"].shape, (50, 16)) + assert_equal(samples["b2"].shape, (50, 8)) + assert_equal(samples["b3"].shape, (50, 1)) + + +def test_bnn_predict_with_samples(): + key, _ = get_keys() + X_test, _ = get_dummy_data() + + params = {"w0": jax.random.normal(key, shape=(50, 1, 64)), + "w1": jax.random.normal(key, shape=(50, 64, 32)), + "w2": jax.random.normal(key, shape=(50, 32, 1)), + "b0": jax.random.normal(key, shape=(50, 64,)), + "b1": jax.random.normal(key, shape=(50, 32,)), + "b2": jax.random.normal(key, shape=(50, 1,)), + "noise": jax.random.normal(key, shape=(50,)) + } + + bnn = BNN(1, 1) + f_pred, f_samples = bnn.predict(key, X_test, params) + assert_equal(f_pred.shape, (len(X_test), 1)) + assert_equal(f_samples.shape, (50, len(X_test), 1)) + + +def test_bnn_custom_layers_predict_custom_with_samples(): + key, _ = get_keys() + X_test, _ = get_dummy_data() + + params = {"w0": jax.random.normal(key, shape=(50, 1, 32)), + "w1": jax.random.normal(key, shape=(50, 32, 16)), + "w2": jax.random.normal(key, shape=(50, 16, 8)), + "w3": jax.random.normal(key, shape=(50, 8, 1)), + "b0": jax.random.normal(key, shape=(50, 32,)), + "b1": jax.random.normal(key, shape=(50, 16,)), + "b2": jax.random.normal(key, shape=(50, 8,)), + "b3": jax.random.normal(key, shape=(50, 1,)), + "noise": jax.random.normal(key, shape=(50,)) + } + + bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) + f_pred, f_samples = bnn.predict(key, X_test, params) + assert_equal(f_pred.shape, (len(X_test), 1)) + assert_equal(f_samples.shape, (50, len(X_test), 1)) + + +@pytest.mark.parametrize("squeezed", [True, False]) +@pytest.mark.parametrize("target_dim", [1, 2]) +@pytest.mark.parametrize("feature_dim", [1, 2]) +def test_bnn_fit_predict(feature_dim, target_dim, squeezed): + key, _ = get_keys() + X, y = get_dummy_data(feature_dim, target_dim, squeezed) + X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed) + bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2]) + bnn.fit(key, X, y, num_warmup=5, num_samples=5) + f_pred, f_samples = bnn.predict(key, X_test) + assert_equal(f_pred.shape, (len(X_test), target_dim)) + assert_equal(f_samples.shape, (5, len(X_test), target_dim))