From 880520554796104d1f0762ad4d839a7182a0c313 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 5 Sep 2024 23:36:45 +0200 Subject: [PATCH] Revert "Remove if-else function" This reverts commit 762b7e0f7738284f8943fb842eab430e709a4a93. --- src/jimgw/single_event/transforms.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 0af0fe42..ad91a56a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,4 +1,3 @@ -import jax.lax import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -261,10 +260,10 @@ def __init__( ) ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -276,9 +275,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -371,10 +368,10 @@ def __init__( and "M_c" in conditional_names ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -386,9 +383,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota):