Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor changes for priors & adding prerequisite #90

Merged
merged 21 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ packages = find:
install_requires =
jax>=0.4.12
jaxlib>=0.4.12
flowMC>=0.2.4
flowMC>=0.3.4
ripplegw
gwpy
corner
astropy
typed-argument-parser
jaxtyping>=0.2.31
beartype
python_requires = >=3.9

Expand Down
30 changes: 19 additions & 11 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,9 @@ class Sphere(Prior):
def __repr__(self):
return f"Sphere(naming={self.naming})"

def __init__(self, naming: str, **kwargs):
self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"]
def __init__(self, naming: list[str], **kwargs):
name = naming[0]
self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"]
self.transforms = {
self.naming[0]: (
f"{naming}_x",
Expand Down Expand Up @@ -402,24 +403,30 @@ class EarthFrame(Prior):
Prior distribution for sky location in Earth frame.
"""

ifos: list = field(default_factory=list)
gmst: float = 0.0
delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3))

def __repr__(self):
return f"EarthFrame(naming={self.naming})"

def __init__(self, naming: str, gps: Float, ifos: list, **kwargs):
self.naming = ["azimuth", "zenith"]
def __init__(self, gps: Float, ifos: list, **kwargs):
self.naming = ["zenith", "azimuth"]
if len(ifos) < 2:
return ValueError(
"At least two detectors are needed to define the Earth frame"
)
elif isinstance(ifos[0], str):
self.ifos = [detector_preset[ifo] for ifo in ifos[:2]]
self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]]
elif isinstance(ifos[0], GroundBased2G):
self.ifos = ifos[:2]
self.ifos = ifos[:1]
else:
return ValueError(
"ifos should be a list of detector names or Detector objects"
"ifos should be a list of detector names or GroundBased2G objects"
)
self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
self.gmst = float(
Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
)
self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

self.transforms = {
Expand Down Expand Up @@ -453,16 +460,17 @@ def sample(
azimuth = jax.random.uniform(
rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
)
return self.add_name(jnp.stack([azimuth, zenith], axis=1).T)
return self.add_name(jnp.stack([zenith, azimuth], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
zenith = x["zenith"]
azimuth = x["azimuth"]
output = jnp.where(
(azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0),
(zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0),
jnp.zeros_like(0) - jnp.inf,
jnp.zeros_like(0),
)
return output
return output + jnp.log(jnp.sin(zenith))


@jaxtyped(typechecker=typechecker)
Expand Down
14 changes: 11 additions & 3 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,17 @@ def initialize_prior(self) -> prior.Prior:
for name, parameters in self.run.priors.items():
if parameters["name"] not in prior_presets:
raise ValueError(f"Prior {name} not recognized.")
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
if parameters["name"] == "EarthFrame":
priors.append(
prior.EarthFrame(
gps=self.run.data_parameters["trigger_time"],
ifos=self.run.detectors,
)
)
else:
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
return prior.Composite(priors)

def initialize_detector(self) -> list[Detector]:
Expand Down
8 changes: 4 additions & 4 deletions src/jimgw/single_event/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import jax.numpy as jnp
from jaxtyping import Array, Float
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc
from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc
from ripplegw.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripplegw.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
from ripplegw.waveforms.TaylorF2 import gen_TaylorF2_hphc
from ripplegw.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc


class Waveform(ABC):
Expand Down
Loading