Skip to content

Commit

Permalink
Merge pull request #90 from thomasckng/main
Browse files Browse the repository at this point in the history
Minor changes for priors & adding prerequisite
  • Loading branch information
kazewong authored Jul 22, 2024
2 parents fe0e767 + 395d2a6 commit 2a51d8c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ packages = find:
install_requires =
jax>=0.4.12
jaxlib>=0.4.12
flowMC>=0.2.4
flowMC>=0.3.4
ripplegw
gwpy
corner
astropy
typed-argument-parser
jaxtyping>=0.2.31
beartype
python_requires = >=3.9

Expand Down
30 changes: 19 additions & 11 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,9 @@ class Sphere(Prior):
def __repr__(self):
return f"Sphere(naming={self.naming})"

def __init__(self, naming: str, **kwargs):
self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"]
def __init__(self, naming: list[str], **kwargs):
name = naming[0]
self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"]
self.transforms = {
self.naming[0]: (
f"{naming}_x",
Expand Down Expand Up @@ -402,24 +403,30 @@ class EarthFrame(Prior):
Prior distribution for sky location in Earth frame.
"""

ifos: list = field(default_factory=list)
gmst: float = 0.0
delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3))

def __repr__(self):
return f"EarthFrame(naming={self.naming})"

def __init__(self, naming: str, gps: Float, ifos: list, **kwargs):
self.naming = ["azimuth", "zenith"]
def __init__(self, gps: Float, ifos: list, **kwargs):
self.naming = ["zenith", "azimuth"]
if len(ifos) < 2:
return ValueError(
"At least two detectors are needed to define the Earth frame"
)
elif isinstance(ifos[0], str):
self.ifos = [detector_preset[ifo] for ifo in ifos[:2]]
self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]]
elif isinstance(ifos[0], GroundBased2G):
self.ifos = ifos[:2]
self.ifos = ifos[:1]
else:
return ValueError(
"ifos should be a list of detector names or Detector objects"
"ifos should be a list of detector names or GroundBased2G objects"
)
self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
self.gmst = float(
Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
)
self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

self.transforms = {
Expand Down Expand Up @@ -453,16 +460,17 @@ def sample(
azimuth = jax.random.uniform(
rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
)
return self.add_name(jnp.stack([azimuth, zenith], axis=1).T)
return self.add_name(jnp.stack([zenith, azimuth], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
zenith = x["zenith"]
azimuth = x["azimuth"]
output = jnp.where(
(azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0),
(zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0),
jnp.zeros_like(0) - jnp.inf,
jnp.zeros_like(0),
)
return output
return output + jnp.log(jnp.sin(zenith))


@jaxtyped(typechecker=typechecker)
Expand Down
14 changes: 11 additions & 3 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,17 @@ def initialize_prior(self) -> prior.Prior:
for name, parameters in self.run.priors.items():
if parameters["name"] not in prior_presets:
raise ValueError(f"Prior {name} not recognized.")
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
if parameters["name"] == "EarthFrame":
priors.append(
prior.EarthFrame(
gps=self.run.data_parameters["trigger_time"],
ifos=self.run.detectors,
)
)
else:
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
return prior.Composite(priors)

def initialize_detector(self) -> list[Detector]:
Expand Down
8 changes: 4 additions & 4 deletions src/jimgw/single_event/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import jax.numpy as jnp
from jaxtyping import Array, Float
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc
from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc
from ripplegw.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripplegw.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
from ripplegw.waveforms.TaylorF2 import gen_TaylorF2_hphc
from ripplegw.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc


class Waveform(ABC):
Expand Down

0 comments on commit 2a51d8c

Please sign in to comment.