diff --git a/gpax/models/bnn.py b/gpax/models/bnn.py index 2ff2eef..ffd0954 100644 --- a/gpax/models/bnn.py +++ b/gpax/models/bnn.py @@ -52,7 +52,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray: return b -def get_mlp(architecture: List[int]) -> Callable: +def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: """Returns a function that represents an MLP for a given architecture.""" def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" @@ -65,7 +65,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return mlp -def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]: """Priors over weights and biases for a Bayesian MLP""" def mlp_prior(): params = {} diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index ad7a022..689af97 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -164,7 +164,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray: return b -def get_mlp(architecture: List[int]) -> Callable: +def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: """Returns a function that represents an MLP for a given architecture.""" def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers.""" @@ -177,7 +177,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return mlp -def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]: +def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]: """Priors over weights and biases for a Bayesian MLP""" def mlp_prior(): params = {} @@ -190,4 +190,4 @@ def mlp_prior(): params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim) params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim) return params - return mlp_prior + return mlp_prior \ No newline at end of file