Skip to content

Commit

Permalink
including multivariate normal ldpf function
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasCowie committed Sep 20, 2024
1 parent 27b2f0d commit 95bbe39
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/enzax/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
from jax._src.random import KeyArray
import jax.numpy as jnp
from jax.scipy.stats import norm
from jax.scipy.stats import norm, multivariate_normal
from jaxtyping import Array, Float, PyTree, ScalarLike

from enzax.kinetic_model import (
Expand Down Expand Up @@ -62,6 +62,13 @@ def ind_normal_prior_logdensity(param, prior: Float[Array, "2 _"]):
"""Total log density for an independent normal distribution."""
return norm.logpdf(param, loc=prior[0], scale=prior[1]).sum()

def mv_normal_prior_logdensity(
param: Float[Array, "_"],
prior: tuple[Float[Array, "_"], Float[Array, "_ _"]],
):
"""Total log density for an multivariate normal distribution."""
return jnp.sum(multivariate_normal.logpdf(param, mean=prior[0], cov=prior[1]))


def posterior_logdensity_amm(
parameters: AllostericMichaelisMentenParameterSet,
Expand Down

0 comments on commit 95bbe39

Please sign in to comment.