-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7d317ca
commit 6c2c93f
Showing
2 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
bnn.py | ||
======= | ||
Fully Bayesian MLPs | ||
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
from typing import Callable, Dict, Optional, List, Union, Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.distributions as dist | ||
|
||
from .spm import sPM | ||
|
||
|
||
class BNN(sPM): | ||
"""Fully Bayesian MLP""" | ||
def __init__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
noise_prior_dist: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
hidden_dim: Optional[List[int]] = None, **kwargs): | ||
hidden_dim = [64, 32] if not hidden_dim else hidden_dim | ||
nn = kwargs.get("nn", get_mlp(hidden_dim)) | ||
nn_prior = kwargs.get("nn_prior", get_mlp_prior(input_dim, output_dim, hidden_dim)) | ||
super(BNN, self).__init__(nn, nn_prior, None, noise_prior_dist) | ||
|
||
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None | ||
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: | ||
X = X if X.ndim > 1 else X[:, None] | ||
if y is not None: | ||
y = y[:, None] if y.ndim < 1 else y | ||
return X, y | ||
return X | ||
|
||
|
||
def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray: | ||
"""Sampling weights matrix""" | ||
w = numpyro.sample(name=name, fn=dist.Normal( | ||
loc=jnp.zeros((in_channels, out_channels)), | ||
scale=jnp.ones((in_channels, out_channels)))) | ||
return w | ||
|
||
|
||
def sample_biases(name: str, channels: int) -> jnp.ndarray: | ||
"""Sampling bias vector""" | ||
b = numpyro.sample(name=name, fn=dist.Normal( | ||
loc=jnp.zeros((channels)), scale=jnp.ones((channels)))) | ||
return b | ||
|
||
|
||
def get_mlp(architecture: List[int]) -> Callable: | ||
"""Returns a function that represents an MLP for a given architecture.""" | ||
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: | ||
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" | ||
h = X | ||
for i in range(len(architecture)): | ||
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"]) | ||
# No non-linearity after the last layer | ||
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"] | ||
return z | ||
return mlp | ||
|
||
|
||
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: | ||
"""Priors over weights and biases for a Bayesian MLP""" | ||
def mlp_prior(): | ||
params = {} | ||
in_channels = input_dim | ||
for i, out_channels in enumerate(architecture): | ||
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels) | ||
params[f"b{i}"] = sample_biases(f"b{i}", out_channels) | ||
in_channels = out_channels | ||
# Output layer | ||
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim) | ||
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim) | ||
return params | ||
return mlp_prior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import sys | ||
import pytest | ||
import numpy as onp | ||
import jax.numpy as jnp | ||
import jax | ||
import numpyro | ||
from numpy.testing import assert_equal, assert_array_equal | ||
|
||
sys.path.insert(0, "../gpax/") | ||
|
||
from gpax.models.bnn import BNN | ||
from gpax.utils import get_keys | ||
|
||
|
||
def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False): | ||
X = onp.random.randn(8, feature_dim) | ||
y = onp.random.randn(X.shape[0], target_dim) | ||
if squeezed: | ||
return X.squeeze(), y.squeeze() | ||
return X, y | ||
|
||
|
||
def test_bnn_fit(): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data() | ||
bnn = BNN(1, 1) | ||
bnn.fit(key, X, y, num_warmup=50, num_samples=50) | ||
assert bnn.mcmc is not None | ||
|
||
|
||
def test_bnn_custom_layers_fit(): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data() | ||
bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) | ||
bnn.fit(key, X, y, num_warmup=50, num_samples=50) | ||
samples = bnn.get_samples() | ||
assert_equal(samples["w0"].shape, (50, 1, 32)) | ||
assert_equal(samples["w1"].shape, (50, 32, 16)) | ||
assert_equal(samples["w2"].shape, (50, 16, 8)) | ||
assert_equal(samples["w3"].shape, (50, 8, 1)) | ||
assert_equal(samples["b0"].shape, (50, 32)) | ||
assert_equal(samples["b1"].shape, (50, 16)) | ||
assert_equal(samples["b2"].shape, (50, 8)) | ||
assert_equal(samples["b3"].shape, (50, 1)) | ||
|
||
|
||
def test_bnn_predict_with_samples(): | ||
key, _ = get_keys() | ||
X_test, _ = get_dummy_data() | ||
|
||
params = {"w0": jax.random.normal(key, shape=(50, 1, 64)), | ||
"w1": jax.random.normal(key, shape=(50, 64, 32)), | ||
"w2": jax.random.normal(key, shape=(50, 32, 1)), | ||
"b0": jax.random.normal(key, shape=(50, 64,)), | ||
"b1": jax.random.normal(key, shape=(50, 32,)), | ||
"b2": jax.random.normal(key, shape=(50, 1,)), | ||
"noise": jax.random.normal(key, shape=(50,)) | ||
} | ||
|
||
bnn = BNN(1, 1) | ||
f_pred, f_samples = bnn.predict(key, X_test, params) | ||
assert_equal(f_pred.shape, (len(X_test), 1)) | ||
assert_equal(f_samples.shape, (50, len(X_test), 1)) | ||
|
||
|
||
def test_bnn_custom_layers_predict_custom_with_samples(): | ||
key, _ = get_keys() | ||
X_test, _ = get_dummy_data() | ||
|
||
params = {"w0": jax.random.normal(key, shape=(50, 1, 32)), | ||
"w1": jax.random.normal(key, shape=(50, 32, 16)), | ||
"w2": jax.random.normal(key, shape=(50, 16, 8)), | ||
"w3": jax.random.normal(key, shape=(50, 8, 1)), | ||
"b0": jax.random.normal(key, shape=(50, 32,)), | ||
"b1": jax.random.normal(key, shape=(50, 16,)), | ||
"b2": jax.random.normal(key, shape=(50, 8,)), | ||
"b3": jax.random.normal(key, shape=(50, 1,)), | ||
"noise": jax.random.normal(key, shape=(50,)) | ||
} | ||
|
||
bnn = BNN(1, 1, hidden_dim=[32, 16, 8]) | ||
f_pred, f_samples = bnn.predict(key, X_test, params) | ||
assert_equal(f_pred.shape, (len(X_test), 1)) | ||
assert_equal(f_samples.shape, (50, len(X_test), 1)) | ||
|
||
|
||
@pytest.mark.parametrize("squeezed", [True, False]) | ||
@pytest.mark.parametrize("target_dim", [1, 2]) | ||
@pytest.mark.parametrize("feature_dim", [1, 2]) | ||
def test_bnn_fit_predict(feature_dim, target_dim, squeezed): | ||
key, _ = get_keys() | ||
X, y = get_dummy_data(feature_dim, target_dim, squeezed) | ||
X_test, _ = get_dummy_data(feature_dim, target_dim, squeezed) | ||
bnn = BNN(feature_dim, target_dim, hidden_dim=[4, 2]) | ||
bnn.fit(key, X, y, num_warmup=5, num_samples=5) | ||
f_pred, f_samples = bnn.predict(key, X_test) | ||
assert_equal(f_pred.shape, (len(X_test), target_dim)) | ||
assert_equal(f_samples.shape, (5, len(X_test), target_dim)) |