Skip to content

Commit

Permalink
Merge pull request #85 from ThibeauWouters/normal
Browse files Browse the repository at this point in the history
Added Normal prior
  • Loading branch information
kazewong authored Jul 1, 2024
2 parents 9ba9715 + 113dd0f commit 5a5e45d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ install_requires =
jax>=0.4.12
jaxlib>=0.4.12
flowMC>=0.2.4
numpy>=1,<2
ripplegw
gwpy
corner
Expand Down
54 changes: 54 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,60 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return log_p + log_in_range


@jaxtyped(typechecker=typechecker)
class Normal(Prior):
mean: Float = 0.0
std: Float = 1.0

def __repr__(self):
return f"Normal(mean={self.mean}, std={self.std})"

def __init__(
self,
mean: Float,
std: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "Normal needs to be 1D distributions"
self.mean = mean
self.std = std

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a normal distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.normal(rng_key, (n_samples,))
samples = self.mean + samples * self.std
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = (
-0.5 * jnp.log(2 * jnp.pi)
- jnp.log(self.std)
- 0.5 * ((variable - self.mean) / self.std) ** 2
)
return output


class Composite(Prior):
priors: list[Prior] = field(default_factory=list)

Expand Down
4 changes: 2 additions & 2 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def evaluate_original(

@staticmethod
def max_phase_diff(
f: npt.NDArray[np.float_],
f: npt.NDArray[np.floating],
f_low: float,
f_high: float,
chi: Float = 1.0,
Expand Down Expand Up @@ -469,7 +469,7 @@ def max_phase_diff(
return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1)

def make_binning_scheme(
self, freqs: npt.NDArray[np.float_], n_bins: int, chi: float = 1
self, freqs: npt.NDArray[np.floating], n_bins: int, chi: float = 1
) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]:
"""
Make a binning scheme based on the maximum phase difference between the
Expand Down

0 comments on commit 5a5e45d

Please sign in to comment.