From ca1d6b6c1fa2ac7cae0d9a313693ea97af070fe4 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 14:07:16 -0400 Subject: [PATCH] Combine should be working now --- src/jimgw/prior.py | 75 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 232b4a0f..85414064 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -145,44 +145,43 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: x = transform.forward(x) return x -# class Combine(Prior): -# """ -# A prior class constructed by joinning multiple priors together to form a multivariate prior. -# This assumes the priors composing the Combine class are independent. -# """ - -# priors: list[Prior] = field(default_factory=list) - -# def __repr__(self): -# return ( -# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" -# ) - -# def __init__( -# self, -# priors: list[Prior], -# **kwargs, -# ): -# parameter_names = [] -# for prior in priors: -# parameter_names += prior.parameter_names -# self.priors = priors -# self.parameter_names = parameter_names - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# output = {} -# for prior in self.priors: -# rng_key, subkey = jax.random.split(rng_key) -# output.update(prior.sample(subkey, n_samples)) -# return output - -# def log_prob(self, x: dict[str, Float]) -> Float: -# output = 0.0 -# for prior in self.priors: -# output -= prior.log_prob(x) -# return output +class Combine(Prior): + """ + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. + """ + + priors: list[Prior] = field(default_factory=list) + + def __repr__(self): + return ( + f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" + ) + + def __init__( + self, + priors: list[Prior], + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.priors = priors + self.parameter_names = parameter_names + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output