From a5af082c526251f137efd0ae8bb024107a88026b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 25 Feb 2024 14:26:58 -0800 Subject: [PATCH] Update type annotations --- gpax/models/bnn.py | 4 ++-- gpax/models/dkl.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gpax/models/bnn.py b/gpax/models/bnn.py index 2ff2eef..ffd0954 100644 --- a/gpax/models/bnn.py +++ b/gpax/models/bnn.py @@ -52,7 +52,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray: return b -def get_mlp(architecture: List[int]) -> Callable: +def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: """Returns a function that represents an MLP for a given architecture.""" def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" @@ -65,7 +65,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return mlp -def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]: """Priors over weights and biases for a Bayesian MLP""" def mlp_prior(): params = {} diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index ad7a022..689af97 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -164,7 +164,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray: return b -def get_mlp(architecture: List[int]) -> Callable: +def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: """Returns a function that represents an MLP for a given architecture.""" def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" @@ -177,7 +177,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return mlp -def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]: """Priors over weights and biases for a Bayesian MLP""" def mlp_prior(): params = {} @@ -190,4 +190,4 @@ def mlp_prior(): params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim) params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim) return params - return mlp_prior + return mlp_prior \ No newline at end of file