Skip to content

Commit

Permalink
fixing precommit complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibeau Wouters committed May 23, 2024
1 parent 5c99869 commit 25dff03
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion example/GW170817_TaylorF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
p = psutil.Process()
p.cpu_affinity([0])
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10"
from jimgw.jim import Jim
from jimgw.single_event.detector import H1, L1, V1
Expand Down
3 changes: 2 additions & 1 deletion src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flowMC.strategy.optimization import optimization_Adam
from jax.scipy.special import logsumexp
from jaxtyping import Array, Float
from typing import Optional
from scipy.interpolate import interp1d

from jimgw.base import LikelihoodBase
Expand Down Expand Up @@ -192,7 +193,7 @@ def __init__(
popsize: int = 100,
n_steps: int = 2000,
ref_params: dict = {},
reference_waveform: Waveform = None,
reference_waveform: Optional[Waveform] = None,
**kwargs,
) -> None:
super().__init__(
Expand Down
16 changes: 8 additions & 8 deletions src/jimgw/single_event/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def __call__(self, frequency: Array, params: dict) -> dict:
output = {}

if self.use_lambda_tildes:
first_lambda_param = params["lambda_tilde"]
second_lambda_param = params["delta_lambda_tilde"]
first_lambda_param = jnp.array(params["lambda_tilde"])
second_lambda_param = jnp.array(params["delta_lambda_tilde"])
else:
first_lambda_param = params["lambda_1"]
second_lambda_param = params["lambda_2"]
first_lambda_param = jnp.array(params["lambda_1"])
second_lambda_param = jnp.array(params["lambda_2"])

theta = [
params["M_c"],
Expand Down Expand Up @@ -143,11 +143,11 @@ def __call__(self, frequency: Array, params: dict) -> dict:
output = {}

if self.use_lambda_tildes:
first_lambda_param = params["lambda_tilde"]
second_lambda_param = params["delta_lambda_tilde"]
first_lambda_param = jnp.array(params["lambda_tilde"])
second_lambda_param = jnp.array(params["delta_lambda_tilde"])
else:
first_lambda_param = params["lambda_1"]
second_lambda_param = params["lambda_2"]
first_lambda_param = jnp.array(params["lambda_1"])
second_lambda_param = jnp.array(params["lambda_2"])

theta = [
params["M_c"],
Expand Down

0 comments on commit 25dff03

Please sign in to comment.