Skip to content

Commit

Permalink
fix normalization of sky and time marginalization to reduce integral … (
Browse files Browse the repository at this point in the history
gwastro#4804)

* fix normalization of sky and time marginalization to reduce integral uncertainty

* remove debug lines

* cc

* update

* edge cases
  • Loading branch information
ahnitz authored Jul 1, 2024
1 parent adef96d commit 72cfa7a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
4 changes: 3 additions & 1 deletion pycbc/inference/models/relbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,9 @@ def _loglr(self):
# Note, this includes complex conjugation already
# as our stored inner products were hp* x data
htf = (f.real * ip + 1.0j * f.imag * ic)
sh = self.sh[ifo].at_time(dts, interpolate='quadratic')
sh = self.sh[ifo].at_time(dts,
interpolate='quadratic',
extrapolate=0.0j)
sh_total += sh * htf
hh_total += self.hh[ifo] * abs(htf) ** 2.0

Expand Down
58 changes: 35 additions & 23 deletions pycbc/inference/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def draw_sample(loglr, size=None):
class DistMarg():
"""Help class to add bookkeeping for likelihood marginalization"""

marginalize_phase = None
distance_marginalization = None
distance_interpolator = None

def setup_marginalization(self,
variable_params,
marginalize_phase=False,
Expand Down Expand Up @@ -287,7 +283,11 @@ def premarg_draw(self):

# Update the current proposed times and the marginalization values
logw = self.premarg['logw_partial']
choice = numpy.random.randint(0, len(logw), size=self.vsamples)
if self.vsamples == len(logw):
choice = slice(None, None)
else:
choice = numpy.random.choice(len(logw), size=self.vsamples,
replace=False)

for k in self.snr_params:
self.marginalize_vector_params[k] = self.premarg[k][choice]
Expand Down Expand Up @@ -339,7 +339,8 @@ def draw_times(self, snrs, size=None):
"""
if not hasattr(self, 'tinfo'):
# determine the rough time offsets for this sky location
tcmin, tcmax = self.marginalized_vector_priors['tc'].bounds['tc']
tcprior = self.marginalized_vector_priors['tc']
tcmin, tcmax = tcprior.bounds['tc']
tcave = (tcmax + tcmin) / 2.0
ifos = list(snrs.keys())
if hasattr(self, 'keep_ifos'):
Expand All @@ -363,13 +364,17 @@ def draw_times(self, snrs, size=None):
starts = []
ends = []

tmin, tmax = tcmin - dt, tcmax + dt
delt = snrs[iref].delta_t
tmin = tcmin + dt - delt
tmax = tcmax + dt + delt
if hasattr(self, 'tstart'):
tmin = self.tstart[iref]
tmax = self.tend[iref]

starts.append(max(tmin, snrs[iref].start_time))
ends.append(min(tmax, snrs[iref].end_time))
# Make sure we draw from times within prior and that have enough
# SNR calculated to do later interpolation
starts.append(max(tmin, snrs[iref].start_time + delt))
ends.append(min(tmax, snrs[iref].end_time - delt * 2))

idels = {}
for ifo in ifos[1:]:
Expand All @@ -394,6 +399,7 @@ def draw_times(self, snrs, size=None):
mode='nearest')
logweight += snrv.squared_norm().numpy()
logweight /= 2.0
logweight -= logsumexp(logweight) # Normalize to PDF

# Draw proportional to the incoherent likelihood
# Draw first which time sample
Expand All @@ -405,15 +411,15 @@ def draw_times(self, snrs, size=None):
tc = tct + tci * snr.delta_t + float(snr.start_time) - dt

# Update the current proposed times and the marginalization values
logw = - logweight[tci]
# assumes uniform prior!
logw = - logweight[tci] + numpy.log(1.0 / len(logweight))
self.marginalize_vector_params['tc'] = tc
self.marginalize_vector_params['logw_partial'] = logw

if self._current_params is not None:
# Update the importance weights for each vector sample
logw = self.marginalize_vector_weights + logw
self._current_params.update(self.marginalize_vector_params)
self.marginalize_vector_weights = logw - logsumexp(logw)
self.marginalize_vector_weights += logw

return self.marginalize_vector_params

Expand Down Expand Up @@ -465,7 +471,10 @@ def make_init():
if len(ifos) == 1:
dmap[()] = numpy.arange(0, size, 1).astype(int)

return dmap, tcmin, tcmax, fp, fc, ra, dec, dtc
# Sky prior by bin
bin_prior = {t: len(dmap[t]) / size for t in dmap}

return dmap, tcmin, tcmax, fp, fc, ra, dec, dtc, bin_prior

if not hasattr(self, 'tinfo'):
self.tinfo = {}
Expand All @@ -474,7 +483,7 @@ def make_init():
logging.info('pregenerating sky pointings')
self.tinfo[ikey] = make_init()

dmap, tcmin, tcmax, fp, fc, ra, dec, dtc = self.tinfo[ikey]
dmap, tcmin, tcmax, fp, fc, ra, dec, dtc, bin_prior = self.tinfo[ikey]

vsamples = size if size is not None else self.vsamples

Expand All @@ -492,25 +501,25 @@ def make_init():
tmin = self.tstart[ifo]
tmax = self.tend[ifo]

start = max(tmin, snrs[ifo].start_time)
end = min(tmax, snrs[ifo].end_time)

start = max(tmin, snr.start_time + snr.delta_t)
end = min(tmax, snr.end_time - snr.delta_t * 2)
snr = snr.time_slice(start, end, mode='nearest')

w = snr.squared_norm().numpy() / 2.0
i = draw_sample(w, size=vsamples)

if sref is not None:
mcweight -= w[i]
mcweight += w[i]
delt = float(snr.start_time - sref.start_time)
i += round(delt / sref.delta_t)
dx.append(iref - i)
else:
sref = snr
iref = i
mcweight = -w[i]
mcweight = w[i]

idx.append(i)
mcweight -= logsumexp(mcweight)

# check if delay is in dict, if not, throw out
ti = []
Expand All @@ -522,7 +531,7 @@ def make_init():
if t in dmap:
randi = int(rand[i] * (len(dmap[t])))
ix.append(dmap[t][randi])
wi.append(len(dmap[t]))
wi.append(bin_prior[t])
ti.append(i)

# If we had really poor efficiency at finding a point, we should
Expand All @@ -535,6 +544,7 @@ def make_init():
ix = numpy.resize(numpy.array(ix, dtype=int), vsamples)
self.sample_idx = ix
self.precalc_antenna_factors = fp, fc, dtc
resize_factor = len(ti) / vsamples

ra = ra[ix]
dec = dec[ix]
Expand All @@ -551,17 +561,19 @@ def make_init():
tc = tct + iref[ti] * snr.delta_t + float(sref.start_time) - dtc[ifos[0]]

# Update the current proposed times and the marginalization values
logw_sky = mcweight[ti] + numpy.log(wi)
# There's an overall normalization here which may introduce a constant
# factor at the moment.
logw_sky = -mcweight[ti] + numpy.log(wi) - numpy.log(resize_factor)

self.marginalize_vector_params['tc'] = tc
self.marginalize_vector_params['ra'] = ra
self.marginalize_vector_params['dec'] = dec
self.marginalize_vector_params['logw_partial'] = logw_sky

if self._current_params is not None:
# Update the importance weights for each vector sample
logw = self.marginalize_vector_weights + logw_sky
self._current_params.update(self.marginalize_vector_params)
self.marginalize_vector_weights = logw - logsumexp(logw)
self.marginalize_vector_weights += logw_sky

return self.marginalize_vector_params

Expand Down

0 comments on commit 72cfa7a

Please sign in to comment.