diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index 12c2d3d..1e93222 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -4,11 +4,11 @@ Fully Bayesian implementation of deep kernel learning -Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) """ from functools import partial -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union, List import jax import jax.numpy as jnp @@ -39,6 +39,10 @@ class DKL(vExactGP): Priors over the weights and biases in 'nn'; uses normal priors by default latent_prior: Optional prior over the latent space (BNN embedding); uses none by default + hidden_dim: + Optional custom MLP architecture. For example [16, 8, 4] corresponds to a 3-layer + neural network backbone containing 16, 8, and 4 neurons activated by tanh(). The latent + layer is added autoamtically and doesn't have to be spcified here. Defaults to [64, 32]. **kwargs: Optional custom prior distributions over observational noise (noise_dist_prior) @@ -67,11 +71,12 @@ def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF', nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None, - **kwargs + hidden_dim: Optional[List[int]] = None, **kwargs ) -> None: super(DKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs) - self.nn = nn if nn else mlp - self.nn_prior = nn_prior if nn_prior else mlp_prior(input_dim, z_dim) + hdim = hidden_dim if hidden_dim is not None else [64, 32] + self.nn = nn if nn else get_mlp(hdim) + self.nn_prior = nn_prior if nn_prior else get_mlp_prior(input_dim, z_dim, hdim) self.kernel_dim = z_dim self.latent_prior = latent_prior @@ -108,7 +113,6 @@ def model(self, obs=y, ) - #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], @@ -174,25 +178,30 @@ def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray: return b -def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: - """Simple MLP for a single MCMC sample of weights and biases""" - h1 = jnp.tanh(jnp.matmul(X, params["w1"]) + params["b1"]) - h2 = jnp.tanh(jnp.matmul(h1, params["w2"]) + params["b2"]) - z = jnp.matmul(h2, params["w3"]) + params["b3"] - return z - - -def mlp_prior(input_dim: int, zdim: int = 2) -> Dict[str, jnp.array]: - """Priors over weights and biases in the default Bayesian MLP""" - hdim = [64, 32] - - def _bnn_prior(task_dim: int): - w1 = sample_weights("w1", input_dim, hdim[0], task_dim) - b1 = sample_biases("b1", hdim[0], task_dim) - w2 = sample_weights("w2", hdim[0], hdim[1], task_dim) - b2 = sample_biases("b2", hdim[1], task_dim) - w3 = sample_weights("w3", hdim[1], zdim, task_dim) - b3 = sample_biases("b3", zdim, task_dim) - return {"w1": w1, "b1": b1, "w2": w2, "b2": b2, "w3": w3, "b3": b3} - - return _bnn_prior +def get_mlp(architecture: List[int]) -> Callable: + """Returns a function that represents an MLP for a given architecture.""" + def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: + """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" + h = X + for i in range(len(architecture)): + h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"]) + # No non-linearity after the last layer + z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"] + return z + return mlp + + +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: + """Priors over weights and biases for a Bayesian MLP""" + def mlp_prior(task_dim: int): + params = {} + in_channels = input_dim + for i, out_channels in enumerate(architecture): + params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, task_dim) + params[f"b{i}"] = sample_biases(f"b{i}", out_channels, task_dim) + in_channels = out_channels + # Output layer + params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim, task_dim) + params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim, task_dim) + return params + return mlp_prior