5
5
import numpy as np
6
6
import requests
7
7
from gwpy .timeseries import TimeSeries
8
- from jaxtyping import Array , PRNGKeyArray
8
+ from jaxtyping import Array , PRNGKeyArray , Float
9
9
from scipy .interpolate import interp1d
10
10
from scipy .signal .windows import tukey
11
11
12
- from jimgw .constants import *
12
+ from jimgw .constants import EARTH_SEMI_MAJOR_AXIS , EARTH_SEMI_MINOR_AXIS , C_SI
13
13
from jimgw .wave import Polarization
14
14
15
15
DEG_TO_RAD = jnp .pi / 180
@@ -39,39 +39,51 @@ class Detector(ABC):
39
39
40
40
name : str
41
41
42
+ data : Float [Array , " n_sample" ]
43
+ psd : Float [Array , " n_sample" ]
44
+
42
45
@abstractmethod
43
46
def load_data (self , data ):
44
47
raise NotImplementedError
45
48
46
49
@abstractmethod
47
- def fd_response (self , frequency : Array , h : Array , params : dict ) -> Array :
50
+ def fd_response (
51
+ self ,
52
+ frequency : Float [Array , " n_sample" ],
53
+ h : dict [str , Float [Array , " n_sample" ]],
54
+ params : dict ,
55
+ ) -> Float [Array , " n_sample" ]:
48
56
"""
49
57
Modulate the waveform in the sky frame by the detector response
50
58
in the frequency domain."""
51
59
pass
52
60
53
61
@abstractmethod
54
- def td_response (self , time : Array , h : Array , params : dict ) -> Array :
62
+ def td_response (
63
+ self ,
64
+ time : Float [Array , " n_sample" ],
65
+ h : dict [str , Float [Array , " n_sample" ]],
66
+ params : dict ,
67
+ ) -> Float [Array , " n_sample" ]:
55
68
"""
56
69
Modulate the waveform in the sky frame by the detector response
57
70
in the time domain."""
58
71
pass
59
72
60
73
61
74
class GroundBased2G (Detector ):
62
-
63
75
polarization_mode : list [Polarization ]
64
- frequencies : Array = None
65
- data : Array = None
66
- psd : Array = None
67
-
68
- latitude : float = 0
69
- longitude : float = 0
70
- xarm_azimuth : float = 0
71
- yarm_azimuth : float = 0
72
- xarm_tilt : float = 0
73
- yarm_tilt : float = 0
74
- elevation : float = 0
76
+ frequencies : Float [ Array , " n_sample" ]
77
+ data : Float [ Array , " n_sample" ]
78
+ psd : Float [ Array , " n_sample" ]
79
+
80
+ latitude : Float = 0
81
+ longitude : Float = 0
82
+ xarm_azimuth : Float = 0
83
+ yarm_azimuth : Float = 0
84
+ xarm_tilt : Float = 0
85
+ yarm_tilt : Float = 0
86
+ elevation : Float = 0
75
87
76
88
def __init__ (self , name : str , ** kwargs ) -> None :
77
89
self .name = name
@@ -86,22 +98,32 @@ def __init__(self, name: str, **kwargs) -> None:
86
98
modes = kwargs .get ("mode" , "pc" )
87
99
88
100
self .polarization_mode = [Polarization (m ) for m in modes ]
101
+ self .frequencies = jnp .array ([])
102
+ self .data = jnp .array ([])
103
+ self .psd = jnp .array ([])
89
104
90
105
@staticmethod
91
- def _get_arm (lat , lon , tilt , azimuth ):
106
+ def _get_arm (
107
+ lat : Float , lon : Float , tilt : Float , azimuth : Float
108
+ ) -> Float [Array , " 3" ]:
92
109
"""
93
110
Construct detector-arm vectors in Earth-centric Cartesian coordinates.
94
111
95
112
Parameters
96
113
---------
97
- lat : float
114
+ lat : Float
98
115
vertex latitude in rad.
99
- lon : float
116
+ lon : Float
100
117
vertex longitude in rad.
101
- tilt : float
118
+ tilt : Float
102
119
arm tilt in rad.
103
- azimuth : float
120
+ azimuth : Float
104
121
arm azimuth in rad.
122
+
123
+ Returns
124
+ -------
125
+ arm : Float[Array, " 3"]
126
+ detector arm vector in Earth-centric Cartesian coordinates.
105
127
"""
106
128
e_lon = jnp .array ([- jnp .sin (lon ), jnp .cos (lon ), 0 ])
107
129
e_lat = jnp .array (
@@ -118,9 +140,16 @@ def _get_arm(lat, lon, tilt, azimuth):
118
140
)
119
141
120
142
@property
121
- def arms (self ):
143
+ def arms (self ) -> tuple [ Float [ Array , " 3" ], Float [ Array , " 3" ]] :
122
144
"""
123
145
Detector arm vectors (x, y).
146
+
147
+ Returns
148
+ -------
149
+ x : Float[Array, " 3"]
150
+ x-arm vector.
151
+ y : Float[Array, " 3"]
152
+ y-arm vector.
124
153
"""
125
154
x = self ._get_arm (
126
155
self .latitude , self .longitude , self .xarm_tilt , self .xarm_azimuth
@@ -131,9 +160,14 @@ def arms(self):
131
160
return x , y
132
161
133
162
@property
134
- def tensor (self ):
163
+ def tensor (self ) -> Float [ Array , " 3, 3" ] :
135
164
"""
136
165
Detector tensor defining the strain measurement.
166
+
167
+ Returns
168
+ -------
169
+ tensor : Float[Array, " 3, 3"]
170
+ detector tensor.
137
171
"""
138
172
# TODO: this could easily be generalized for other detector geometries
139
173
arm1 , arm2 = self .arms
@@ -142,11 +176,16 @@ def tensor(self):
142
176
)
143
177
144
178
@property
145
- def vertex (self ):
179
+ def vertex (self ) -> Float [ Array , " 3" ] :
146
180
"""
147
181
Detector vertex coordinates in the reference celestial frame. Based
148
182
on arXiv:gr-qc/0008066 Eqs. (B11-B13) except for a typo in the
149
183
definition of the local radius; see Section 2.1 of LIGO-T980044-10.
184
+
185
+ Returns
186
+ -------
187
+ vertex : Float[Array, " 3"]
188
+ detector vertex coordinates.
150
189
"""
151
190
# get detector and Earth parameters
152
191
lat = self .latitude
@@ -164,33 +203,33 @@ def vertex(self):
164
203
165
204
def load_data (
166
205
self ,
167
- trigger_time : float ,
206
+ trigger_time : Float ,
168
207
gps_start_pad : int ,
169
208
gps_end_pad : int ,
170
- f_min : float ,
171
- f_max : float ,
209
+ f_min : Float ,
210
+ f_max : Float ,
172
211
psd_pad : int = 16 ,
173
- tukey_alpha : float = 0.2 ,
212
+ tukey_alpha : Float = 0.2 ,
174
213
gwpy_kwargs : dict = {"cache" : True },
175
214
) -> None :
176
215
"""
177
216
Load data from the detector.
178
217
179
218
Parameters
180
219
----------
181
- trigger_time : float
220
+ trigger_time : Float
182
221
The GPS time of the trigger.
183
222
gps_start_pad : int
184
223
The amount of time before the trigger to fetch data.
185
224
gps_end_pad : int
186
225
The amount of time after the trigger to fetch data.
187
- f_min : float
226
+ f_min : Float
188
227
The minimum frequency to fetch data.
189
- f_max : float
228
+ f_max : Float
190
229
The maximum frequency to fetch data.
191
230
psd_pad : int
192
231
The amount of time to pad the PSD data.
193
- tukey_alpha : float
232
+ tukey_alpha : Float
194
233
The alpha parameter for the Tukey window.
195
234
196
235
"""
@@ -202,6 +241,7 @@ def load_data(
202
241
trigger_time + gps_end_pad ,
203
242
** gwpy_kwargs
204
243
)
244
+ assert isinstance (data_td , TimeSeries ), "Data is not a TimeSeries object."
205
245
segment_length = data_td .duration .value
206
246
n = len (data_td )
207
247
delta_t = data_td .dt .value
@@ -217,6 +257,9 @@ def load_data(
217
257
psd_data_td = TimeSeries .fetch_open_data (
218
258
self .name , start_psd , end_psd , ** gwpy_kwargs
219
259
)
260
+ assert isinstance (
261
+ psd_data_td , TimeSeries
262
+ ), "PSD data is not a TimeSeries object."
220
263
psd = psd_data_td .psd (
221
264
fftlength = segment_length
222
265
).value # TODO: Check whether this is sright.
@@ -227,9 +270,15 @@ def load_data(
227
270
self .data = data [(freq > f_min ) & (freq < f_max )]
228
271
self .psd = psd [(freq > f_min ) & (freq < f_max )]
229
272
230
- def fd_response (self , frequency : Array , h_sky : dict , params : dict ) -> Array :
273
+ def fd_response (
274
+ self ,
275
+ frequency : Float [Array , " n_sample" ],
276
+ h_sky : dict [str , Float [Array , " n_sample" ]],
277
+ params : dict [str , Float ],
278
+ ) -> Array :
279
+ """
280
+ Modulate the waveform in the sky frame by the detector response in the frequency domain.
231
281
"""
232
- Modulate the waveform in the sky frame by the detector response in the frequency domain."""
233
282
ra , dec , psi , gmst = params ["ra" ], params ["dec" ], params ["psi" ], params ["gmst" ]
234
283
antenna_pattern = self .antenna_pattern (ra , dec , psi , gmst )
235
284
timeshift = self .delay_from_geocenter (ra , dec , gmst )
@@ -244,10 +293,11 @@ def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array:
244
293
245
294
def td_response (self , time : Array , h : Array , params : Array ) -> Array :
246
295
"""
247
- Modulate the waveform in the sky frame by the detector response in the time domain."""
248
- pass
296
+ Modulate the waveform in the sky frame by the detector response in the time domain.
297
+ """
298
+ raise NotImplementedError
249
299
250
- def delay_from_geocenter (self , ra : float , dec : float , gmst : float ) -> float :
300
+ def delay_from_geocenter (self , ra : Float , dec : Float , gmst : Float ) -> Float :
251
301
"""
252
302
Calculate time delay between two detectors in geocentric
253
303
coordinates based on XLALArrivaTimeDiff in TimeDelay.c
@@ -256,16 +306,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
256
306
257
307
Parameters
258
308
---------
259
- ra : float
309
+ ra : Float
260
310
right ascension of the source in rad.
261
- dec : float
311
+ dec : Float
262
312
declination of the source in rad.
263
- gmst : float
313
+ gmst : Float
264
314
Greenwich mean sidereal time in rad.
265
315
266
316
Returns
267
317
-------
268
- float : time delay from Earth center.
318
+ Float : time delay from Earth center.
269
319
"""
270
320
delta_d = - self .vertex
271
321
gmst = jnp .mod (gmst , 2 * jnp .pi )
@@ -280,7 +330,7 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
280
330
)
281
331
return jnp .dot (omega , delta_d ) / C_SI
282
332
283
- def antenna_pattern (self , ra : float , dec : float , psi : float , gmst : float ) -> dict :
333
+ def antenna_pattern (self , ra : Float , dec : Float , psi : Float , gmst : Float ) -> dict :
284
334
"""
285
335
Computes {name} antenna patterns for {modes} polarizations
286
336
at the specified sky location, orientation and GMST.
@@ -291,13 +341,13 @@ def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dic
291
341
292
342
Parameters
293
343
---------
294
- ra : float
344
+ ra : Float
295
345
source right ascension in radians.
296
- dec : float
346
+ dec : Float
297
347
source declination in radians.
298
- psi : float
348
+ psi : Float
299
349
source polarization angle in radians.
300
- gmst : float
350
+ gmst : Float
301
351
Greenwich mean sidereal time (GMST) in radians.
302
352
modes : str
303
353
string of polarizations to include, defaults to tensor modes: 'pc'.
@@ -324,7 +374,7 @@ def inject_signal(
324
374
freqs : Array ,
325
375
h_sky : dict ,
326
376
params : dict ,
327
- psd_file : str = None ,
377
+ psd_file : str = "" ,
328
378
) -> None :
329
379
""" """
330
380
self .frequencies = freqs
@@ -339,8 +389,10 @@ def inject_signal(
339
389
signal = self .fd_response (freqs , h_sky , params ) * align_time
340
390
self .data = signal + noise_real + 1j * noise_imag
341
391
342
- def load_psd (self , freqs : Array , psd_file : str = None ) -> None :
343
- if psd_file is None :
392
+ def load_psd (
393
+ self , freqs : Float [Array , " n_sample" ], psd_file : str = ""
394
+ ) -> Float [Array , " n_sample" ]:
395
+ if psd_file == "" :
344
396
print ("Grabbing GWTC-2 PSD for " + self .name )
345
397
url = psd_file_dict [self .name ]
346
398
data = requests .get (url )
@@ -349,7 +401,10 @@ def load_psd(self, freqs: Array, psd_file: str = None) -> None:
349
401
else :
350
402
f , asd_vals = np .loadtxt (psd_file , unpack = True )
351
403
psd_vals = asd_vals ** 2
352
- psd = interp1d (f , psd_vals , fill_value = (psd_vals [0 ], psd_vals [- 1 ]))(freqs )
404
+ assert isinstance (f , Float [Array , "n_sample" ])
405
+ assert isinstance (psd_vals , Float [Array , "n_sample" ])
406
+ psd = interp1d (f , psd_vals , fill_value = (psd_vals [0 ], psd_vals [- 1 ]))(freqs ) # type: ignore
407
+ psd = jnp .array (psd )
353
408
return psd
354
409
355
410
0 commit comments