From 25dff03e1a4c97ca828fabea2b4ddd1ecd6ad0b9 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 23 May 2024 02:23:54 -0700 Subject: [PATCH] fixing precommit complaints --- example/GW170817_TaylorF2.py | 2 +- src/jimgw/single_event/likelihood.py | 3 ++- src/jimgw/single_event/waveform.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/example/GW170817_TaylorF2.py b/example/GW170817_TaylorF2.py index 554e0679..499f2d6d 100644 --- a/example/GW170817_TaylorF2.py +++ b/example/GW170817_TaylorF2.py @@ -2,7 +2,7 @@ p = psutil.Process() p.cpu_affinity([0]) import os -os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["CUDA_VISIBLE_DEVICES"] = "2" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" from jimgw.jim import Jim from jimgw.single_event.detector import H1, L1, V1 diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index d6347f9e..295e2694 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -6,6 +6,7 @@ from flowMC.strategy.optimization import optimization_Adam from jax.scipy.special import logsumexp from jaxtyping import Array, Float +from typing import Optional from scipy.interpolate import interp1d from jimgw.base import LikelihoodBase @@ -192,7 +193,7 @@ def __init__( popsize: int = 100, n_steps: int = 2000, ref_params: dict = {}, - reference_waveform: Waveform = None, + reference_waveform: Optional[Waveform] = None, **kwargs, ) -> None: super().__init__( diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index 29ba3a00..aa3ae6e9 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -95,11 +95,11 @@ def __call__(self, frequency: Array, params: dict) -> dict: output = {} if self.use_lambda_tildes: - first_lambda_param = params["lambda_tilde"] - second_lambda_param = params["delta_lambda_tilde"] + first_lambda_param = jnp.array(params["lambda_tilde"]) + second_lambda_param = jnp.array(params["delta_lambda_tilde"]) else: - first_lambda_param = params["lambda_1"] - second_lambda_param = params["lambda_2"] + first_lambda_param = jnp.array(params["lambda_1"]) + second_lambda_param = jnp.array(params["lambda_2"]) theta = [ params["M_c"], @@ -143,11 +143,11 @@ def __call__(self, frequency: Array, params: dict) -> dict: output = {} if self.use_lambda_tildes: - first_lambda_param = params["lambda_tilde"] - second_lambda_param = params["delta_lambda_tilde"] + first_lambda_param = jnp.array(params["lambda_tilde"]) + second_lambda_param = jnp.array(params["delta_lambda_tilde"]) else: - first_lambda_param = params["lambda_1"] - second_lambda_param = params["lambda_2"] + first_lambda_param = jnp.array(params["lambda_1"]) + second_lambda_param = jnp.array(params["lambda_2"]) theta = [ params["M_c"],