diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..45ef3064 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,28 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + + pre-commit: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + python -m pip install . + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 56ce5dd7..44d9bb6b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -29,6 +29,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + python -m pip install . - name: Test with pytest run: | pytest diff --git a/example/GW150914.py b/example/GW150914.py index 9c373b6a..8ba26ead 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -3,7 +3,7 @@ from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD from jimgw.waveform import RippleIMRPhenomD -from jimgw.prior import Uniform +from jimgw.prior import Unconstrained_Uniform, Composite import jax.numpy as jnp import jax @@ -29,26 +29,70 @@ H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -prior = Uniform( - xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], - xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], - naming=[ - "M_c", - "q", - "s1_z", - "s2_z", - "d_L", - "t_c", - "phase_c", - "cos_iota", - "psi", - "ra", - "sin_dec", - ], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, ) +s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) +s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 6dc91e79..ac164357 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -2,8 +2,8 @@ from jimgw.jim import Jim from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 -from jimgw.prior import Uniform +from jimgw.waveform import RippleIMRPhenomPv2 +from jimgw.prior import Uniform, Composite, Sphere import jax.numpy as jnp import jax @@ -28,19 +28,66 @@ L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) waveform = RippleIMRPhenomPv2(f_ref=20) -prior = Uniform( - xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.], - xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.], - naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']), - "s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']), - "s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']), - "s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']), - "s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']), - "s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec + +########################################### +########## Set up priors ################## +########################################### + +Mc_prior = Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1_prior = Sphere(naming="s1") +s2_prior = Sphere(naming="s2") +dL_prior = Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1_prior, + s2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ], ) likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) # likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) @@ -61,19 +108,14 @@ n_chains=500, n_epochs=300, learning_rate=0.001, - max_samples = 60000, + max_samples = 10000, momentum=0.9, - batch_size=30000, + batch_size=10000, use_global=True, keep_quantile=0., train_thinning=1, output_thinning=30, local_sampler_arg=local_sampler_arg, - num_layers = 6, - hidden_size = [32,32], - num_bins = 8 ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) -# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py new file mode 100644 index 00000000..fce24f74 --- /dev/null +++ b/example/GW150914_PV2_newglobal.py @@ -0,0 +1,152 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite, Sphere +import jax.numpy as jnp +import jax + + +jax.config.update("jax_enable_x64", True) + +########################################### +########## This script is experimental #### +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +start = gps - 2 +end = gps + 2 +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +waveform = RippleIMRPhenomPv2(f_ref=20) + +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1_prior = Sphere("s1") +s2_prior = Sphere("s2") +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1_prior, + s2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) + +optimization_bounds = jnp.array( + [ + [-10.0, 10.0], + [-10.0, 10.0], + [0.0, 2.0 * jnp.pi], + [-1.0, 1.0], + [0.01, 1.0], + [0.0, 2.0 * jnp.pi], + [-1.0, 1.0], + [0.01, 1.0], + [-10.0, 10.0], + [-30.0, 30.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + ] +) + +likelihood = TransientLikelihoodFD( + [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 +) + + +mass_matrix = jnp.eye(prior.n_dim) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} + + +jim = Jim( + likelihood, + prior, + n_loop_training=20, + n_loop_production=10, + n_local_steps=300, + n_global_steps=300, + n_chains=500, + n_epochs=300, + learning_rate=0.001, + max_samples=60000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=30, + local_sampler_arg=local_sampler_arg, + num_layers=6, + hidden_size=[32, 32], + num_bins=8, + flowHMC_params={ + "step_size": 1e-2, + "n_leapfrog": 3, + "condition_matrix": jnp.linalg.inv(mass_matrix), + }, +) + +# jim.maximize_likelihood([prior.xmin, prior.xmax]) +# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) +jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 9f4cae5b..5df9cf48 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jaxtyping import Array import jax import jax.numpy as jnp +from flowMC.sampler.flowHMC import flowHMC class Jim(object): @@ -33,20 +34,35 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.posterior, True, local_sampler_arg ) # Remember to add routine to find automated mass matrix - model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1] - ) - self.Sampler = Sampler( - self.Prior.n_dim, rng_key_set, None, local_sampler, model, **kwargs - ) - def maximize_likelihood( - self, - bounds: tuple[Array, Array], - set_nwalkers: int = 100, - n_loops: int = 2000, - seed=92348, - ): + flowHMC_params = kwargs.get("flowHMC_params", {}) + model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) + if len(flowHMC_params) > 0: + global_sampler = flowHMC( + self.posterior, + True, + model, + params={ + "step_size": flowHMC_params["step_size"], + "n_leapfrog": flowHMC_params["n_leapfrog"], + "condition_matrix": flowHMC_params["condition_matrix"], + }, + ) + else: + global_sampler = None + + + self.Sampler = Sampler( + self.Prior.n_dim, + rng_key_set, + None, + local_sampler, + model, + global_sampler = global_sampler, + **kwargs) + + + def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 100, n_loops: int = 2000, seed = 92348): bounds = jnp.array(bounds).T key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers @@ -65,16 +81,15 @@ def maximize_likelihood( return best_fit def posterior(self, params: Array, data: dict): - named_params = self.Prior.add_name( - params, transform_name=True, transform_value=True - ) - return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob( - params - ) + prior_params = self.Prior.add_name(params.T) + prior = self.Prior.log_prob(prior_params) + return self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior - def sample(self, key: jax.random.PRNGKey, initial_guess: Array = None): + def sample(self, key: jax.random.PRNGKey, + initial_guess: Array = None): if initial_guess is None: initial_guess = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess = jnp.stack([i for i in initial_guess.values()]).T self.Sampler.sample(initial_guess, None) def print_summary(self): diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 7bcdd6c0..63164805 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -150,8 +150,8 @@ def __init__( trigger_time: float = 0, duration: float = 4, post_trigger_duration: float = 2, - n_walkers: int = 100, - n_loops: int = 200, + popsize: int = 100, + n_loops: int = 2000, ) -> None: super().__init__( detectors, waveform, trigger_time, duration, post_trigger_duration @@ -181,7 +181,7 @@ def __init__( print("Finding reference parameters..") self.ref_params = self.maximize_likelihood( - bounds=bounds, prior=prior, set_nwalkers=n_walkers, n_loops=n_loops + bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops ) print("Constructing reference waveforms..") @@ -455,11 +455,11 @@ def maximize_likelihood( self, bounds: tuple[Array, Array], prior: Prior, - set_nwalkers: int = 100, + popsize: int = 100, n_loops: int = 2000, ): bounds = jnp.array(bounds).T - set_nwalkers = set_nwalkers + popsize = popsize # TODO remove this? def y(x): return -self.evaluate_original( @@ -469,7 +469,7 @@ def y(x): y = jax.jit(jax.vmap(y)) print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(len(bounds), verbose=True) - optimizer.optimize(y, bounds, n_loops=n_loops) + optimizer = EvolutionaryOptimizer(len(bounds), popsize=popsize, verbose=True) + state = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return prior.add_name(best_fit, transform_name=True, transform_value=True) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 4baf4298..2137b018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -2,9 +2,10 @@ import jax.numpy as jnp from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float -from typing import Callable, Union +from typing import Callable from dataclasses import field + class Prior(Distribution): """ A thin wrapper build on top of flowMC distributions to do book keeping. @@ -16,13 +17,13 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[tuple[str,Callable]] = field(default_factory=dict) + transforms: dict[tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - - def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {}): + + def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): """ Parameters ---------- @@ -37,15 +38,17 @@ def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {} self.transforms = {} def make_lambda(name): - return lambda x: x[name] + return lambda x: x[name] for name in naming: if name in transforms: self.transforms[name] = transforms[name] else: - self.transforms[name] = (name, make_lambda(name)) # Without the function, the lambda will refer to the variable name instead of its value, which will make lambda reference the last value of the variable name + # Without the function, the lambda will refer to the variable name instead of its value, + # which will make lambda reference the last value of the variable name + self.transforms[name] = (name, make_lambda(name)) - def transform(self, x: Array) -> Array: + def transform(self, x: dict) -> dict: """ Apply the transforms to the parameters. @@ -59,37 +62,115 @@ def transform(self, x: Array) -> Array: x : dict A dictionary of parameters with the transforms applied. """ - output = self.add_name(x, transform_name = False, transform_value = False) - for i, (key, value) in enumerate(self.transforms.items()): - x = x.at[i].set(value[1](output)) - return x + output = {} + for value in self.transforms.values(): + output[value[0]] = value[1](x) + return output - def add_name(self, x: Array, transform_name: bool = False, transform_value: bool = False) -> dict: + def add_name(self, x: Array) -> dict: """ Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim, n_sample). """ - if transform_name: - naming = [value[0] for value in self.transforms.values()] - else: - naming = self.naming - if transform_value: - x = self.transform(x) - value = x - else: - value = x - return dict(zip(naming,value)) + + return dict(zip(self.naming, x)) + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + raise NotImplementedError + + def logpdf(self, x: dict) -> Float: + raise NotImplementedError + class Uniform(Prior): - xmin: Array - xmax: Array + xmin: float = 0.0 + xmax: float = 1.0 - def __init__(self, xmin: Union[float,Array], xmax: Union[float,Array], **kwargs): - super().__init__(kwargs.get("naming"), kwargs.get("transforms")) - self.xmax = jnp.array(xmax) - self.xmin = jnp.array(xmin) - - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Uniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + """ + Sample from a uniform distribution. + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.uniform( + rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax + ) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + output = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + return output + jnp.log(1.0 / (self.xmax - self.xmin)) + + +class Unconstrained_Uniform(Prior): + + xmin: float = 0.0 + xmax: float = 1.0 + to_range: Callable = lambda x: x + + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + local_transform = self.transforms + self.to_range = ( + lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + + self.xmin + ) + + def new_transform(param): + param[self.naming[0]] = self.to_range(param) + return local_transform[self.naming[0]][1](param) + + self.transforms = { + self.naming[0]: (local_transform[self.naming[0]][0], new_transform) + } + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ Sample from a uniform distribution. @@ -102,13 +183,86 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: Returns ------- - samples : Array + samples : An array of shape (n_samples, n_dim) containing the samples. - + """ - samples = jax.random.uniform(rng_key, (n_samples,self.n_dim), minval=self.xmin, maxval=self.xmax) - return samples # TODO: remember to cast this to a named array + samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) + samples = jnp.log(samples / (1 - samples)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) + + +class Sphere(Prior): + + """ + A prior on a sphere represented by Cartesian coordinates. + + Magnitude is sampled from a uniform distribution. + """ + + def __init__(self, naming: str): + self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] + self.transforms = { + self.naming[0]: ( + f"{naming}_x", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.cos(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[1]: ( + f"{naming}_y", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.sin(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[2]: ( + f"{naming}_z", + lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], + ), + } + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + rng_keys = jax.random.split(rng_key, 3) + theta = jnp.arccos( + jax.random.uniform( + rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 + ) + ) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) + mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) + return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) + + def log_prob(self, x: dict) -> Float: + return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) + + +class Composite(Prior): + + priors: list[Prior] = field(default_factory=list) + + def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}): + naming = [] + self.transforms = {} + for prior in priors: + naming += prior.naming + self.transforms.update(prior.transforms) + self.priors = priors + self.naming = naming + self.transforms.update(transforms) + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output - def log_prob(self, x: Array) -> Float: - output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x))) - return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) + def log_prob(self, x: dict) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output diff --git a/test/test_prior.py b/test/test_prior.py new file mode 100644 index 00000000..d4c5be59 --- /dev/null +++ b/test/test_prior.py @@ -0,0 +1 @@ +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite