Skip to content

Commit

Permalink
Merge pull request #91 from ziatdinovmax/bnn
Browse files Browse the repository at this point in the history
Add a standalone BNN
  • Loading branch information
ziatdinovmax authored Feb 29, 2024
2 parents 5d509eb + d82b1b9 commit 72c7255
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 26 deletions.
4 changes: 2 additions & 2 deletions gpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from .hypo import sample_next
from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP,
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
MeasuredNoiseGP, viSparseGP)
MeasuredNoiseGP, viSparseGP, BNN)

__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "sample_next", "__version__"]
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]
6 changes: 4 additions & 2 deletions gpax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .mngp import MeasuredNoiseGP
from .linreg import LinReg
from .sparse_gp import viSparseGP
from .bnn import BNN

__all__ = [
"ExactGP",
Expand All @@ -30,6 +31,7 @@
"CoregGP",
"UIGP",
"LinReg",
"MeasuredNoiseGP"
"viSparseGP"
"MeasuredNoiseGP",
"viSparseGP",
"BNN"
]
81 changes: 81 additions & 0 deletions gpax/models/bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
bnn.py
=======
Fully Bayesian MLPs
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Callable, Dict, Optional, List, Union, Tuple

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from .spm import sPM


class BNN(sPM):
"""Fully Bayesian MLP"""
def __init__(self,
input_dim: int,
output_dim: int,
noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
hidden_dim: Optional[List[int]] = None, **kwargs):
hidden_dim = [64, 32] if not hidden_dim else hidden_dim
nn = kwargs.get("nn", get_mlp(hidden_dim))
nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim))
super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist)

def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
y = y[:, None] if y.ndim < 1 else y
return X, y
return X


def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
25 changes: 12 additions & 13 deletions gpax/models/dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,16 @@ def model(self,
) -> None:
"""DKL probabilistic model"""
jitter = kwargs.get("jitter", 1e-6)
task_dim = X.shape[0]
# BNN part
nn_params = self.nn_prior(task_dim)
nn_params = self.nn_prior()
z = self.nn(X, nn_params)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params(task_dim)
kernel_params = self._sample_kernel_params()
# Sample noise
noise = self._sample_noise()
# GP's mean function
Expand Down Expand Up @@ -150,22 +149,22 @@ def _print_summary(self):
{k: v for (k, v) in samples.items() if k in list_of_keys})


def sample_weights(name: str, in_channels: int, out_channels: int, task_dim: int) -> jnp.ndarray:
def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray:
def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
b = numpyro.sample(name=name, fn=dist.Cauchy(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def get_mlp(architecture: List[int]) -> Callable:
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
Expand All @@ -178,17 +177,17 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior(task_dim: int):
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, task_dim)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels, task_dim)
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim, task_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim, task_dim)
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
3 changes: 0 additions & 3 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import warnings
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Type, Union

import jax
Expand All @@ -17,7 +16,6 @@
import jax.random as jra
import numpyro
import numpyro.distributions as dist
from jax import jit
from numpyro.infer import MCMC, NUTS, init_to_median, Predictive

from ..kernels import get_kernel
Expand Down Expand Up @@ -252,7 +250,6 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
"""Get posterior samples (after running the MCMC chains)"""
return self.mcmc.get_samples(group_by_chain=chain_dim)

# @partial(jit, static_argnames='self')
def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down
39 changes: 33 additions & 6 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

import warnings
from typing import Callable, Optional, Tuple, Type, Dict

import jax
Expand Down Expand Up @@ -41,13 +42,20 @@ class sPM:
def __init__(self,
model: model_type,
model_prior: prior_type,
noise_prior: Optional[prior_type] = None) -> None:
noise_prior: Optional[prior_type] = None,
noise_prior_dist: Optional[dist.Distribution] = None,) -> None:
self._model = model
self.model_prior = model_prior
if noise_prior is None:
self.noise_prior = lambda: numpyro.sample("sig", dist.LogNormal(0, 1))
else:
self.noise_prior = noise_prior
if noise_prior is not None:
warnings.warn(
"`noise_prior` is deprecated and will be removed in a future version. "
"Please use `noise_prior_dist` instead, which accepts an instance of a "
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, "
"rather than a function that calls `numpyro.sample`.",
FutureWarning,
)
self.noise_prior = noise_prior
self.noise_prior_dist = noise_prior_dist
self.mcmc = None

def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
Expand All @@ -59,10 +67,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
# Compute the function's value
mu = numpyro.deterministic("mu", self._model(X, params))
# Sample observational noise
sig = self.noise_prior()
if self.noise_prior: # this will be removed in the future releases
sig = self.noise_prior()
else:
sig = self._sample_noise()
# Score against the observed data points
numpyro.sample("y", dist.Normal(mu, sig), obs=y)

def _sample_noise(self) -> jnp.ndarray:
if self.noise_prior_dist is not None:
noise_dist = self.noise_prior_dist
else:
noise_dist = dist.LogNormal(0, 1)
return numpyro.sample("noise", noise_dist)

def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_warmup: int = 2000, num_samples: int = 2000,
num_chains: int = 1, chain_method: str = 'sequential',
Expand All @@ -85,6 +103,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
Expand Down Expand Up @@ -147,6 +166,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
Returns:
Point predictions (or their mean) and posterior predictive distribution
"""
X_new = self._set_data(X_new)
if samples is None:
samples = self.get_samples(chain_dim=False)
if device:
Expand All @@ -165,3 +185,10 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,

def _print_summary(self):
self.mcmc.print_summary()

def _set_data(self,
X: jnp.ndarray, y: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray]:
if y is not None:
return X, y
return X
98 changes: 98 additions & 0 deletions tests/test_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax
import numpyro
from numpy.testing import assert_equal, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.models.bnn import BNN
from gpax.utils import get_keys


def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False):
X = onp.random.randn(8, feature_dim)
y = onp.random.randn(X.shape[0], target_dim)
if squeezed:
return X.squeeze(), y.squeeze()
return X, y


def test_bnn_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1)
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
assert bnn.mcmc is not None


def test_bnn_custom_layers_fit():
key, _ = get_keys()
X, y = get_dummy_data()
bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
bnn.fit(key, X, y, num_warmup=50, num_samples=50)
samples = bnn.get_samples()
assert_equal(samples["w0"].shape, (50, 1, 32))
assert_equal(samples["w1"].shape, (50, 32, 16))
assert_equal(samples["w2"].shape, (50, 16, 8))
assert_equal(samples["w3"].shape, (50, 8, 1))
assert_equal(samples["b0"].shape, (50, 32))
assert_equal(samples["b1"].shape, (50, 16))
assert_equal(samples["b2"].shape, (50, 8))
assert_equal(samples["b3"].shape, (50, 1))


def test_bnn_predict_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 64)),
"w1": jax.random.normal(key, shape=(50, 64, 32)),
"w2": jax.random.normal(key, shape=(50, 32, 1)),
"b0": jax.random.normal(key, shape=(50, 64,)),
"b1": jax.random.normal(key, shape=(50, 32,)),
"b2": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1)
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


def test_bnn_custom_layers_predict_custom_with_samples():
key, _ = get_keys()
X_test, _ = get_dummy_data()

params = {"w0": jax.random.normal(key, shape=(50, 1, 32)),
"w1": jax.random.normal(key, shape=(50, 32, 16)),
"w2": jax.random.normal(key, shape=(50, 16, 8)),
"w3": jax.random.normal(key, shape=(50, 8, 1)),
"b0": jax.random.normal(key, shape=(50, 32,)),
"b1": jax.random.normal(key, shape=(50, 16,)),
"b2": jax.random.normal(key, shape=(50, 8,)),
"b3": jax.random.normal(key, shape=(50, 1,)),
"noise": jax.random.normal(key, shape=(50,))
}

bnn = BNN(1, 1, hidden_dim=[32, 16, 8])
f_pred, f_samples = bnn.predict(key, X_test, params)
assert_equal(f_pred.shape, (len(X_test), 1))
assert_equal(f_samples.shape, (50, len(X_test), 1))


@pytest.mark.parametrize("squeezed", [True, False])
@pytest.mark.parametrize("target_dim", [1, 2])
@pytest.mark.parametrize("feature_dim", [1, 2])
def test_bnn_fit_predict(feature_dim, target_dim, squeezed):
key, _ = get_keys()
X, y = get_dummy_data(feature_dim, target_dim, squeezed)
X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed)
bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2])
bnn.fit(key, X, y, num_warmup=5, num_samples=5)
f_pred, f_samples = bnn.predict(key, X_test)
assert_equal(f_pred.shape, (len(X_test), target_dim))
assert_equal(f_samples.shape, (5, len(X_test), target_dim))

0 comments on commit 72c7255

Please sign in to comment.