diff --git a/setup.cfg b/setup.cfg index 357b339b..e2fbd986 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,12 +17,13 @@ packages = find: install_requires = jax>=0.4.12 jaxlib>=0.4.12 - flowMC>=0.2.4 + flowMC>=0.3.4 ripplegw gwpy corner astropy typed-argument-parser + jaxtyping>=0.2.31 beartype python_requires = >=3.9 diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e5675c8f..f1827753 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -234,8 +234,9 @@ class Sphere(Prior): def __repr__(self): return f"Sphere(naming={self.naming})" - def __init__(self, naming: str, **kwargs): - self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] + def __init__(self, naming: list[str], **kwargs): + name = naming[0] + self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] self.transforms = { self.naming[0]: ( f"{naming}_x", @@ -402,24 +403,30 @@ class EarthFrame(Prior): Prior distribution for sky location in Earth frame. """ + ifos: list = field(default_factory=list) + gmst: float = 0.0 + delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + def __repr__(self): return f"EarthFrame(naming={self.naming})" - def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): - self.naming = ["azimuth", "zenith"] + def __init__(self, gps: Float, ifos: list, **kwargs): + self.naming = ["zenith", "azimuth"] if len(ifos) < 2: return ValueError( "At least two detectors are needed to define the Earth frame" ) elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifo] for ifo in ifos[:2]] + self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:2] + self.ifos = ifos[:1] else: return ValueError( - "ifos should be a list of detector names or Detector objects" + "ifos should be a list of detector names or GroundBased2G objects" ) - self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + self.gmst = float( + Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + ) self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex self.transforms = { @@ -453,16 +460,17 @@ def sample( azimuth = jax.random.uniform( rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi ) - return self.add_name(jnp.stack([azimuth, zenith], axis=1).T) + return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) def log_prob(self, x: dict[str, Float]) -> Float: zenith = x["zenith"] azimuth = x["azimuth"] output = jnp.where( - (azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0), + (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), jnp.zeros_like(0) - jnp.inf, + jnp.zeros_like(0), ) - return output + return output + jnp.log(jnp.sin(zenith)) @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 56eb9925..aeb2304a 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -196,9 +196,17 @@ def initialize_prior(self) -> prior.Prior: for name, parameters in self.run.priors.items(): if parameters["name"] not in prior_presets: raise ValueError(f"Prior {name} not recognized.") - priors.append( - prior_presets[parameters["name"]](naming=[name], **parameters) - ) + if parameters["name"] == "EarthFrame": + priors.append( + prior.EarthFrame( + gps=self.run.data_parameters["trigger_time"], + ifos=self.run.detectors, + ) + ) + else: + priors.append( + prior_presets[parameters["name"]](naming=[name], **parameters) + ) return prior.Composite(priors) def initialize_detector(self) -> list[Detector]: diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index e2084ea1..6edbfd3f 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -2,10 +2,10 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc -from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc -from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc -from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc +from ripplegw.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc +from ripplegw.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc +from ripplegw.waveforms.TaylorF2 import gen_TaylorF2_hphc +from ripplegw.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc class Waveform(ABC):