@@ -87,7 +87,6 @@ def logpdf(self, x: dict) -> Float:
87
87
88
88
89
89
class Uniform (Prior ):
90
-
91
90
xmin : float = 0.0
92
91
xmax : float = 1.0
93
92
@@ -138,7 +137,6 @@ def log_prob(self, x: dict) -> Float:
138
137
139
138
140
139
class Unconstrained_Uniform (Prior ):
141
-
142
140
xmin : float = 0.0
143
141
xmax : float = 1.0
144
142
to_range : Callable = lambda x : x
@@ -228,23 +226,198 @@ def __init__(self, naming: str):
228
226
def sample (self , rng_key : jax .random .PRNGKey , n_samples : int ) -> Array :
229
227
rng_keys = jax .random .split (rng_key , 3 )
230
228
theta = jnp .arccos (
231
- jax .random .uniform (
232
- rng_keys [0 ], (n_samples ,), minval = - 1.0 , maxval = 1.0
233
- )
229
+ jax .random .uniform (rng_keys [0 ], (n_samples ,), minval = - 1.0 , maxval = 1.0 )
234
230
)
235
- phi = jax .random .uniform (rng_keys [1 ], (n_samples ,), minval = 0 , maxval = 2 * jnp .pi )
231
+ phi = jax .random .uniform (rng_keys [1 ], (n_samples ,), minval = 0 , maxval = 2 * jnp .pi )
236
232
mag = jax .random .uniform (rng_keys [2 ], (n_samples ,), minval = 0 , maxval = 1 )
237
233
return self .add_name (jnp .stack ([theta , phi , mag ], axis = 1 ).T )
238
234
239
235
def log_prob (self , x : dict ) -> Float :
240
236
return jnp .log (x [self .naming [2 ]] ** 2 * jnp .sin (x [self .naming [0 ]]))
241
237
242
238
243
- class Composite (Prior ):
239
+ class Alignedspin (Prior ):
240
+
241
+ """
242
+ Prior distribution for the aligned (z) component of the spin.
243
+
244
+ This assume the prior distribution on the spin magnitude to be uniform in [0, amax]
245
+ with its orientation uniform on a sphere
246
+
247
+ p(chi) = -log(|chi| / amax) / 2 / amax
248
+
249
+ This is useful when comparing results between an aligned-spin run and
250
+ a precessing spin run.
251
+
252
+ See (A7) of https://arxiv.org/abs/1805.10457.
253
+ """
254
+
255
+ amax : float = 0.99
256
+ chi_axis : Array = field (default_factory = lambda : jnp .linspace (0 , 1 , num = 1000 ))
257
+ cdf_vals : Array = field (default_factory = lambda : jnp .linspace (0 , 1 , num = 1000 ))
258
+
259
+ def __init__ (
260
+ self ,
261
+ amax : float ,
262
+ naming : list [str ],
263
+ transforms : dict [tuple [str , Callable ]] = {},
264
+ ):
265
+ super ().__init__ (naming , transforms )
266
+ assert isinstance (amax , float ), "xmin must be a float"
267
+ assert self .n_dim == 1 , "Alignedspin needs to be 1D distributions"
268
+ self .amax = amax
269
+
270
+ # build the interpolation table for the ppf of the one-sided distribution
271
+ chi_axis = jnp .linspace (1e-31 , self .amax , num = 1000 )
272
+ cdf_vals = - chi_axis * (jnp .log (chi_axis / self .amax ) - 1.0 ) / self .amax
273
+ self .chi_axis = chi_axis
274
+ self .cdf_vals = cdf_vals
275
+
276
+ def sample (self , rng_key : jax .random .PRNGKey , n_samples : int ) -> dict :
277
+ """
278
+ Sample from the Alignedspin distribution.
279
+
280
+ for chi > 0;
281
+ p(chi) = -log(chi / amax) / amax # halved normalization constant
282
+ cdf(chi) = -chi * (log(chi / amax) - 1) / amax
283
+
284
+ Since there is a pole at chi=0, we will sample with the following steps
285
+ 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise
286
+ 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q)
287
+ 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5)
288
+ 3. Map the quantile to chi via the ppf by checking against the table
289
+ built during the initialization
290
+ 4. add back the sign
291
+
292
+ Parameters
293
+ ----------
294
+ rng_key : jax.random.PRNGKey
295
+ A random key to use for sampling.
296
+ n_samples : int
297
+ The number of samples to draw.
298
+
299
+ Returns
300
+ -------
301
+ samples : dict
302
+ Samples from the distribution. The keys are the names of the parameters.
303
+
304
+ """
305
+ q_samples = jax .random .uniform (rng_key , (n_samples ,), minval = 0.0 , maxval = 1.0 )
306
+ # 1. calculate the sign of chi from the q_samples
307
+ sign_samples = jnp .where (
308
+ q_samples >= 0.5 ,
309
+ jnp .zeros_like (q_samples ) + 1.0 ,
310
+ jnp .zeros_like (q_samples ) - 1.0 ,
311
+ )
312
+ # 2. remap q_samples
313
+ q_samples = jnp .where (
314
+ q_samples >= 0.5 ,
315
+ 2 * (q_samples - 0.5 ),
316
+ 2 * (0.5 - q_samples ),
317
+ )
318
+ # 3. map the quantile to chi via interpolation
319
+ samples = jnp .interp (
320
+ q_samples ,
321
+ self .cdf_vals ,
322
+ self .chi_axis ,
323
+ )
324
+ # 4. add back the sign
325
+ samples *= sign_samples
326
+
327
+ return self .add_name (samples [None ])
328
+
329
+ def log_prob (self , x : dict ) -> Float :
330
+ variable = x [self .naming [0 ]]
331
+ log_p = jnp .where (
332
+ (variable >= self .amax ) | (variable <= - self .amax ),
333
+ jnp .zeros_like (variable ) - jnp .inf ,
334
+ jnp .log (- jnp .log (jnp .absolute (variable ) / self .amax ) / 2.0 / self .amax ),
335
+ )
336
+ return log_p
337
+
244
338
339
+ class Powerlaw (Prior ):
340
+
341
+ """
342
+ A prior following the power-law with alpha in the range [xmin, xmax).
343
+ p(x) ~ x^{\a lpha}
344
+ """
345
+
346
+ xmin : float = 0.0
347
+ xmax : float = 1.0
348
+ alpha : int = 0.0
349
+ normalization : float = 1.0
350
+
351
+ def __init__ (
352
+ self ,
353
+ xmin : float ,
354
+ xmax : float ,
355
+ alpha : float ,
356
+ naming : list [str ],
357
+ transforms : dict [tuple [str , Callable ]] = {},
358
+ ):
359
+ super ().__init__ (naming , transforms )
360
+ assert isinstance (xmin , float ), "xmin must be a float"
361
+ assert isinstance (xmax , float ), "xmax must be a float"
362
+ assert isinstance (alpha , (float )), "alpha must be a float"
363
+ if alpha < 0.0 :
364
+ assert alpha < 0.0 or xmin > 0.0 , "With negative alpha, xmin must > 0"
365
+ assert self .n_dim == 1 , "Powerlaw needs to be 1D distributions"
366
+ self .xmax = xmax
367
+ self .xmin = xmin
368
+ self .alpha = alpha
369
+ if alpha == - 1 :
370
+ self .normalization = 1.0 / jnp .log (self .xmax / self .xmin )
371
+ else :
372
+ self .normalization = (1 + self .alpha ) / (
373
+ self .xmax ** (1 + self .alpha ) - self .xmin ** (1 + self .alpha )
374
+ )
375
+
376
+ def sample (self , rng_key : jax .random .PRNGKey , n_samples : int ) -> dict :
377
+ """
378
+ Sample from a power-law distribution.
379
+
380
+ Parameters
381
+ ----------
382
+ rng_key : jax.random.PRNGKey
383
+ A random key to use for sampling.
384
+ n_samples : int
385
+ The number of samples to draw.
386
+
387
+ Returns
388
+ -------
389
+ samples : dict
390
+ Samples from the distribution. The keys are the names of the parameters.
391
+
392
+ """
393
+ q_samples = jax .random .uniform (rng_key , (n_samples ,), minval = 0.0 , maxval = 1.0 )
394
+ if self .alpha == - 1 :
395
+ samples = self .xmin * jnp .exp (q_samples * jnp .log (self .xmax / self .xmin ))
396
+ else :
397
+ samples = (
398
+ self .xmin ** (1.0 + self .alpha )
399
+ + q_samples
400
+ * (self .xmax ** (1.0 + self .alpha ) - self .xmin ** (1.0 + self .alpha ))
401
+ ) ** (1.0 / (1.0 + self .alpha ))
402
+ return self .add_name (samples [None ])
403
+
404
+ def log_prob (self , x : dict ) -> Float :
405
+ variable = x [self .naming [0 ]]
406
+ log_in_range = jnp .where (
407
+ (variable >= self .xmax ) | (variable <= self .xmin ),
408
+ jnp .zeros_like (variable ) - jnp .inf ,
409
+ jnp .zeros_like (variable ),
410
+ )
411
+ log_p = self .alpha * jnp .log (variable ) + jnp .log (self .normalization )
412
+ return log_p + log_in_range
413
+
414
+
415
+ class Composite (Prior ):
245
416
priors : list [Prior ] = field (default_factory = list )
246
417
247
- def __init__ (self , priors : list [Prior ], transforms : dict [tuple [str , Callable ]] = {}):
418
+ def __init__ (
419
+ self , priors : list [Prior ], transforms : dict [tuple [str , Callable ]] = {}
420
+ ):
248
421
naming = []
249
422
self .transforms = {}
250
423
for prior in priors :
0 commit comments