Skip to content

Commit 151c219

Browse files
authored
Merge pull request #45 from tsunhopang/main
Adding additional prior class
2 parents e59fa02 + 76bb56a commit 151c219

File tree

1 file changed

+181
-8
lines changed

1 file changed

+181
-8
lines changed

src/jimgw/prior.py

+181-8
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def logpdf(self, x: dict) -> Float:
8787

8888

8989
class Uniform(Prior):
90-
9190
xmin: float = 0.0
9291
xmax: float = 1.0
9392

@@ -138,7 +137,6 @@ def log_prob(self, x: dict) -> Float:
138137

139138

140139
class Unconstrained_Uniform(Prior):
141-
142140
xmin: float = 0.0
143141
xmax: float = 1.0
144142
to_range: Callable = lambda x: x
@@ -228,23 +226,198 @@ def __init__(self, naming: str):
228226
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
229227
rng_keys = jax.random.split(rng_key, 3)
230228
theta = jnp.arccos(
231-
jax.random.uniform(
232-
rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0
233-
)
229+
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
234230
)
235-
phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi)
231+
phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi)
236232
mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1)
237233
return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)
238234

239235
def log_prob(self, x: dict) -> Float:
240236
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]]))
241237

242238

243-
class Composite(Prior):
239+
class Alignedspin(Prior):
240+
241+
"""
242+
Prior distribution for the aligned (z) component of the spin.
243+
244+
This assume the prior distribution on the spin magnitude to be uniform in [0, amax]
245+
with its orientation uniform on a sphere
246+
247+
p(chi) = -log(|chi| / amax) / 2 / amax
248+
249+
This is useful when comparing results between an aligned-spin run and
250+
a precessing spin run.
251+
252+
See (A7) of https://arxiv.org/abs/1805.10457.
253+
"""
254+
255+
amax: float = 0.99
256+
chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))
257+
cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))
258+
259+
def __init__(
260+
self,
261+
amax: float,
262+
naming: list[str],
263+
transforms: dict[tuple[str, Callable]] = {},
264+
):
265+
super().__init__(naming, transforms)
266+
assert isinstance(amax, float), "xmin must be a float"
267+
assert self.n_dim == 1, "Alignedspin needs to be 1D distributions"
268+
self.amax = amax
269+
270+
# build the interpolation table for the ppf of the one-sided distribution
271+
chi_axis = jnp.linspace(1e-31, self.amax, num=1000)
272+
cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax
273+
self.chi_axis = chi_axis
274+
self.cdf_vals = cdf_vals
275+
276+
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
277+
"""
278+
Sample from the Alignedspin distribution.
279+
280+
for chi > 0;
281+
p(chi) = -log(chi / amax) / amax # halved normalization constant
282+
cdf(chi) = -chi * (log(chi / amax) - 1) / amax
283+
284+
Since there is a pole at chi=0, we will sample with the following steps
285+
1. Map the samples with quantile > 0.5 to positive chi and negative otherwise
286+
2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q)
287+
2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5)
288+
3. Map the quantile to chi via the ppf by checking against the table
289+
built during the initialization
290+
4. add back the sign
291+
292+
Parameters
293+
----------
294+
rng_key : jax.random.PRNGKey
295+
A random key to use for sampling.
296+
n_samples : int
297+
The number of samples to draw.
298+
299+
Returns
300+
-------
301+
samples : dict
302+
Samples from the distribution. The keys are the names of the parameters.
303+
304+
"""
305+
q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
306+
# 1. calculate the sign of chi from the q_samples
307+
sign_samples = jnp.where(
308+
q_samples >= 0.5,
309+
jnp.zeros_like(q_samples) + 1.0,
310+
jnp.zeros_like(q_samples) - 1.0,
311+
)
312+
# 2. remap q_samples
313+
q_samples = jnp.where(
314+
q_samples >= 0.5,
315+
2 * (q_samples - 0.5),
316+
2 * (0.5 - q_samples),
317+
)
318+
# 3. map the quantile to chi via interpolation
319+
samples = jnp.interp(
320+
q_samples,
321+
self.cdf_vals,
322+
self.chi_axis,
323+
)
324+
# 4. add back the sign
325+
samples *= sign_samples
326+
327+
return self.add_name(samples[None])
328+
329+
def log_prob(self, x: dict) -> Float:
330+
variable = x[self.naming[0]]
331+
log_p = jnp.where(
332+
(variable >= self.amax) | (variable <= -self.amax),
333+
jnp.zeros_like(variable) - jnp.inf,
334+
jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax),
335+
)
336+
return log_p
337+
244338

339+
class Powerlaw(Prior):
340+
341+
"""
342+
A prior following the power-law with alpha in the range [xmin, xmax).
343+
p(x) ~ x^{\alpha}
344+
"""
345+
346+
xmin: float = 0.0
347+
xmax: float = 1.0
348+
alpha: int = 0.0
349+
normalization: float = 1.0
350+
351+
def __init__(
352+
self,
353+
xmin: float,
354+
xmax: float,
355+
alpha: float,
356+
naming: list[str],
357+
transforms: dict[tuple[str, Callable]] = {},
358+
):
359+
super().__init__(naming, transforms)
360+
assert isinstance(xmin, float), "xmin must be a float"
361+
assert isinstance(xmax, float), "xmax must be a float"
362+
assert isinstance(alpha, (float)), "alpha must be a float"
363+
if alpha < 0.0:
364+
assert alpha < 0.0 or xmin > 0.0, "With negative alpha, xmin must > 0"
365+
assert self.n_dim == 1, "Powerlaw needs to be 1D distributions"
366+
self.xmax = xmax
367+
self.xmin = xmin
368+
self.alpha = alpha
369+
if alpha == -1:
370+
self.normalization = 1.0 / jnp.log(self.xmax / self.xmin)
371+
else:
372+
self.normalization = (1 + self.alpha) / (
373+
self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha)
374+
)
375+
376+
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
377+
"""
378+
Sample from a power-law distribution.
379+
380+
Parameters
381+
----------
382+
rng_key : jax.random.PRNGKey
383+
A random key to use for sampling.
384+
n_samples : int
385+
The number of samples to draw.
386+
387+
Returns
388+
-------
389+
samples : dict
390+
Samples from the distribution. The keys are the names of the parameters.
391+
392+
"""
393+
q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
394+
if self.alpha == -1:
395+
samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin))
396+
else:
397+
samples = (
398+
self.xmin ** (1.0 + self.alpha)
399+
+ q_samples
400+
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
401+
) ** (1.0 / (1.0 + self.alpha))
402+
return self.add_name(samples[None])
403+
404+
def log_prob(self, x: dict) -> Float:
405+
variable = x[self.naming[0]]
406+
log_in_range = jnp.where(
407+
(variable >= self.xmax) | (variable <= self.xmin),
408+
jnp.zeros_like(variable) - jnp.inf,
409+
jnp.zeros_like(variable),
410+
)
411+
log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization)
412+
return log_p + log_in_range
413+
414+
415+
class Composite(Prior):
245416
priors: list[Prior] = field(default_factory=list)
246417

247-
def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}):
418+
def __init__(
419+
self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}
420+
):
248421
naming = []
249422
self.transforms = {}
250423
for prior in priors:

0 commit comments

Comments
 (0)