diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 480f98f3..df9df401 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,5 +1,6 @@ import jax.numpy as jnp from jax.scipy.special import i0e, logsumexp +from jax.scipy.integrate import trapezoid from jax import jit from jaxtyping import Float, Array @@ -35,7 +36,7 @@ def inner_product( # psd_interp = jnp.interp(frequency, psd_frequency, psd) df = frequency[1] - frequency[0] integrand = jnp.conj(h1) * h2 / psd - return 4.0 * jnp.real(jnp.trapz(integrand, dx=df)) + return 4.0 * jnp.real(trapezoid(integrand, dx=df)) @jit