Skip to content

Commit

Permalink
Update type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 25, 2024
1 parent 6c2c93f commit a5af082
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions gpax/models/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray:
return b


def get_mlp(architecture: List[int]) -> Callable:
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
Expand All @@ -65,7 +65,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
Expand Down
6 changes: 3 additions & 3 deletions gpax/models/dkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def sample_biases(name: str, channels: int) -> jnp.ndarray:
return b


def get_mlp(architecture: List[int]) -> Callable:
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
Expand All @@ -177,7 +177,7 @@ def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
return mlp


def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Dict[str, jnp.ndarray]:
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
Expand All @@ -190,4 +190,4 @@ def mlp_prior():
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior
return mlp_prior

0 comments on commit a5af082

Please sign in to comment.