Skip to content

Commit 733d802

Browse files
authored
Merge pull request #48 from kazewong/46-restyling-the-code-to-work-with-pre-commit
46 restyling the code to work with pre commit
2 parents 282f409 + 37f0b96 commit 733d802

12 files changed

+478
-288
lines changed

.pre-commit-config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
files: src/
12
repos:
23
- repo: https://github.com/ambv/black
34
rev: 23.9.1
@@ -12,7 +13,7 @@ repos:
1213
rev: v1.1.338
1314
hooks:
1415
- id: pyright
15-
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions]
16+
additional_dependencies: [beartype, jax, jaxtyping, pytest, typing_extensions, flowMC, ripplegw, gwpy, astropy]
1617
- repo: https://github.com/nbQA-dev/nbQA
1718
rev: 1.7.1
1819
hooks:

ruff.toml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ignore = ["F722"]

src/jimgw/constants.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from astropy.constants import c,au,G,pc
1+
from astropy.constants import c, pc # type: ignore TODO: fix astropy stubs
22
from astropy.units import year as yr
3-
from astropy.cosmology import WMAP9 as cosmo
43

54
Msun = 4.9255e-6
6-
year = (1*yr).cgs.value
7-
Mpc = 1e6*pc.value/c.value
5+
year = (1 * yr).cgs.value # type: ignore
6+
Mpc = 1e6 * pc.value / c.value
87
euler_gamma = 0.577215664901532860606512090082
98
MR_sun = 1.476625061404649406193430731479084713e3
109
C_SI = 299792458.0
@@ -13,4 +12,4 @@
1312
EARTH_SEMI_MINOR_AXIS = 6356752.314 # in m
1413

1514
DAYSID_SI = 86164.09053133354
16-
DAYJUL_SI = 86400.0
15+
DAYJUL_SI = 86400.0

src/jimgw/data.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import equinox as eqx
21
from abc import ABC, abstractmethod
3-
from jaxtyping import Array
42

5-
class Data(ABC):
63

4+
class Data(ABC):
75
@abstractmethod
86
def __init__(self):
97
raise NotImplementedError

src/jimgw/detector.py

+105-50
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import numpy as np
66
import requests
77
from gwpy.timeseries import TimeSeries
8-
from jaxtyping import Array, PRNGKeyArray
8+
from jaxtyping import Array, PRNGKeyArray, Float
99
from scipy.interpolate import interp1d
1010
from scipy.signal.windows import tukey
1111

12-
from jimgw.constants import *
12+
from jimgw.constants import EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS, C_SI
1313
from jimgw.wave import Polarization
1414

1515
DEG_TO_RAD = jnp.pi / 180
@@ -39,39 +39,51 @@ class Detector(ABC):
3939

4040
name: str
4141

42+
data: Float[Array, " n_sample"]
43+
psd: Float[Array, " n_sample"]
44+
4245
@abstractmethod
4346
def load_data(self, data):
4447
raise NotImplementedError
4548

4649
@abstractmethod
47-
def fd_response(self, frequency: Array, h: Array, params: dict) -> Array:
50+
def fd_response(
51+
self,
52+
frequency: Float[Array, " n_sample"],
53+
h: dict[str, Float[Array, " n_sample"]],
54+
params: dict,
55+
) -> Float[Array, " n_sample"]:
4856
"""
4957
Modulate the waveform in the sky frame by the detector response
5058
in the frequency domain."""
5159
pass
5260

5361
@abstractmethod
54-
def td_response(self, time: Array, h: Array, params: dict) -> Array:
62+
def td_response(
63+
self,
64+
time: Float[Array, " n_sample"],
65+
h: dict[str, Float[Array, " n_sample"]],
66+
params: dict,
67+
) -> Float[Array, " n_sample"]:
5568
"""
5669
Modulate the waveform in the sky frame by the detector response
5770
in the time domain."""
5871
pass
5972

6073

6174
class GroundBased2G(Detector):
62-
6375
polarization_mode: list[Polarization]
64-
frequencies: Array = None
65-
data: Array = None
66-
psd: Array = None
67-
68-
latitude: float = 0
69-
longitude: float = 0
70-
xarm_azimuth: float = 0
71-
yarm_azimuth: float = 0
72-
xarm_tilt: float = 0
73-
yarm_tilt: float = 0
74-
elevation: float = 0
76+
frequencies: Float[Array, " n_sample"]
77+
data: Float[Array, " n_sample"]
78+
psd: Float[Array, " n_sample"]
79+
80+
latitude: Float = 0
81+
longitude: Float = 0
82+
xarm_azimuth: Float = 0
83+
yarm_azimuth: Float = 0
84+
xarm_tilt: Float = 0
85+
yarm_tilt: Float = 0
86+
elevation: Float = 0
7587

7688
def __init__(self, name: str, **kwargs) -> None:
7789
self.name = name
@@ -86,22 +98,32 @@ def __init__(self, name: str, **kwargs) -> None:
8698
modes = kwargs.get("mode", "pc")
8799

88100
self.polarization_mode = [Polarization(m) for m in modes]
101+
self.frequencies = jnp.array([])
102+
self.data = jnp.array([])
103+
self.psd = jnp.array([])
89104

90105
@staticmethod
91-
def _get_arm(lat, lon, tilt, azimuth):
106+
def _get_arm(
107+
lat: Float, lon: Float, tilt: Float, azimuth: Float
108+
) -> Float[Array, " 3"]:
92109
"""
93110
Construct detector-arm vectors in Earth-centric Cartesian coordinates.
94111
95112
Parameters
96113
---------
97-
lat : float
114+
lat : Float
98115
vertex latitude in rad.
99-
lon : float
116+
lon : Float
100117
vertex longitude in rad.
101-
tilt : float
118+
tilt : Float
102119
arm tilt in rad.
103-
azimuth : float
120+
azimuth : Float
104121
arm azimuth in rad.
122+
123+
Returns
124+
-------
125+
arm : Float[Array, " 3"]
126+
detector arm vector in Earth-centric Cartesian coordinates.
105127
"""
106128
e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0])
107129
e_lat = jnp.array(
@@ -118,9 +140,16 @@ def _get_arm(lat, lon, tilt, azimuth):
118140
)
119141

120142
@property
121-
def arms(self):
143+
def arms(self) -> tuple[Float[Array, " 3"], Float[Array, " 3"]]:
122144
"""
123145
Detector arm vectors (x, y).
146+
147+
Returns
148+
-------
149+
x : Float[Array, " 3"]
150+
x-arm vector.
151+
y : Float[Array, " 3"]
152+
y-arm vector.
124153
"""
125154
x = self._get_arm(
126155
self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth
@@ -131,9 +160,14 @@ def arms(self):
131160
return x, y
132161

133162
@property
134-
def tensor(self):
163+
def tensor(self) -> Float[Array, " 3, 3"]:
135164
"""
136165
Detector tensor defining the strain measurement.
166+
167+
Returns
168+
-------
169+
tensor : Float[Array, " 3, 3"]
170+
detector tensor.
137171
"""
138172
# TODO: this could easily be generalized for other detector geometries
139173
arm1, arm2 = self.arms
@@ -142,11 +176,16 @@ def tensor(self):
142176
)
143177

144178
@property
145-
def vertex(self):
179+
def vertex(self) -> Float[Array, " 3"]:
146180
"""
147181
Detector vertex coordinates in the reference celestial frame. Based
148182
on arXiv:gr-qc/0008066 Eqs. (B11-B13) except for a typo in the
149183
definition of the local radius; see Section 2.1 of LIGO-T980044-10.
184+
185+
Returns
186+
-------
187+
vertex : Float[Array, " 3"]
188+
detector vertex coordinates.
150189
"""
151190
# get detector and Earth parameters
152191
lat = self.latitude
@@ -164,33 +203,33 @@ def vertex(self):
164203

165204
def load_data(
166205
self,
167-
trigger_time: float,
206+
trigger_time: Float,
168207
gps_start_pad: int,
169208
gps_end_pad: int,
170-
f_min: float,
171-
f_max: float,
209+
f_min: Float,
210+
f_max: Float,
172211
psd_pad: int = 16,
173-
tukey_alpha: float = 0.2,
212+
tukey_alpha: Float = 0.2,
174213
gwpy_kwargs: dict = {"cache": True},
175214
) -> None:
176215
"""
177216
Load data from the detector.
178217
179218
Parameters
180219
----------
181-
trigger_time : float
220+
trigger_time : Float
182221
The GPS time of the trigger.
183222
gps_start_pad : int
184223
The amount of time before the trigger to fetch data.
185224
gps_end_pad : int
186225
The amount of time after the trigger to fetch data.
187-
f_min : float
226+
f_min : Float
188227
The minimum frequency to fetch data.
189-
f_max : float
228+
f_max : Float
190229
The maximum frequency to fetch data.
191230
psd_pad : int
192231
The amount of time to pad the PSD data.
193-
tukey_alpha : float
232+
tukey_alpha : Float
194233
The alpha parameter for the Tukey window.
195234
196235
"""
@@ -202,6 +241,7 @@ def load_data(
202241
trigger_time + gps_end_pad,
203242
**gwpy_kwargs
204243
)
244+
assert isinstance(data_td, TimeSeries), "Data is not a TimeSeries object."
205245
segment_length = data_td.duration.value
206246
n = len(data_td)
207247
delta_t = data_td.dt.value
@@ -217,6 +257,9 @@ def load_data(
217257
psd_data_td = TimeSeries.fetch_open_data(
218258
self.name, start_psd, end_psd, **gwpy_kwargs
219259
)
260+
assert isinstance(
261+
psd_data_td, TimeSeries
262+
), "PSD data is not a TimeSeries object."
220263
psd = psd_data_td.psd(
221264
fftlength=segment_length
222265
).value # TODO: Check whether this is sright.
@@ -227,9 +270,15 @@ def load_data(
227270
self.data = data[(freq > f_min) & (freq < f_max)]
228271
self.psd = psd[(freq > f_min) & (freq < f_max)]
229272

230-
def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array:
273+
def fd_response(
274+
self,
275+
frequency: Float[Array, " n_sample"],
276+
h_sky: dict[str, Float[Array, " n_sample"]],
277+
params: dict[str, Float],
278+
) -> Array:
279+
"""
280+
Modulate the waveform in the sky frame by the detector response in the frequency domain.
231281
"""
232-
Modulate the waveform in the sky frame by the detector response in the frequency domain."""
233282
ra, dec, psi, gmst = params["ra"], params["dec"], params["psi"], params["gmst"]
234283
antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst)
235284
timeshift = self.delay_from_geocenter(ra, dec, gmst)
@@ -244,10 +293,11 @@ def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array:
244293

245294
def td_response(self, time: Array, h: Array, params: Array) -> Array:
246295
"""
247-
Modulate the waveform in the sky frame by the detector response in the time domain."""
248-
pass
296+
Modulate the waveform in the sky frame by the detector response in the time domain.
297+
"""
298+
raise NotImplementedError
249299

250-
def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
300+
def delay_from_geocenter(self, ra: Float, dec: Float, gmst: Float) -> Float:
251301
"""
252302
Calculate time delay between two detectors in geocentric
253303
coordinates based on XLALArrivaTimeDiff in TimeDelay.c
@@ -256,16 +306,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
256306
257307
Parameters
258308
---------
259-
ra : float
309+
ra : Float
260310
right ascension of the source in rad.
261-
dec : float
311+
dec : Float
262312
declination of the source in rad.
263-
gmst : float
313+
gmst : Float
264314
Greenwich mean sidereal time in rad.
265315
266316
Returns
267317
-------
268-
float: time delay from Earth center.
318+
Float: time delay from Earth center.
269319
"""
270320
delta_d = -self.vertex
271321
gmst = jnp.mod(gmst, 2 * jnp.pi)
@@ -280,7 +330,7 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
280330
)
281331
return jnp.dot(omega, delta_d) / C_SI
282332

283-
def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dict:
333+
def antenna_pattern(self, ra: Float, dec: Float, psi: Float, gmst: Float) -> dict:
284334
"""
285335
Computes {name} antenna patterns for {modes} polarizations
286336
at the specified sky location, orientation and GMST.
@@ -291,13 +341,13 @@ def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dic
291341
292342
Parameters
293343
---------
294-
ra : float
344+
ra : Float
295345
source right ascension in radians.
296-
dec : float
346+
dec : Float
297347
source declination in radians.
298-
psi : float
348+
psi : Float
299349
source polarization angle in radians.
300-
gmst : float
350+
gmst : Float
301351
Greenwich mean sidereal time (GMST) in radians.
302352
modes : str
303353
string of polarizations to include, defaults to tensor modes: 'pc'.
@@ -324,7 +374,7 @@ def inject_signal(
324374
freqs: Array,
325375
h_sky: dict,
326376
params: dict,
327-
psd_file: str = None,
377+
psd_file: str = "",
328378
) -> None:
329379
""" """
330380
self.frequencies = freqs
@@ -339,8 +389,10 @@ def inject_signal(
339389
signal = self.fd_response(freqs, h_sky, params) * align_time
340390
self.data = signal + noise_real + 1j * noise_imag
341391

342-
def load_psd(self, freqs: Array, psd_file: str = None) -> None:
343-
if psd_file is None:
392+
def load_psd(
393+
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
394+
) -> Float[Array, " n_sample"]:
395+
if psd_file == "":
344396
print("Grabbing GWTC-2 PSD for " + self.name)
345397
url = psd_file_dict[self.name]
346398
data = requests.get(url)
@@ -349,7 +401,10 @@ def load_psd(self, freqs: Array, psd_file: str = None) -> None:
349401
else:
350402
f, asd_vals = np.loadtxt(psd_file, unpack=True)
351403
psd_vals = asd_vals**2
352-
psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs)
404+
assert isinstance(f, Float[Array, "n_sample"])
405+
assert isinstance(psd_vals, Float[Array, "n_sample"])
406+
psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore
407+
psd = jnp.array(psd)
353408
return psd
354409

355410

0 commit comments

Comments
 (0)