Skip to content

Commit

Permalink
Combine should be working now
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Jul 25, 2024
1 parent d8b2d1f commit ca1d6b6
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,44 +145,43 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]:
x = transform.forward(x)
return x

# class Combine(Prior):
# """
# A prior class constructed by joinning multiple priors together to form a multivariate prior.
# This assumes the priors composing the Combine class are independent.
# """

# priors: list[Prior] = field(default_factory=list)

# def __repr__(self):
# return (
# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})"
# )

# def __init__(
# self,
# priors: list[Prior],
# **kwargs,
# ):
# parameter_names = []
# for prior in priors:
# parameter_names += prior.parameter_names
# self.priors = priors
# self.parameter_names = parameter_names

# def sample(
# self, rng_key: PRNGKeyArray, n_samples: int
# ) -> dict[str, Float[Array, " n_samples"]]:
# output = {}
# for prior in self.priors:
# rng_key, subkey = jax.random.split(rng_key)
# output.update(prior.sample(subkey, n_samples))
# return output

# def log_prob(self, x: dict[str, Float]) -> Float:
# output = 0.0
# for prior in self.priors:
# output -= prior.log_prob(x)
# return output
class Combine(Prior):
"""
A prior class constructed by joinning multiple priors together to form a multivariate prior.
This assumes the priors composing the Combine class are independent.
"""

priors: list[Prior] = field(default_factory=list)

def __repr__(self):
return (
f"Composite(priors={self.priors}, parameter_names={self.parameter_names})"
)

def __init__(
self,
priors: list[Prior],
):
parameter_names = []
for prior in priors:
parameter_names += prior.parameter_names
self.priors = priors
self.parameter_names = parameter_names

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
output = {}
for prior in self.priors:
rng_key, subkey = jax.random.split(rng_key)
output.update(prior.sample(subkey, n_samples))
return output

def log_prob(self, x: dict[str, Float]) -> Float:
output = 0.0
for prior in self.priors:
output += prior.log_prob(x)
return output



Expand Down

0 comments on commit ca1d6b6

Please sign in to comment.