Skip to content

Commit

Permalink
Allows passing number of layers/neurons to BNN
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 23, 2024
1 parent 1b3e5c6 commit be70d33
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions gpax/models/dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
Fully Bayesian implementation of deep kernel learning
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -39,6 +39,10 @@ class DKL(vExactGP):
Priors over the weights and biases in 'nn'; uses normal priors by default
latent_prior:
Optional prior over the latent space (BNN embedding); uses none by default
hidden_dim:
Optional custom MLP architecture. For example [16, 8, 4] corresponds to a 3-layer
neural network backbone containing 16, 8, and 4 neurons activated by tanh(). The latent
layer is added autoamtically and doesn't have to be spcified here. Defaults to [64, 32].
**kwargs:
Optional custom prior distributions over observational noise (noise_dist_prior)
Expand Down Expand Up @@ -67,11 +71,12 @@ def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF',
nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
**kwargs
hidden_dim: Optional[List[int]] = None, **kwargs
) -> None:
super(DKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs)
self.nn = nn if nn else mlp
self.nn_prior = nn_prior if nn_prior else mlp_prior(input_dim, z_dim)
hdim = hidden_dim if hidden_dim is not None else [64, 32]
self.nn = nn if nn else get_mlp(hdim)
self.nn_prior = nn_prior if nn_prior else get_mlp_prior(input_dim, z_dim, hdim)
self.kernel_dim = z_dim
self.latent_prior = latent_prior

Expand Down Expand Up @@ -108,7 +113,6 @@ def model(self,
obs=y,
)

#@partial(jit, static_argnames='self')
def _get_mvn_posterior(self,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray, params: Dict[str, jnp.ndarray],
Expand Down Expand Up @@ -174,25 +178,30 @@ def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray:
return b


def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Simple MLP for a single MCMC sample of weights and biases"""
h1 = jnp.tanh(jnp.matmul(X, params["w1"]) + params["b1"])
h2 = jnp.tanh(jnp.matmul(h1, params["w2"]) + params["b2"])
z = jnp.matmul(h2, params["w3"]) + params["b3"]
return z


def mlp_prior(input_dim: int, zdim: int = 2) -> Dict[str, jnp.array]:
"""Priors over weights and biases in the default Bayesian MLP"""
hdim = [64, 32]

def _bnn_prior(task_dim: int):
w1 = sample_weights("w1", input_dim, hdim[0], task_dim)
b1 = sample_biases("b1", hdim[0], task_dim)
w2 = sample_weights("w2", hdim[0], hdim[1], task_dim)
b2 = sample_biases("b2", hdim[1], task_dim)
w3 = sample_weights("w3", hdim[1], zdim, task_dim)
b3 = sample_biases("b3", zdim, task_dim)
return {"w1": w1, "b1": b1, "w2": w2, "b2": b2, "w3": w3, "b3": b3}

return _bnn_prior
def get_mlp(architecture: List[int]) -> Callable:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior(task_dim: int):
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, task_dim)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels, task_dim)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim, task_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim, task_dim)
return params
return mlp_prior

0 comments on commit be70d33

Please sign in to comment.