Skip to content

Commit

Permalink
Fix flagging of noise model fits (#698)
Browse files Browse the repository at this point in the history
* Fix flagging of noise model fits

* Some additional comments

* Add comment for error helper function
  • Loading branch information
tskisner authored Sep 13, 2023
1 parent 485736a commit 4de4bc3
Showing 1 changed file with 83 additions and 45 deletions.
128 changes: 83 additions & 45 deletions src/toast/ops/noise_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class FitNoiseModel(Operator):
None, allow_none=True, help="Create a new noise model with this name"
)

det_flag_mask = Int(
defaults.det_mask_invalid,
help="Bit mask value for excluding bad detectors",
)

bad_fit_mask = Int(
defaults.det_mask_processing, help="Bit mask to raise for bad fits"
)

f_min = Quantity(1.0e-5 * u.Hz, help="Low-frequency rolloff of model in the fit")

white_noise_min = Quantity(
Expand Down Expand Up @@ -190,19 +199,27 @@ def _exec(self, data, detectors=None, **kwargs):
nse_alpha = dict()
nse_NET = dict()
nse_indx = dict()
for det in ob.local_detectors:
dets = ob.select_local_detectors(detectors, flagmask=self.det_flag_mask)
if len(dets) == 0:
# Nothing to do for this observation
continue
for det in dets:
freqs = in_model.freq(det)
in_psd = in_model.psd(det)
cur_flag = ob.local_detector_flags[det]
props = self._fit_log_psd(freqs, in_psd, guess=params)
if props["fit_result"].success:
# This was a good fit
params = props["fit_result"].x
else:
params = None
msg = f"FitNoiseModel observation {ob.name}, det {det} failed, "
msg += f"using white noise with NET = {props['NET']}"
log.warning(msg)
msg = f" Best Result = {props['fit_result']}"
log.verbose(msg)
new_flag = cur_flag | self.bad_fit_mask
ob.update_local_detector_flags({det: new_flag})
nse_indx[det] = in_model.index(det)
nse_rate[det] = 2.0 * freqs[-1]
nse_fmin[det] = props["fmin"]
Expand Down Expand Up @@ -408,6 +425,18 @@ def _fit_log_jac(self, x, *args, **kwargs):
) / fmalpha
return J

def _get_err_ret(self, psd_unit):
# Internal function to build a fake return result
# when the fitting fails for some reason.
eret = dict()
eret["fit_result"] = types.SimpleNamespace()
eret["fit_result"].success = False
eret["NET"] = 0.0 * np.sqrt(1.0 * psd_unit)
eret["fmin"] = 0.0 * u.Hz
eret["fknee"] = 0.0 * u.Hz
eret["alpha"] = 0.0
return eret

def _fit_log_psd(self, freqs, data, guess=None):
"""Perform a log-space fit to model PSD parameters.
Expand All @@ -424,16 +453,6 @@ def _fit_log_psd(self, freqs, data, guess=None):
psd_unit = data.unit
ret = dict()

def _get_err_ret():
eret = dict()
eret["fit_result"] = types.SimpleNamespace()
eret["fit_result"].success = False
eret["NET"] = 0.0 * np.sqrt(1.0 * psd_unit)
eret["fmin"] = 0.0 * u.Hz
eret["fknee"] = 0.0 * u.Hz
eret["alpha"] = 0.0
return eret

# We cut the lowest frequency bin value, and any leading negative values,
# since these are usually due to poor estimation. If the user has specified
# a maximum frequency for the white noise plateau, then we also stop our
Expand All @@ -447,7 +466,7 @@ def _get_err_ret():
if n_skip == n_raw:
msg = f"All {n_raw} PSD values were negative. Giving up."
log.warning(msg)
ret = _get_err_ret()
ret = self._get_err_ret(psd_unit)
return ret

n_trim = 0
Expand All @@ -460,19 +479,27 @@ def _get_err_ret():
if n_skip + n_trim >= n_raw:
msg = f"All {n_raw} PSD values either negative or above plateau."
log.warning(msg)
ret = _get_err_ret()
ret = self._get_err_ret(psd_unit)
return ret

input_freqs = raw_freqs[n_skip : n_raw - n_trim]
input_data = raw_data[n_skip : n_raw - n_trim]
# Force all points to be positive
bad = input_data <= 0
good = input_data > 0
if np.count_nonzero(good) == 0:
# All PSD values zero, must be flagged
msg = f"All PSD values zero, skipping fit."
log.warning(msg)
ret = self._get_err_ret(psd_unit)
return ret
bad = np.logical_not(good)
n_bad = np.count_nonzero(bad)
if n_bad > 0:
msg = "Some PSDs have negative values. Consider changing "
msg += "noise estimation parameters."
log.warning(msg)
input_data[bad] = 1.0e-6
good_min = np.min(input_data[good])
input_data[bad] = 1.0e-6 * good_min
input_log_data = np.log(input_data)

# print(f"FIT: input {input_freqs} {input_data} {input_log_data}")
Expand Down Expand Up @@ -502,23 +529,28 @@ def _get_err_ret():

# print(f"FIT: starting guess = {x_0}")

result = least_squares(
self._fit_log_fun,
x_0,
jac=self._fit_log_jac,
bounds=bounds,
xtol=self.least_squares_xtol,
gtol=self.least_squares_gtol,
ftol=self.least_squares_ftol,
max_nfev=500,
verbose=0,
kwargs={
"freqs": input_freqs,
"logdata": input_log_data,
"fmin": raw_fmin,
"net": net,
},
)
try:
result = least_squares(
self._fit_log_fun,
x_0,
jac=self._fit_log_jac,
bounds=bounds,
xtol=self.least_squares_xtol,
gtol=self.least_squares_gtol,
ftol=self.least_squares_ftol,
max_nfev=500,
verbose=0,
kwargs={
"freqs": input_freqs,
"logdata": input_log_data,
"fmin": raw_fmin,
"net": net,
},
)
except Exception:
log.verbose(f"PSD fit raised exception, skipping")
ret = self._get_err_ret(psd_unit)
return ret

# print(f"FIT: [{n_skip}:{n_raw}-{n_trim}] {result}")

Expand All @@ -533,8 +565,7 @@ def _get_err_ret():
ret["alpha"] = 1.0

# print(f"FIT ret = {ret}")

log.verbose(f"PSD fit NET={net}, bounds={bounds}, guess={x_0}, result={result}")
# print(f"PSD fit NET={net}, bounds={bounds}, guess={x_0}, result={result}")
return ret

def _finalize(self, data, **kwargs):
Expand Down Expand Up @@ -567,6 +598,11 @@ class FlagNoiseFit(Operator):
)

det_flag_mask = Int(
defaults.det_mask_invalid | defaults.det_mask_processing,
help="Bit mask for considering detectors",
)

outlier_flag_mask = Int(
defaults.det_mask_processing, help="Bit mask to raise flags with"
)

Expand All @@ -586,7 +622,7 @@ def _exec(self, data, detectors=None, **kwargs):
raise RuntimeError("You must set det_flags before calling exec()")

for obs in data.obs:
dets = obs.select_local_detectors(detectors)
dets = obs.select_local_detectors(detectors, flagmask=self.det_flag_mask)
if len(dets) == 0:
# Nothing to do for this observation
continue
Expand Down Expand Up @@ -629,6 +665,9 @@ def _exec(self, data, detectors=None, **kwargs):
net_std = None
fknee_mean = None
fknee_std = None
# If the noise model came from fitting, then detectors with a bad
# fit are already flagged in addition to NET being set to zero.
# This check is just an additional safeguard.
good_fit = local_net > 0.0
if obs.comm_col is None:
all_net = np.array(local_net[good_fit])
Expand Down Expand Up @@ -659,22 +698,21 @@ def _exec(self, data, detectors=None, **kwargs):

# Flag outlier detectors

local_cut_NET = list()
local_cut_fknee = list()
new_flags = dict()
for idet, det in enumerate(dets):
cur_flag = obs.local_detector_flags[det]
if not good_fit[idet]:
msg = f"obs {obs.name}, det {det} has NET=0 (bad model fit)"
log.info(msg)
obs.detdata[self.det_flags][det, :] |= self.det_flag_mask
new_flags[det] = self.det_flag_mask
log.debug(msg)
obs.detdata[self.det_flags][det, :] |= self.outlier_flag_mask
new_flags[det] = cur_flag | self.outlier_flag_mask
continue
if np.absolute(local_net[idet] - net_mean) > net_std * self.sigma_NET:
msg = f"obs {obs.name}, det {det} has NET {local_net[idet]} "
msg += f" that is > {self.sigma_NET} x {net_std} from {net_mean}"
log.info(msg)
obs.detdata[self.det_flags][det, :] |= self.det_flag_mask
new_flags[det] = self.det_flag_mask
log.debug(msg)
obs.detdata[self.det_flags][det, :] |= self.outlier_flag_mask
new_flags[det] = cur_flag | self.outlier_flag_mask
if self.sigma_fknee is not None:
if (
np.absolute(local_fknee[idet] - fknee_mean)
Expand All @@ -683,9 +721,9 @@ def _exec(self, data, detectors=None, **kwargs):
msg = f"obs {obs.name}, det {det} has fknee "
msg += f"{local_fknee[idet]} that is > {self.sigma_fknee} "
msg += f"x {fknee_std} from {fknee_mean}"
log.info(msg)
log.debug(msg)
obs.detdata[self.det_flags][det, :] |= self.det_flag_mask
new_flags[det] = self.det_flag_mask
new_flags[det] = cur_flag | self.det_flag_mask
obs.update_local_detector_flags(new_flags)

def _finalize(self, data, **kwargs):
Expand Down

0 comments on commit 4de4bc3

Please sign in to comment.