Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre commit check #103

Merged
merged 7 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
[lint]
ignore = ["F722"]
2 changes: 1 addition & 1 deletion src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,8 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center):

def maximize_likelihood(
self,
bounds: Float[Array, " n_dim 2"],
prior: Prior,
bounds: Float[Array, " n_dim 2"],
popsize: int = 100,
n_steps: int = 2000,
):
Expand Down
1 change: 0 additions & 1 deletion src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
@dataclass
class SingleEventRun:
seed: int
path: str

detectors: list[str]
priors: dict[
Expand Down
19 changes: 4 additions & 15 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import jax.numpy as jnp
from jax import jit
from jax.scipy.integrate import trapezoid
from jax.scipy.special import i0e
from jaxtyping import Array, Float


@jit
def inner_product(
h1: Float[Array, " n_sample"],
h2: Float[Array, " n_sample"],
Expand Down Expand Up @@ -39,7 +37,6 @@ def inner_product(
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


@jit
def m1m2_to_Mq(m1: Float, m2: Float):
"""
Transforming the primary mass m1 and secondary mass m2 to the Total mass M
Expand All @@ -64,7 +61,6 @@ def m1m2_to_Mq(m1: Float, m2: Float):
return M_tot, q


@jit
def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float):
"""
Transforming the Total mass M and mass ratio q to the primary mass m1 and
Expand All @@ -91,7 +87,6 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float):
return m1, m2


@jit
def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]:
"""
Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and
Expand All @@ -118,7 +113,6 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]:
return m1, m2


@jit
def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the right ascension ra and declination dec to the polar angle
Expand All @@ -145,8 +139,7 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa
return theta, phi


@jit
def euler_rotation(delta_x: tuple[Float, Float, Float]):
def euler_rotation(delta_x: Float[Array, " 3"]):
"""
Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x
while preserving the origin of the azimuthal angle.
Expand Down Expand Up @@ -189,9 +182,8 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]):
return rotation


@jit
def zenith_azimuth_to_theta_phi(
zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]
zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame.
Expand Down Expand Up @@ -241,7 +233,6 @@ def zenith_azimuth_to_theta_phi(
return theta, phi


@jit
def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the polar angle and azimuthal angle to right ascension and declination.
Expand All @@ -267,9 +258,8 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F
return ra, dec


@jit
def zenith_azimuth_to_ra_dec(
zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]
zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination.
Expand Down Expand Up @@ -300,8 +290,7 @@ def zenith_azimuth_to_ra_dec(
return ra, dec


@jit
def log_i0(x):
def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]:
"""
A numerically stable method to evaluate log of
a modified Bessel function of order 0.
Expand Down
Loading