Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a standalone BNN #91

Merged
merged 10 commits into from
Feb 29, 2024
4 changes: 2 additions & 2 deletions gpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from .hypo import sample_next
from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP,
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
MeasuredNoiseGP, viSparseGP)
MeasuredNoiseGP, viSparseGP, BNN)

__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "sample_next", "__version__"]
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]
6 changes: 4 additions & 2 deletions gpax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .mngp import MeasuredNoiseGP
from .linreg import LinReg
from .sparse_gp import viSparseGP
from .bnn import BNN

__all__ = [
"ExactGP",
Expand All @@ -30,6 +31,7 @@
"CoregGP",
"UIGP",
"LinReg",
"MeasuredNoiseGP"
"viSparseGP"
"MeasuredNoiseGP",
"viSparseGP",
"BNN"
]
81 changes: 81 additions & 0 deletions gpax/models/bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
bnn.py
=======

Fully Bayesian MLPs

Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Callable, Dict, Optional, List, Union, Tuple

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from .spm import sPM


class BNN(sPM):
"""Fully Bayesian MLP"""
def __init__(self,
input_dim: int,
output_dim: int,
noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
hidden_dim: Optional[List[int]] = None, **kwargs):
hidden_dim = [64, 32] if not hidden_dim else hidden_dim
nn = kwargs.get("nn", get_mlp(hidden_dim))
nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim))
super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist)

def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
y = y[:, None] if y.ndim < 1 else y
return X, y
return X


def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
25 changes: 12 additions & 13 deletions gpax/models/dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,16 @@ def model(self,
) -> None:
"""DKL probabilistic model"""
jitter = kwargs.get("jitter", 1e-6)
task_dim = X.shape[0]
# BNN part
nn_params = self.nn_prior(task_dim)
nn_params = self.nn_prior()
z = self.nn(X, nn_params)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params(task_dim)
kernel_params = self._sample_kernel_params()
# Sample noise
noise = self._sample_noise()
# GP's mean function
Expand Down Expand Up @@ -150,22 +149,22 @@ def _print_summary(self):
{k: v for (k, v) in samples.items() if k in list_of_keys})


def sample_weights(name: str, in_channels: int, out_channels: int, task_dim: int) -> jnp.ndarray:
def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray:
def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
b = numpyro.sample(name=name, fn=dist.Cauchy(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def get_mlp(architecture: List[int]) -> Callable:
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
Expand All @@ -178,17 +177,17 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior(task_dim: int):
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, task_dim)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels, task_dim)
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim, task_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim, task_dim)
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
3 changes: 0 additions & 3 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import warnings
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Type, Union

import jax
Expand All @@ -17,7 +16,6 @@
import jax.random as jra
import numpyro
import numpyro.distributions as dist
from jax import jit
from numpyro.infer import MCMC, NUTS, init_to_median, Predictive

from ..kernels import get_kernel
Expand Down Expand Up @@ -252,7 +250,6 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
"""Get posterior samples (after running the MCMC chains)"""
return self.mcmc.get_samples(group_by_chain=chain_dim)

# @partial(jit, static_argnames='self')
def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down
39 changes: 33 additions & 6 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

import warnings
from typing import Callable, Optional, Tuple, Type, Dict

import jax
Expand Down Expand Up @@ -41,13 +42,20 @@ class sPM:
def __init__(self,
model: model_type,
model_prior: prior_type,
noise_prior: Optional[prior_type] = None) -> None:
noise_prior: Optional[prior_type] = None,
noise_prior_dist: Optional[dist.Distribution] = None,) -> None:
self._model = model
self.model_prior = model_prior
if noise_prior is None:
self.noise_prior = lambda: numpyro.sample("sig", dist.LogNormal(0, 1))
else:
self.noise_prior = noise_prior
if noise_prior is not None:
warnings.warn(
"`noise_prior` is deprecated and will be removed in a future version. "
"Please use `noise_prior_dist` instead, which accepts an instance of a "
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, "
"rather than a function that calls `numpyro.sample`.",
FutureWarning,
)
self.noise_prior = noise_prior
self.noise_prior_dist = noise_prior_dist
self.mcmc = None

def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
Expand All @@ -59,10 +67,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
# Compute the function's value
mu = numpyro.deterministic("mu", self._model(X, params))
# Sample observational noise
sig = self.noise_prior()
if self.noise_prior: # this will be removed in the future releases
sig = self.noise_prior()
else:
sig = self._sample_noise()
# Score against the observed data points
numpyro.sample("y", dist.Normal(mu, sig), obs=y)

def _sample_noise(self) -> jnp.ndarray:
if self.noise_prior_dist is not None:
noise_dist = self.noise_prior_dist
else:
noise_dist = dist.LogNormal(0, 1)
return numpyro.sample("noise", noise_dist)

def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_warmup: int = 2000, num_samples: int = 2000,
num_chains: int = 1, chain_method: str = 'sequential',
Expand All @@ -85,6 +103,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
Expand Down Expand Up @@ -147,6 +166,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
Returns:
Point predictions (or their mean) and posterior predictive distribution
"""
X_new = self._set_data(X_new)
if samples is None:
samples = self.get_samples(chain_dim=False)
if device:
Expand All @@ -165,3 +185,10 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,

def _print_summary(self):
self.mcmc.print_summary()

def _set_data(self,
X: jnp.ndarray, y: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray]:
if y is not None:
return X, y
return X
98 changes: 98 additions & 0 deletions tests/test_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax
import numpyro
from numpy.testing import assert_equal, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.models.bnn import BNN
from gpax.utils import get_keys


def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False):
X = onp.random.randn(8, feature_dim)
y = onp.random.randn(X.shape[0], target_dim)
if squeezed:
return X.squeeze(), y.squeeze()
return X, y


def test_bnn_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1)
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
assert bnn.mcmc is not None


def test_bnn_custom_layers_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
samples = bnn.get_samples()
assert_equal(samples["w0"].shape, (50, 1, 32))
assert_equal(samples["w1"].shape, (50, 32, 16))
assert_equal(samples["w2"].shape, (50, 16, 8))
assert_equal(samples["w3"].shape, (50, 8, 1))
assert_equal(samples["b0"].shape, (50, 32))
assert_equal(samples["b1"].shape, (50, 16))
assert_equal(samples["b2"].shape, (50, 8))
assert_equal(samples["b3"].shape, (50, 1))


def test_bnn_predict_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 64)),
"w1": jax.random.normal(key, shape=(50, 64, 32)),
"w2": jax.random.normal(key, shape=(50, 32, 1)),
"b0": jax.random.normal(key, shape=(50, 64,)),
"b1": jax.random.normal(key, shape=(50, 32,)),
"b2": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1)
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


def test_bnn_custom_layers_predict_custom_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 32)),
"w1": jax.random.normal(key, shape=(50, 32, 16)),
"w2": jax.random.normal(key, shape=(50, 16, 8)),
"w3": jax.random.normal(key, shape=(50, 8, 1)),
"b0": jax.random.normal(key, shape=(50, 32,)),
"b1": jax.random.normal(key, shape=(50, 16,)),
"b2": jax.random.normal(key, shape=(50, 8,)),
"b3": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


@pytest.mark.parametrize("squeezed", [True, False])
@pytest.mark.parametrize("target_dim", [1, 2])
@pytest.mark.parametrize("feature_dim", [1, 2])
def test_bnn_fit_predict(feature_dim, target_dim, squeezed):
key, _ = get_keys()
X, y = get_dummy_data(feature_dim, target_dim, squeezed)
X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed)
bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2])
bnn.fit(key, X, y, num_warmup=5, num_samples=5)
f_pred, f_samples = bnn.predict(key, X_test)
assert_equal(f_pred.shape, (len(X_test), target_dim))
assert_equal(f_samples.shape, (5, len(X_test), target_dim))
Loading