diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 56ce5dd7..44d9bb6b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -29,6 +29,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + python -m pip install . - name: Test with pytest run: | pytest diff --git a/example/GW150914.py b/example/GW150914.py index 2e73ad58..a9c1c9c8 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -3,7 +3,7 @@ from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD from jimgw.waveform import RippleIMRPhenomD -from jimgw.prior import Uniform +from jimgw.prior import Unconstrained_Uniform, Composite import jax.numpy as jnp import jax @@ -27,28 +27,69 @@ H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -prior = Uniform( - xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], - xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], - naming=[ - "M_c", - "q", - "s1_z", - "s2_z", - "d_L", - "t_c", - "phase_c", - "cos_iota", - "psi", - "ra", - "sin_dec", - ], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) +s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, ) -likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) @@ -76,5 +117,5 @@ local_sampler_arg=local_sampler_arg, ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) +# jim.maximize_likelihood([prior.xmin, prior.xmax]) jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 6dc91e79..140af05b 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -2,8 +2,8 @@ from jimgw.jim import Jim from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 -from jimgw.prior import Uniform +from jimgw.waveform import RippleIMRPhenomPv2 +from jimgw.prior import Uniform, Composite, Sphere import jax.numpy as jnp import jax @@ -28,19 +28,66 @@ L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) waveform = RippleIMRPhenomPv2(f_ref=20) -prior = Uniform( - xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.], - xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.], - naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']), - "s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']), - "s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']), - "s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']), - "s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']), - "s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec + +########################################### +########## Set up priors ################## +########################################### + +Mc_prior = Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1_prior = Sphere(naming="s1") +s2_prior = Sphere(naming="s2") +dL_prior = Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1_prior, + s2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ], ) likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) # likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) @@ -54,7 +101,7 @@ jim = Jim( likelihood, prior, - n_loop_training=200, + n_loop_training=100, n_loop_production=10, n_local_steps=300, n_global_steps=300, @@ -63,17 +110,12 @@ learning_rate=0.001, max_samples = 60000, momentum=0.9, - batch_size=30000, + batch_size=60000, use_global=True, keep_quantile=0., train_thinning=1, output_thinning=30, local_sampler_arg=local_sampler_arg, - num_layers = 6, - hidden_size = [32,32], - num_bins = 8 ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) -# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py new file mode 100644 index 00000000..d995df98 --- /dev/null +++ b/example/GW150914_PV2_newglobal.py @@ -0,0 +1,81 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite, Sphere +import jax.numpy as jnp +import jax + + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +start = gps - 2 +end = gps + 2 +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +waveform = RippleIMRPhenomPv2(f_ref=20) + +Mc_prior = Unconstrained_Uniform(10., 80., naming=["M_c"]) +q_prior = Unconstrained_Uniform(0.125, 1., naming=["q"], transforms={"q": ("eta", lambda params: params['q']/(1+params['q'])**2)}) +s1_prior = Sphere("s1") +s2_prior = Sphere("s2") +dL_prior = Unconstrained_Uniform(0., 2000., naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform(-1., 1., naming=["cos_iota"], transforms={"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi))}) +psi_prior = Unconstrained_Uniform(0., jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform(-1., 1., naming=["sin_dec"], transforms={"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))}) + +prior = Composite([Mc_prior, q_prior, s1_prior, s2_prior, dL_prior, t_c_prior, phase_c_prior, cos_iota_prior, psi_prior, ra_prior, sin_dec_prior]) + +likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) + + +mass_matrix = jnp.eye(prior.n_dim) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + + +jim = Jim( + likelihood, + prior, + n_loop_training=50, + n_loop_production=10, + n_local_steps=300, + n_global_steps=300, + n_chains=500, + n_epochs=300, + learning_rate=0.001, + max_samples = 60000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0., + train_thinning=1, + output_thinning=30, + local_sampler_arg=local_sampler_arg, + num_layers = 6, + hidden_size = [32,32], + num_bins = 8 +) + +# jim.maximize_likelihood([prior.xmin, prior.xmax]) +# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) +jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 12aa89c1..41dd8ef9 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jaxtyping import Array import jax import jax.numpy as jnp +from flowMC.sampler.flowHMC import flowHMC class Jim(object): """ @@ -30,13 +31,30 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix + flowHMC_params = kwargs.get("flowHMC_params", {}) model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) + if len(flowHMC_params) > 0: + global_sampler = flowHMC( + self.posterior, + True, + model, + params={ + "step_size": flowHMC_params["step_size"], + "n_leapfrog": flowHMC_params["n_leapfrog"], + "condition_matrix": flowHMC_params["condition_matrix"], + }, + ) + else: + global_sampler = None + + self.Sampler = Sampler( self.Prior.n_dim, rng_key_set, None, local_sampler, model, + global_sampler = global_sampler, **kwargs) @@ -59,13 +77,15 @@ def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 10 return best_fit def posterior(self, params: Array, data: dict): - named_params = self.Prior.add_name(params, transform_name=True, transform_value=True) - return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(params) + prior_params = self.Prior.add_name(params.T) + prior = self.Prior.log_prob(prior_params) + return self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior def sample(self, key: jax.random.PRNGKey, initial_guess: Array = None): if initial_guess is None: initial_guess = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess = jnp.stack([i for i in initial_guess.values()]).T self.Sampler.sample(initial_guess, None) def print_summary(self): diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 4baf4298..2137b018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -2,9 +2,10 @@ import jax.numpy as jnp from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float -from typing import Callable, Union +from typing import Callable from dataclasses import field + class Prior(Distribution): """ A thin wrapper build on top of flowMC distributions to do book keeping. @@ -16,13 +17,13 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[tuple[str,Callable]] = field(default_factory=dict) + transforms: dict[tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - - def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {}): + + def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): """ Parameters ---------- @@ -37,15 +38,17 @@ def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {} self.transforms = {} def make_lambda(name): - return lambda x: x[name] + return lambda x: x[name] for name in naming: if name in transforms: self.transforms[name] = transforms[name] else: - self.transforms[name] = (name, make_lambda(name)) # Without the function, the lambda will refer to the variable name instead of its value, which will make lambda reference the last value of the variable name + # Without the function, the lambda will refer to the variable name instead of its value, + # which will make lambda reference the last value of the variable name + self.transforms[name] = (name, make_lambda(name)) - def transform(self, x: Array) -> Array: + def transform(self, x: dict) -> dict: """ Apply the transforms to the parameters. @@ -59,37 +62,115 @@ def transform(self, x: Array) -> Array: x : dict A dictionary of parameters with the transforms applied. """ - output = self.add_name(x, transform_name = False, transform_value = False) - for i, (key, value) in enumerate(self.transforms.items()): - x = x.at[i].set(value[1](output)) - return x + output = {} + for value in self.transforms.values(): + output[value[0]] = value[1](x) + return output - def add_name(self, x: Array, transform_name: bool = False, transform_value: bool = False) -> dict: + def add_name(self, x: Array) -> dict: """ Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim, n_sample). """ - if transform_name: - naming = [value[0] for value in self.transforms.values()] - else: - naming = self.naming - if transform_value: - x = self.transform(x) - value = x - else: - value = x - return dict(zip(naming,value)) + + return dict(zip(self.naming, x)) + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + raise NotImplementedError + + def logpdf(self, x: dict) -> Float: + raise NotImplementedError + class Uniform(Prior): - xmin: Array - xmax: Array + xmin: float = 0.0 + xmax: float = 1.0 - def __init__(self, xmin: Union[float,Array], xmax: Union[float,Array], **kwargs): - super().__init__(kwargs.get("naming"), kwargs.get("transforms")) - self.xmax = jnp.array(xmax) - self.xmin = jnp.array(xmin) - - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Uniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + """ + Sample from a uniform distribution. + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.uniform( + rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax + ) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + output = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + return output + jnp.log(1.0 / (self.xmax - self.xmin)) + + +class Unconstrained_Uniform(Prior): + + xmin: float = 0.0 + xmax: float = 1.0 + to_range: Callable = lambda x: x + + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + local_transform = self.transforms + self.to_range = ( + lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + + self.xmin + ) + + def new_transform(param): + param[self.naming[0]] = self.to_range(param) + return local_transform[self.naming[0]][1](param) + + self.transforms = { + self.naming[0]: (local_transform[self.naming[0]][0], new_transform) + } + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ Sample from a uniform distribution. @@ -102,13 +183,86 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: Returns ------- - samples : Array + samples : An array of shape (n_samples, n_dim) containing the samples. - + """ - samples = jax.random.uniform(rng_key, (n_samples,self.n_dim), minval=self.xmin, maxval=self.xmax) - return samples # TODO: remember to cast this to a named array + samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) + samples = jnp.log(samples / (1 - samples)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) + + +class Sphere(Prior): + + """ + A prior on a sphere represented by Cartesian coordinates. + + Magnitude is sampled from a uniform distribution. + """ + + def __init__(self, naming: str): + self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] + self.transforms = { + self.naming[0]: ( + f"{naming}_x", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.cos(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[1]: ( + f"{naming}_y", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.sin(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[2]: ( + f"{naming}_z", + lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], + ), + } + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + rng_keys = jax.random.split(rng_key, 3) + theta = jnp.arccos( + jax.random.uniform( + rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 + ) + ) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) + mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) + return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) + + def log_prob(self, x: dict) -> Float: + return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) + + +class Composite(Prior): + + priors: list[Prior] = field(default_factory=list) + + def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}): + naming = [] + self.transforms = {} + for prior in priors: + naming += prior.naming + self.transforms.update(prior.transforms) + self.priors = priors + self.naming = naming + self.transforms.update(transforms) + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output - def log_prob(self, x: Array) -> Float: - output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x))) - return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) + def log_prob(self, x: dict) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output diff --git a/test/test_prior.py b/test/test_prior.py new file mode 100644 index 00000000..d4c5be59 --- /dev/null +++ b/test/test_prior.py @@ -0,0 +1 @@ +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite