-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #91 from ziatdinovmax/bnn
Add a standalone BNN
- Loading branch information
Showing
7 changed files
with
230 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
bnn.py | ||
======= | ||
Fully Bayesian MLPs | ||
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
from typing import Callable, Dict, Optional, List, Union, Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.distributions as dist | ||
|
||
from .spm import sPM | ||
|
||
|
||
class BNN(sPM): | ||
"""Fully Bayesian MLP""" | ||
def __init__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
hidden_dim: Optional[List[int]] = None, **kwargs): | ||
hidden_dim = [64, 32] if not hidden_dim else hidden_dim | ||
nn = kwargs.get("nn", get_mlp(hidden_dim)) | ||
nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim)) | ||
super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist) | ||
|
||
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None | ||
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: | ||
X = X if X.ndim > 1 else X[:, None] | ||
if y is not None: | ||
y = y[:, None] if y.ndim < 1 else y | ||
return X, y | ||
return X | ||
|
||
|
||
def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray: | ||
"""Sampling weights matrix""" | ||
w = numpyro.sample(name=name, fn=dist.Normal( | ||
loc=jnp.zeros((in_channels, out_channels)), | ||
scale=jnp.ones((in_channels, out_channels)))) | ||
return w | ||
|
||
|
||
def sample_biases(name: str, channels: int) -> jnp.ndarray: | ||
"""Sampling bias vector""" | ||
b = numpyro.sample(name=name, fn=dist.Normal( | ||
loc=jnp.zeros((channels)), scale=jnp.ones((channels)))) | ||
return b | ||
|
||
|
||
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: | ||
"""Returns a function that represents an MLP for a given architecture.""" | ||
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: | ||
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" | ||
h = X | ||
for i in range(len(architecture)): | ||
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"]) | ||
# No non-linearity after the last layer | ||
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"] | ||
return z | ||
return mlp | ||
|
||
|
||
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]: | ||
"""Priors over weights and biases for a Bayesian MLP""" | ||
def mlp_prior(): | ||
params = {} | ||
in_channels = input_dim | ||
for i, out_channels in enumerate(architecture): | ||
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels) | ||
params[f"b{i}"] = sample_biases(f"b{i}", out_channels) | ||
in_channels = out_channels | ||
# Output layer | ||
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim) | ||
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim) | ||
return params | ||
return mlp_prior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
import warnings | ||
from typing import Callable, Optional, Tuple, Type, Dict | ||
|
||
import jax | ||
|
@@ -41,13 +42,20 @@ class sPM: | |
def __init__(self, | ||
model: model_type, | ||
model_prior: prior_type, | ||
noise_prior: Optional[prior_type] = None) -> None: | ||
noise_prior: Optional[prior_type] = None, | ||
noise_prior_dist: Optional[dist.Distribution] = None,) -> None: | ||
self._model = model | ||
self.model_prior = model_prior | ||
if noise_prior is None: | ||
self.noise_prior = lambda: numpyro.sample("sig", dist.LogNormal(0, 1)) | ||
else: | ||
self.noise_prior = noise_prior | ||
if noise_prior is not None: | ||
warnings.warn( | ||
"`noise_prior` is deprecated and will be removed in a future version. " | ||
"Please use `noise_prior_dist` instead, which accepts an instance of a " | ||
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " | ||
"rather than a function that calls `numpyro.sample`.", | ||
FutureWarning, | ||
) | ||
self.noise_prior = noise_prior | ||
self.noise_prior_dist = noise_prior_dist | ||
self.mcmc = None | ||
|
||
def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None: | ||
|
@@ -59,10 +67,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None: | |
# Compute the function's value | ||
mu = numpyro.deterministic("mu", self._model(X, params)) | ||
# Sample observational noise | ||
sig = self.noise_prior() | ||
if self.noise_prior: # this will be removed in the future releases | ||
sig = self.noise_prior() | ||
else: | ||
sig = self._sample_noise() | ||
# Score against the observed data points | ||
numpyro.sample("y", dist.Normal(mu, sig), obs=y) | ||
|
||
def _sample_noise(self) -> jnp.ndarray: | ||
if self.noise_prior_dist is not None: | ||
noise_dist = self.noise_prior_dist | ||
else: | ||
noise_dist = dist.LogNormal(0, 1) | ||
return numpyro.sample("noise", noise_dist) | ||
|
||
def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, | ||
num_warmup: int = 2000, num_samples: int = 2000, | ||
num_chains: int = 1, chain_method: str = 'sequential', | ||
|
@@ -85,6 +103,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, | |
optionally specify a cpu or gpu device on which to run the inference; | ||
e.g., ``device=jax.devices("cpu")[0]`` | ||
""" | ||
X, y = self._set_data(X, y) | ||
if device: | ||
X = jax.device_put(X, device) | ||
y = jax.device_put(y, device) | ||
|
@@ -147,6 +166,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, | |
Returns: | ||
Point predictions (or their mean) and posterior predictive distribution | ||
""" | ||
X_new = self._set_data(X_new) | ||
if samples is None: | ||
samples = self.get_samples(chain_dim=False) | ||
if device: | ||
|
@@ -165,3 +185,10 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, | |
|
||
def _print_summary(self): | ||
self.mcmc.print_summary() | ||
|
||
def _set_data(self, | ||
X: jnp.ndarray, y: Optional[jnp.ndarray] = None, | ||
) -> Tuple[jnp.ndarray]: | ||
if y is not None: | ||
return X, y | ||
return X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import sys | ||
import pytest | ||
import numpy as onp | ||
import jax.numpy as jnp | ||
import jax | ||
import numpyro | ||
from numpy.testing import assert_equal, assert_array_equal | ||
|
||
sys.path.insert(0, "../gpax/") | ||
|
||
from gpax.models.bnn import BNN | ||
from gpax.utils import get_keys | ||
|
||
|
||
def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False): | ||
X = onp.random.randn(8, feature_dim) | ||
y = onp.random.randn(X.shape[0], target_dim) | ||
if squeezed: | ||
return X.squeeze(), y.squeeze() | ||
return X, y | ||
|
||
|
||
def test_bnn_fit(): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data() | ||
bnn = BNN(1, 1) | ||
bnn.fit(key, X, y, num_warmup=50, num_samples=50) | ||
assert bnn.mcmc is not None | ||
|
||
|
||
def test_bnn_custom_layers_fit(): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data() | ||
bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) | ||
bnn.fit(key, X, y, num_warmup=50, num_samples=50) | ||
samples = bnn.get_samples() | ||
assert_equal(samples["w0"].shape, (50, 1, 32)) | ||
assert_equal(samples["w1"].shape, (50, 32, 16)) | ||
assert_equal(samples["w2"].shape, (50, 16, 8)) | ||
assert_equal(samples["w3"].shape, (50, 8, 1)) | ||
assert_equal(samples["b0"].shape, (50, 32)) | ||
assert_equal(samples["b1"].shape, (50, 16)) | ||
assert_equal(samples["b2"].shape, (50, 8)) | ||
assert_equal(samples["b3"].shape, (50, 1)) | ||
|
||
|
||
def test_bnn_predict_with_samples(): | ||
key, _ = get_keys() | ||
X_test, _ = get_dummy_data() | ||
|
||
params = {"w0": jax.random.normal(key, shape=(50, 1, 64)), | ||
"w1": jax.random.normal(key, shape=(50, 64, 32)), | ||
"w2": jax.random.normal(key, shape=(50, 32, 1)), | ||
"b0": jax.random.normal(key, shape=(50, 64,)), | ||
"b1": jax.random.normal(key, shape=(50, 32,)), | ||
"b2": jax.random.normal(key, shape=(50, 1,)), | ||
"noise": jax.random.normal(key, shape=(50,)) | ||
} | ||
|
||
bnn = BNN(1, 1) | ||
f_pred, f_samples = bnn.predict(key, X_test, params) | ||
assert_equal(f_pred.shape, (len(X_test), 1)) | ||
assert_equal(f_samples.shape, (50, len(X_test), 1)) | ||
|
||
|
||
def test_bnn_custom_layers_predict_custom_with_samples(): | ||
key, _ = get_keys() | ||
X_test, _ = get_dummy_data() | ||
|
||
params = {"w0": jax.random.normal(key, shape=(50, 1, 32)), | ||
"w1": jax.random.normal(key, shape=(50, 32, 16)), | ||
"w2": jax.random.normal(key, shape=(50, 16, 8)), | ||
"w3": jax.random.normal(key, shape=(50, 8, 1)), | ||
"b0": jax.random.normal(key, shape=(50, 32,)), | ||
"b1": jax.random.normal(key, shape=(50, 16,)), | ||
"b2": jax.random.normal(key, shape=(50, 8,)), | ||
"b3": jax.random.normal(key, shape=(50, 1,)), | ||
"noise": jax.random.normal(key, shape=(50,)) | ||
} | ||
|
||
bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) | ||
f_pred, f_samples = bnn.predict(key, X_test, params) | ||
assert_equal(f_pred.shape, (len(X_test), 1)) | ||
assert_equal(f_samples.shape, (50, len(X_test), 1)) | ||
|
||
|
||
@pytest.mark.parametrize("squeezed", [True, False]) | ||
@pytest.mark.parametrize("target_dim", [1, 2]) | ||
@pytest.mark.parametrize("feature_dim", [1, 2]) | ||
def test_bnn_fit_predict(feature_dim, target_dim, squeezed): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data(feature_dim, target_dim, squeezed) | ||
X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed) | ||
bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2]) | ||
bnn.fit(key, X, y, num_warmup=5, num_samples=5) | ||
f_pred, f_samples = bnn.predict(key, X_test) | ||
assert_equal(f_pred.shape, (len(X_test), target_dim)) | ||
assert_equal(f_samples.shape, (5, len(X_test), target_dim)) |