Skip to content

Commit

Permalink
more precommit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibeau Wouters committed May 23, 2024
1 parent 25dff03 commit 8aef32c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
)

print("Initializing heterodyned likelihood..")

# Can use another waveform to use as reference waveform, but if not provided, use the same waveform
if reference_waveform is None:
reference_waveform = waveform
Expand Down
109 changes: 65 additions & 44 deletions src/jimgw/single_event/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc
from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc


class Waveform(ABC):
def __init__(self):
return NotImplemented
Expand Down Expand Up @@ -82,6 +83,7 @@ def __call__(
def __repr__(self):
return f"RippleIMRPhenomPv2(f_ref={self.f_ref})"


class RippleTaylorF2(Waveform):

f_ref: float
Expand All @@ -93,40 +95,50 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False):

def __call__(self, frequency: Array, params: dict) -> dict:
output = {}

if self.use_lambda_tildes:
first_lambda_param = jnp.array(params["lambda_tilde"])
second_lambda_param = jnp.array(params["delta_lambda_tilde"])
first_lambda_param = params["lambda_tilde"]
second_lambda_param = params["delta_lambda_tilde"]
else:
first_lambda_param = jnp.array(params["lambda_1"])
second_lambda_param = jnp.array(params["lambda_2"])

theta = [
params["M_c"],
params["eta"],
params["s1_z"],
params["s2_z"],
first_lambda_param,
second_lambda_param,
params["d_L"],
0,
params["phase_c"],
params["iota"],
]
hp, hc = gen_TaylorF2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes)
first_lambda_param = params["lambda_1"]
second_lambda_param = params["lambda_2"]

theta = jnp.array(
[
params["M_c"],
params["eta"],
params["s1_z"],
params["s2_z"],
first_lambda_param,
second_lambda_param,
params["d_L"],
0,
params["phase_c"],
params["iota"],
]
)
hp, hc = gen_TaylorF2_hphc(
frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes
)
output["p"] = hp
output["c"] = hc
return output

def __repr__(self):
return f"RippleTaylorF2(f_ref={self.f_ref})"



class RippleIMRPhenomD_NRTidalv2(Waveform):

f_ref: float
use_lambda_tildes: bool

def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_taper: bool = False):
def __init__(
self,
f_ref: float = 20.0,
use_lambda_tildes: bool = False,
no_taper: bool = False,
):
"""
Initialize the waveform.
Expand All @@ -141,35 +153,44 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_tape

def __call__(self, frequency: Array, params: dict) -> dict:
output = {}

if self.use_lambda_tildes:
first_lambda_param = jnp.array(params["lambda_tilde"])
second_lambda_param = jnp.array(params["delta_lambda_tilde"])
first_lambda_param = params["lambda_tilde"]
second_lambda_param = params["delta_lambda_tilde"]
else:
first_lambda_param = jnp.array(params["lambda_1"])
second_lambda_param = jnp.array(params["lambda_2"])

theta = [
params["M_c"],
params["eta"],
params["s1_z"],
params["s2_z"],
first_lambda_param,
second_lambda_param,
params["d_L"],
0,
params["phase_c"],
params["iota"],
]

hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes, no_taper=self.no_taper)
first_lambda_param = params["lambda_1"]
second_lambda_param = params["lambda_2"]

theta = jnp.array(
[
params["M_c"],
params["eta"],
params["s1_z"],
params["s2_z"],
first_lambda_param,
second_lambda_param,
params["d_L"],
0,
params["phase_c"],
params["iota"],
]
)

hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(
frequency,
theta,
self.f_ref,
use_lambda_tildes=self.use_lambda_tildes,
no_taper=self.no_taper,
)
output["p"] = hp
output["c"] = hc
return output

def __repr__(self):
return f"RippleIMRPhenomD_NRTidalv2(f_ref={self.f_ref})"



waveform_preset = {
"RippleIMRPhenomD": RippleIMRPhenomD,
"RippleIMRPhenomPv2": RippleIMRPhenomPv2,
Expand Down

0 comments on commit 8aef32c

Please sign in to comment.