From 72cfa7ab03238763442a7cb227ec421ce90e87d0 Mon Sep 17 00:00:00 2001 From: Alex Nitz Date: Mon, 1 Jul 2024 19:09:41 +0100 Subject: [PATCH] =?UTF-8?q?fix=20normalization=20of=20sky=20and=20time=20m?= =?UTF-8?q?arginalization=20to=20reduce=20integral=20=E2=80=A6=20(#4804)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix normalization of sky and time marginalization to reduce integral uncertainty * remove debug lines * cc * update * edge cases --- pycbc/inference/models/relbin.py | 4 ++- pycbc/inference/models/tools.py | 58 +++++++++++++++++++------------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/pycbc/inference/models/relbin.py b/pycbc/inference/models/relbin.py index 7f7445f78e5..4f574992fa3 100644 --- a/pycbc/inference/models/relbin.py +++ b/pycbc/inference/models/relbin.py @@ -868,7 +868,9 @@ def _loglr(self): # Note, this includes complex conjugation already # as our stored inner products were hp* x data htf = (f.real * ip + 1.0j * f.imag * ic) - sh = self.sh[ifo].at_time(dts, interpolate='quadratic') + sh = self.sh[ifo].at_time(dts, + interpolate='quadratic', + extrapolate=0.0j) sh_total += sh * htf hh_total += self.hh[ifo] * abs(htf) ** 2.0 diff --git a/pycbc/inference/models/tools.py b/pycbc/inference/models/tools.py index b07f09c3c5f..044c716440f 100644 --- a/pycbc/inference/models/tools.py +++ b/pycbc/inference/models/tools.py @@ -51,10 +51,6 @@ def draw_sample(loglr, size=None): class DistMarg(): """Help class to add bookkeeping for likelihood marginalization""" - marginalize_phase = None - distance_marginalization = None - distance_interpolator = None - def setup_marginalization(self, variable_params, marginalize_phase=False, @@ -287,7 +283,11 @@ def premarg_draw(self): # Update the current proposed times and the marginalization values logw = self.premarg['logw_partial'] - choice = numpy.random.randint(0, len(logw), size=self.vsamples) + if self.vsamples == len(logw): + choice = slice(None, None) + else: + choice = numpy.random.choice(len(logw), size=self.vsamples, + replace=False) for k in self.snr_params: self.marginalize_vector_params[k] = self.premarg[k][choice] @@ -339,7 +339,8 @@ def draw_times(self, snrs, size=None): """ if not hasattr(self, 'tinfo'): # determine the rough time offsets for this sky location - tcmin, tcmax = self.marginalized_vector_priors['tc'].bounds['tc'] + tcprior = self.marginalized_vector_priors['tc'] + tcmin, tcmax = tcprior.bounds['tc'] tcave = (tcmax + tcmin) / 2.0 ifos = list(snrs.keys()) if hasattr(self, 'keep_ifos'): @@ -363,13 +364,17 @@ def draw_times(self, snrs, size=None): starts = [] ends = [] - tmin, tmax = tcmin - dt, tcmax + dt + delt = snrs[iref].delta_t + tmin = tcmin + dt - delt + tmax = tcmax + dt + delt if hasattr(self, 'tstart'): tmin = self.tstart[iref] tmax = self.tend[iref] - starts.append(max(tmin, snrs[iref].start_time)) - ends.append(min(tmax, snrs[iref].end_time)) + # Make sure we draw from times within prior and that have enough + # SNR calculated to do later interpolation + starts.append(max(tmin, snrs[iref].start_time + delt)) + ends.append(min(tmax, snrs[iref].end_time - delt * 2)) idels = {} for ifo in ifos[1:]: @@ -394,6 +399,7 @@ def draw_times(self, snrs, size=None): mode='nearest') logweight += snrv.squared_norm().numpy() logweight /= 2.0 + logweight -= logsumexp(logweight) # Normalize to PDF # Draw proportional to the incoherent likelihood # Draw first which time sample @@ -405,15 +411,15 @@ def draw_times(self, snrs, size=None): tc = tct + tci * snr.delta_t + float(snr.start_time) - dt # Update the current proposed times and the marginalization values - logw = - logweight[tci] + # assumes uniform prior! + logw = - logweight[tci] + numpy.log(1.0 / len(logweight)) self.marginalize_vector_params['tc'] = tc self.marginalize_vector_params['logw_partial'] = logw if self._current_params is not None: # Update the importance weights for each vector sample - logw = self.marginalize_vector_weights + logw self._current_params.update(self.marginalize_vector_params) - self.marginalize_vector_weights = logw - logsumexp(logw) + self.marginalize_vector_weights += logw return self.marginalize_vector_params @@ -465,7 +471,10 @@ def make_init(): if len(ifos) == 1: dmap[()] = numpy.arange(0, size, 1).astype(int) - return dmap, tcmin, tcmax, fp, fc, ra, dec, dtc + # Sky prior by bin + bin_prior = {t: len(dmap[t]) / size for t in dmap} + + return dmap, tcmin, tcmax, fp, fc, ra, dec, dtc, bin_prior if not hasattr(self, 'tinfo'): self.tinfo = {} @@ -474,7 +483,7 @@ def make_init(): logging.info('pregenerating sky pointings') self.tinfo[ikey] = make_init() - dmap, tcmin, tcmax, fp, fc, ra, dec, dtc = self.tinfo[ikey] + dmap, tcmin, tcmax, fp, fc, ra, dec, dtc, bin_prior = self.tinfo[ikey] vsamples = size if size is not None else self.vsamples @@ -492,25 +501,25 @@ def make_init(): tmin = self.tstart[ifo] tmax = self.tend[ifo] - start = max(tmin, snrs[ifo].start_time) - end = min(tmax, snrs[ifo].end_time) - + start = max(tmin, snr.start_time + snr.delta_t) + end = min(tmax, snr.end_time - snr.delta_t * 2) snr = snr.time_slice(start, end, mode='nearest') w = snr.squared_norm().numpy() / 2.0 i = draw_sample(w, size=vsamples) if sref is not None: - mcweight -= w[i] + mcweight += w[i] delt = float(snr.start_time - sref.start_time) i += round(delt / sref.delta_t) dx.append(iref - i) else: sref = snr iref = i - mcweight = -w[i] + mcweight = w[i] idx.append(i) + mcweight -= logsumexp(mcweight) # check if delay is in dict, if not, throw out ti = [] @@ -522,7 +531,7 @@ def make_init(): if t in dmap: randi = int(rand[i] * (len(dmap[t]))) ix.append(dmap[t][randi]) - wi.append(len(dmap[t])) + wi.append(bin_prior[t]) ti.append(i) # If we had really poor efficiency at finding a point, we should @@ -535,6 +544,7 @@ def make_init(): ix = numpy.resize(numpy.array(ix, dtype=int), vsamples) self.sample_idx = ix self.precalc_antenna_factors = fp, fc, dtc + resize_factor = len(ti) / vsamples ra = ra[ix] dec = dec[ix] @@ -551,7 +561,10 @@ def make_init(): tc = tct + iref[ti] * snr.delta_t + float(sref.start_time) - dtc[ifos[0]] # Update the current proposed times and the marginalization values - logw_sky = mcweight[ti] + numpy.log(wi) + # There's an overall normalization here which may introduce a constant + # factor at the moment. + logw_sky = -mcweight[ti] + numpy.log(wi) - numpy.log(resize_factor) + self.marginalize_vector_params['tc'] = tc self.marginalize_vector_params['ra'] = ra self.marginalize_vector_params['dec'] = dec @@ -559,9 +572,8 @@ def make_init(): if self._current_params is not None: # Update the importance weights for each vector sample - logw = self.marginalize_vector_weights + logw_sky self._current_params.update(self.marginalize_vector_params) - self.marginalize_vector_weights = logw - logsumexp(logw) + self.marginalize_vector_weights += logw_sky return self.marginalize_vector_params