From ee7bc8aa2d4e338e0f62edbf90caec7b42f2b5bb Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 4 Dec 2023 15:50:22 +0100 Subject: [PATCH 1/4] Adding Powerlaw prior class --- src/jimgw/prior.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2137b018..b58436b8 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -240,6 +240,79 @@ def log_prob(self, x: dict) -> Float: return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) +class Powerlaw(Prior): + + """ + A prior following the power-law with alpha in the range [xmin, xmax). + p(x) ~ x^{\alpha} + """ + + xmin: float = 0.0 + xmax: float = 1.0 + alpha: int = 0.0 + + def __init__( + self, + xmin: float, + xmax: float, + alpha: int | float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert isinstance(alpha, (int, float)), "alpha must be a int or a float" + if alpha < 0.: + assert alpha < 0. or xmin > 0., "With negative alpha, xmin must > 0" + assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + """ + Sample from a power-law distribution. + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform( + rng_key, (n_samples,), minval=0., maxval=1. + ) + if self.alpha == -1: + samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) + else: + samples = (self.xmin ** (1. + self.alpha) + q_samples * + (self.xmax ** (1. + self.alpha) - self.xmin ** (1. + self.alpha))) ** (1. / (1. + self.alpha)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + if self.alpha == -1: + normalization_constant = 1. / jnp.log(self.xmax / self.xmin) + else: + normalization_constant = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) - + self.xmin ** (1 + self.alpha)) + log_in_range = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + log_p = self.alpha * jnp.log(variable) + jnp.log(normalization_constant) + return log_p + log_in_range + + class Composite(Prior): priors: list[Prior] = field(default_factory=list) From 43b6525fd0cf91f502056ecc4efdd3cef35b368a Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 4 Dec 2023 18:28:30 +0100 Subject: [PATCH 2/4] Adding Alignedspin prior --- src/jimgw/prior.py | 102 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b58436b8..d8e6369a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -240,6 +240,108 @@ def log_prob(self, x: dict) -> Float: return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) +class Alignedspin(Prior): + + """ + Prior distribution for the aligned (z) component of the spin. + + This assume the prior distribution on the spin magnitude to be uniform in [0, amax] + with its orientation uniform on a sphere + + p(chi) = -log(|chi| / amax) / 2 / amax + + This is useful when comparing results between an aligned-spin run and + a precessing spin run. + + See (A7) of https://arxiv.org/abs/1805.10457. + """ + + amax: float = 0.99 + chi_axis: Array = jnp.linspace(0, 1, num=1000) + cdf_vals: Array = jnp.linspace(0, 1, num=1000) + + def __init__( + self, + amax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) + assert isinstance(amax, float), "xmin must be a float" + assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" + self.amax = amax + + # build the interpolation table for the ppf of the one-sided distribution + chi_axis = jnp.linspace(1e-31, self.amax, num=1000) + cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.) / self.amax + self.chi_axis = chi_axis + self.cdf_vals = cdf_vals + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + """ + Sample from the Alignedspin distribution. + + for chi > 0; + p(chi) = -log(chi / amax) / amax # halved normalization constant + cdf(chi) = -chi * (log(chi / amax) - 1) / amax + + Since there is a pole at chi=0, we will sample with the following steps + 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise + 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) + 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) + 3. Map the quantile to chi via the ppf by checking against the table + built during the initialization + 4. add back the sign + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform( + rng_key, (n_samples,), minval=0., maxval=1. + ) + # 1. calculate the sign of chi from the q_samples + sign_samples = jnp.where( + q_samples >= 0.5, + jnp.zeros_like(q_samples) + 1., + jnp.zeros_like(q_samples) - 1., + ) + # 2. remap q_samples + q_samples = jnp.where( + q_samples >=0.5, + 2 * (q_samples - 0.5), + 2 * (0.5 - q_samples), + ) + # 3. map the quantile to chi via interpolation + samples = jnp.interp( + q_samples, + self.cdf_vals, + self.chi_axis, + ) + # 4. add back the sign + samples *= sign_samples + + return self.add_name(samples[None]) + + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + log_p = jnp.where( + (variable >= self.amax) | (variable <= -self.amax), + jnp.zeros_like(variable) - jnp.inf, + jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2. / self.amax), + ) + return log_p + + class Powerlaw(Prior): """ From cb8d95b2c5bb9a207d4a22c344e1c66335b00b08 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 4 Dec 2023 15:57:37 -0500 Subject: [PATCH 3/4] Put normaliziation compute in initialization --- src/jimgw/prior.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index d8e6369a..2b0981d6 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -352,6 +352,7 @@ class Powerlaw(Prior): xmin: float = 0.0 xmax: float = 1.0 alpha: int = 0.0 + normalization: float = 1.0 def __init__( self, @@ -371,6 +372,11 @@ def __init__( self.xmax = xmax self.xmin = xmin self.alpha = alpha + if alpha == -1: + self.normalization = 1. / jnp.log(self.xmax / self.xmin) + else: + self.normalization = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) - + self.xmin ** (1 + self.alpha)) def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ @@ -401,17 +407,12 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - if self.alpha == -1: - normalization_constant = 1. / jnp.log(self.xmax / self.xmin) - else: - normalization_constant = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) - - self.xmin ** (1 + self.alpha)) log_in_range = jnp.where( (variable >= self.xmax) | (variable <= self.xmin), jnp.zeros_like(variable) - jnp.inf, jnp.zeros_like(variable), ) - log_p = self.alpha * jnp.log(variable) + jnp.log(normalization_constant) + log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) return log_p + log_in_range From 76bb56acbb2076e7855363a984c1bb0464c8dc20 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 4 Dec 2023 16:05:52 -0500 Subject: [PATCH 4/4] update to adhere to python 3.9 and 3.11 convention --- src/jimgw/prior.py | 57 ++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2b0981d6..aa1366af 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -87,7 +87,6 @@ def logpdf(self, x: dict) -> Float: class Uniform(Prior): - xmin: float = 0.0 xmax: float = 1.0 @@ -138,7 +137,6 @@ def log_prob(self, x: dict) -> Float: class Unconstrained_Uniform(Prior): - xmin: float = 0.0 xmax: float = 1.0 to_range: Callable = lambda x: x @@ -228,11 +226,9 @@ def __init__(self, naming: str): def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: rng_keys = jax.random.split(rng_key, 3) theta = jnp.arccos( - jax.random.uniform( - rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 - ) + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) @@ -257,8 +253,8 @@ class Alignedspin(Prior): """ amax: float = 0.99 - chi_axis: Array = jnp.linspace(0, 1, num=1000) - cdf_vals: Array = jnp.linspace(0, 1, num=1000) + chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) def __init__( self, @@ -273,7 +269,7 @@ def __init__( # build the interpolation table for the ppf of the one-sided distribution chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.) / self.amax + cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax self.chi_axis = chi_axis self.cdf_vals = cdf_vals @@ -306,18 +302,16 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: Samples from the distribution. The keys are the names of the parameters. """ - q_samples = jax.random.uniform( - rng_key, (n_samples,), minval=0., maxval=1. - ) + q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) # 1. calculate the sign of chi from the q_samples sign_samples = jnp.where( q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1., - jnp.zeros_like(q_samples) - 1., + jnp.zeros_like(q_samples) + 1.0, + jnp.zeros_like(q_samples) - 1.0, ) # 2. remap q_samples q_samples = jnp.where( - q_samples >=0.5, + q_samples >= 0.5, 2 * (q_samples - 0.5), 2 * (0.5 - q_samples), ) @@ -337,7 +331,7 @@ def log_prob(self, x: dict) -> Float: log_p = jnp.where( (variable >= self.amax) | (variable <= -self.amax), jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2. / self.amax), + jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), ) return log_p @@ -358,25 +352,26 @@ def __init__( self, xmin: float, xmax: float, - alpha: int | float, + alpha: float, naming: list[str], transforms: dict[tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) assert isinstance(xmin, float), "xmin must be a float" assert isinstance(xmax, float), "xmax must be a float" - assert isinstance(alpha, (int, float)), "alpha must be a int or a float" - if alpha < 0.: - assert alpha < 0. or xmin > 0., "With negative alpha, xmin must > 0" + assert isinstance(alpha, (float)), "alpha must be a float" + if alpha < 0.0: + assert alpha < 0.0 or xmin > 0.0, "With negative alpha, xmin must > 0" assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" self.xmax = xmax self.xmin = xmin self.alpha = alpha if alpha == -1: - self.normalization = 1. / jnp.log(self.xmax / self.xmin) + self.normalization = 1.0 / jnp.log(self.xmax / self.xmin) else: - self.normalization = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) - - self.xmin ** (1 + self.alpha)) + self.normalization = (1 + self.alpha) / ( + self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) + ) def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ @@ -395,14 +390,15 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: Samples from the distribution. The keys are the names of the parameters. """ - q_samples = jax.random.uniform( - rng_key, (n_samples,), minval=0., maxval=1. - ) + q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) if self.alpha == -1: samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) else: - samples = (self.xmin ** (1. + self.alpha) + q_samples * - (self.xmax ** (1. + self.alpha) - self.xmin ** (1. + self.alpha))) ** (1. / (1. + self.alpha)) + samples = ( + self.xmin ** (1.0 + self.alpha) + + q_samples + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) return self.add_name(samples[None]) def log_prob(self, x: dict) -> Float: @@ -417,10 +413,11 @@ def log_prob(self, x: dict) -> Float: class Composite(Prior): - priors: list[Prior] = field(default_factory=list) - def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}): + def __init__( + self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {} + ): naming = [] self.transforms = {} for prior in priors: