Skip to content

Commit

Permalink
Add standalone bnn module
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 25, 2024
1 parent 7d317ca commit 6c2c93f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
81 changes: 81 additions & 0 deletions gpax/models/bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
bnn.py
=======
Fully Bayesian MLPs
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Callable, Dict, Optional, List, Union, Tuple

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from .spm import sPM


class BNN(sPM):
"""Fully Bayesian MLP"""
def __init__(self,
input_dim: int,
output_dim: int,
noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
hidden_dim: Optional[List[int]] = None, **kwargs):
hidden_dim = [64, 32] if not hidden_dim else hidden_dim
nn = kwargs.get("nn", get_mlp(hidden_dim))
nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim))
super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist)

def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
y = y[:, None] if y.ndim < 1 else y
return X, y
return X


def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def get_mlp(architecture: List[int]) -> Callable:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
98 changes: 98 additions & 0 deletions tests/test_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax
import numpyro
from numpy.testing import assert_equal, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.models.bnn import BNN
from gpax.utils import get_keys


def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False):
X = onp.random.randn(8, feature_dim)
y = onp.random.randn(X.shape[0], target_dim)
if squeezed:
return X.squeeze(), y.squeeze()
return X, y


def test_bnn_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1)
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
assert bnn.mcmc is not None


def test_bnn_custom_layers_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
samples = bnn.get_samples()
assert_equal(samples["w0"].shape, (50, 1, 32))
assert_equal(samples["w1"].shape, (50, 32, 16))
assert_equal(samples["w2"].shape, (50, 16, 8))
assert_equal(samples["w3"].shape, (50, 8, 1))
assert_equal(samples["b0"].shape, (50, 32))
assert_equal(samples["b1"].shape, (50, 16))
assert_equal(samples["b2"].shape, (50, 8))
assert_equal(samples["b3"].shape, (50, 1))


def test_bnn_predict_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 64)),
"w1": jax.random.normal(key, shape=(50, 64, 32)),
"w2": jax.random.normal(key, shape=(50, 32, 1)),
"b0": jax.random.normal(key, shape=(50, 64,)),
"b1": jax.random.normal(key, shape=(50, 32,)),
"b2": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1)
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


def test_bnn_custom_layers_predict_custom_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 32)),
"w1": jax.random.normal(key, shape=(50, 32, 16)),
"w2": jax.random.normal(key, shape=(50, 16, 8)),
"w3": jax.random.normal(key, shape=(50, 8, 1)),
"b0": jax.random.normal(key, shape=(50, 32,)),
"b1": jax.random.normal(key, shape=(50, 16,)),
"b2": jax.random.normal(key, shape=(50, 8,)),
"b3": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


@pytest.mark.parametrize("squeezed", [True, False])
@pytest.mark.parametrize("target_dim", [1, 2])
@pytest.mark.parametrize("feature_dim", [1, 2])
def test_bnn_fit_predict(feature_dim, target_dim, squeezed):
key, _ = get_keys()
X, y = get_dummy_data(feature_dim, target_dim, squeezed)
X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed)
bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2])
bnn.fit(key, X, y, num_warmup=5, num_samples=5)
f_pred, f_samples = bnn.predict(key, X_test)
assert_equal(f_pred.shape, (len(X_test), target_dim))
assert_equal(f_samples.shape, (5, len(X_test), target_dim))

0 comments on commit 6c2c93f

Please sign in to comment.