Skip to content

Commit

Permalink
Make gate-and-inpainting starting time precise (gwastro#4403)
Browse files Browse the repository at this point in the history
* apply time shift to account for the gating start time is a subsampling data point. fix a few typos

* fix int type error

* correct a misunderstanding for apply_fd_time_shift

* correct a unused variable

* correct the shift time

* address cc issues

* apply static method

* complete the gated_gaussian_noise

* typo

* typo

* re express time shift

* space

* start to work on pol marg model

* complete shifting for marg pol model

* instead of modifying models, now implementing time shifting only in gate-and-inpaint session

* remove all changes in the model

* remove changes in model

* fix some bugs when fft/ifft converting, now productive

* remove rindex+1 in gate and inpainting, add if condition for time shifting

* get rindex_time via actually doing the math, this is faster than retrieve the data sample time array

* remove the proj and projslc attributes in gated_gaussian_model, it's used for debugging before
  • Loading branch information
yi-fan-wang authored Sep 22, 2023
1 parent 4da9e64 commit caac90b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
7 changes: 0 additions & 7 deletions pycbc/inference/models/gated_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def __init__(self, variable_params, data, low_frequency_cutoff, psds=None,
static_params=None, highpass_waveforms=False, **kwargs):
# we'll want the time-domain data, so store that
self._td_data = {}
# cache the current projection for debugging
self.current_proj = {}
self.current_nproj = {}
# cache the overwhitened data
self._overwhitened_data = {}
# cache the current gated data
Expand Down Expand Up @@ -383,7 +380,6 @@ def _lognl(self):
self._det_lognls.clear()
# get the times of the gates
gate_times = self.get_gate_times()
self.current_nproj.clear()
for det, invpsd in self._invpsds.items():
norm = self.det_lognorm(det)
gatestartdelay, dgatedelay = gate_times[det]
Expand All @@ -395,7 +391,6 @@ def _lognl(self):
gated_dt = data.gate(gatestartdelay + dgatedelay/2,
window=dgatedelay/2, copy=True,
invpsd=invpsd, method='paint')
self.current_nproj[det] = (gated_dt.proj, gated_dt.projslc)
# convert to the frequency series
gated_d = gated_dt.to_frequencyseries()
# overwhiten
Expand Down Expand Up @@ -546,7 +541,6 @@ def _loglikelihood(self):
# get the times of the gates
gate_times = self.get_gate_times()
logl = 0.
self.current_proj.clear()
for det, h in wfs.items():
invpsd = self._invpsds[det]
norm = self.det_lognorm(det)
Expand All @@ -562,7 +556,6 @@ def _loglikelihood(self):
gated_res = res.gate(gatestartdelay + dgatedelay/2,
window=dgatedelay/2, copy=True,
invpsd=invpsd, method='paint')
self.current_proj[det] = (gated_res.proj, gated_res.projslc)
gated_rtilde = gated_res.to_frequencyseries()
# overwhiten
gated_rtilde *= invpsd
Expand Down
4 changes: 2 additions & 2 deletions pycbc/inference/models/marginalized_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(self, variable_params,
generator_class=generator.FDomainDetFrameTwoPolNoRespGenerator,
gates=self.gates, **kwargs['static_params'])
else:
# create a waveform generator for each ifo respestively
# create a waveform generator for each ifo respectively
self.waveform_generator = {}
for det in self.data:
self.waveform_generator[det] = create_waveform_generator(
Expand Down Expand Up @@ -386,7 +386,7 @@ def __init__(self, variable_params, data, low_frequency_cutoff, psds=None,
generator_class=generator.FDomainDetFrameTwoPolGenerator,
gates=self.gates, **kwargs['static_params'])
else:
# create a waveform generator for each ifo respestively
# create a waveform generator for each ifo respectively
self.waveform_generator = {}
for det in self.data:
self.waveform_generator[det] = create_waveform_generator(
Expand Down
3 changes: 1 addition & 2 deletions pycbc/strain/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def gate_and_paint(data, lindex, rindex, invpsd, copy=True):
# Copy the data and zero inside the hole
if copy:
data = data.copy()
# Here's ambiguity about when gate end time exactly is, rindex-1 or rindex?
data[lindex:rindex] = 0

# get the over-whitened gated data
Expand All @@ -178,6 +179,4 @@ def gate_and_paint(data, lindex, rindex, invpsd, copy=True):
proj = linalg.solve_toeplitz(tdfilter[:(rindex - lindex)],
owhgated_data[lindex:rindex])
data[lindex:rindex] -= proj
data.projslc = (lindex, rindex)
data.proj = proj
return data
20 changes: 18 additions & 2 deletions pycbc/types/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,15 +595,31 @@ def gate(self, time, window=0.25, method='taper', copy=True,
# Uses the hole-filling method of
# https://arxiv.org/pdf/1908.05644.pdf
from pycbc.strain.gate import gate_and_paint
from pycbc.waveform.utils import apply_fd_time_shift
if invpsd is None:
# These are some bare minimum settings, normally you
# should probably provide a psd
invpsd = 1. / self.filter_psd(self.duration/32, self.delta_f, 0)
lindex = int((time - window - self.start_time) / self.delta_t)
rindex = lindex + int(2 * window / self.delta_t)
rindex = int((time + window - self.start_time) / self.delta_t)
lindex = lindex if lindex >= 0 else 0
rindex = rindex if rindex <= len(self) else len(self)
return gate_and_paint(data, lindex, rindex, invpsd, copy=False)
rindex_time = float(self.start_time + rindex * self.delta_t)
offset = rindex_time - (time + window)
if offset == 0:
return gate_and_paint(data, lindex, rindex, invpsd, copy=False)
else:
# time shift such that gate end time lands on a specific data sample
fdata = data.to_frequencyseries()
fdata = apply_fd_time_shift(fdata, offset + fdata.epoch, copy=False)
# gate and paint in time domain
data = fdata.to_timeseries()
data = gate_and_paint(data, lindex, rindex, invpsd, copy=False)
# shift back to the original time
fdata = data.to_frequencyseries()
fdata = apply_fd_time_shift(fdata, -offset + fdata.epoch, copy=False)
tdata = fdata.to_timeseries()
return tdata
elif method == 'hard':
tslice = data.time_slice(time - window, time + window)
tslice[:] = 0
Expand Down

0 comments on commit caac90b

Please sign in to comment.