Skip to content

Commit

Permalink
Pass amplitude flags into Offset template kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
tskisner committed Sep 13, 2023
1 parent 67b2474 commit 81254cb
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/toast/_libtoast/template_offset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void init_template_offset(py::module & m) {
int64_t amp_offset,
py::buffer n_amp_views,
py::buffer amplitudes,
py::buffer amplitude_flags,
int32_t data_index,
py::buffer det_data,
py::buffer intervals,
Expand Down Expand Up @@ -143,6 +144,7 @@ void init_template_offset(py::module & m) {
int64_t amp_offset,
py::buffer n_amp_views,
py::buffer amplitudes,
py::buffer amplitude_flags,
py::buffer intervals,
bool use_accel
) {
Expand Down Expand Up @@ -306,6 +308,7 @@ void init_template_offset(py::module & m) {
"template_offset_apply_diag_precond", [](
py::buffer offset_var,
py::buffer amplitudes_in,
py::buffer amplitude_flags,
py::buffer amplitudes_out,
bool use_accel
) {
Expand Down
9 changes: 9 additions & 0 deletions src/toast/templates/offset/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def offset_add_to_signal(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
data_index,
det_data,
intervals,
Expand All @@ -46,6 +47,7 @@ def offset_add_to_signal(
amp_offset (int64): The first amplitude for this detector.
n_amp_views (array, int): The number of amplitudes for each interval.
amplitudes (array, double): The amplitude data.
amplitude_flags (array, uint8): The amplitude flags.
data_index (int): The detector to process.
det_data (array, double): The array of data for all detectors.
intervals (array, Interval): The intervals to process.
Expand All @@ -60,6 +62,7 @@ def offset_add_to_signal(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
data_index,
det_data,
intervals,
Expand All @@ -78,6 +81,7 @@ def offset_project_signal(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
intervals,
use_accel=False,
):
Expand All @@ -97,6 +101,7 @@ def offset_project_signal(
amp_offset (int64): The first amplitude for this detector.
n_amp_views (array, int): The number of amplitudes for each interval.
amplitudes (array, double): The amplitude data.
amplitude_flags (array, uint8): The amplitude flags.
intervals (array, Interval): The intervals to process.
use_accel (bool): Whether to use the accelerator for this call (if supported).
Expand All @@ -114,6 +119,7 @@ def offset_project_signal(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
intervals,
use_accel,
)
Expand All @@ -123,13 +129,15 @@ def offset_project_signal(
def offset_apply_diag_precond(
offset_var,
amplitudes_in,
amplitude_flags,
amplitudes_out,
use_accel=False,
):
"""
Args:
offset_var (array, double): The variance weight to apply to each amplitude.
amplitudes_in (array, double): Input amplitude data
amplitude_flags (array, uint8): The amplitude flags.
amplitudes_out (array, double): Output amplitude data
use_accel (bool): Whether to use the accelerator for this call (if supported).
Expand All @@ -140,6 +148,7 @@ def offset_apply_diag_precond(
return libtoast_offset_apply_diag_precond(
offset_var,
amplitudes_in,
amplitude_flags,
amplitudes_out,
use_accel,
)
Expand Down
4 changes: 3 additions & 1 deletion src/toast/templates/offset/kernels_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def offset_add_to_signal_jax(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
data_index,
det_data,
intervals,
Expand Down Expand Up @@ -271,6 +272,7 @@ def offset_project_signal_jax(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
intervals,
use_accel,
):
Expand Down Expand Up @@ -342,7 +344,7 @@ def offset_apply_diag_precond_inner(offset_var, amplitudes_in, amplitudes_out):


@kernel(impl=ImplementationType.JAX, name="offset_apply_diag_precond")
def offset_apply_diag_precond_jax(offset_var, amplitudes_in, amplitudes_out, use_accel):
def offset_apply_diag_precond_jax(offset_var, amplitudes_in, amplitude_flags, amplitudes_out, use_accel):
"""
Simple multiplication.
Expand Down
16 changes: 12 additions & 4 deletions src/toast/templates/offset/kernels_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def offset_add_to_signal_numpy(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
data_index,
det_data,
intervals,
Expand All @@ -39,8 +40,10 @@ def offset_add_to_signal_numpy(
for interval, view_offset in zip(intervals, n_amp_views):
samples = slice(interval.first, interval.last + 1, 1)
sampidx = np.arange(0, interval.last - interval.first + 1, dtype=np.int64)
amp_vals = np.array([amplitudes[offset + x] for x in (sampidx // step_length)])
det_data[data_index, samples] += amp_vals
amp_idx = sampidx // step_length
amp_vals = np.array([amplitudes[offset + x] for x in amp_idx])
amp_flags = np.array([amplitude_flags[offset + x] for x in amp_idx])
det_data[data_index, samples] += amp_vals[amp_flags == 0]
offset += view_offset


Expand All @@ -55,6 +58,7 @@ def offset_project_signal_numpy(
amp_offset,
n_amp_views,
amplitudes,
amplitude_flags,
intervals,
use_accel,
):
Expand Down Expand Up @@ -94,6 +98,7 @@ def offset_project_signal_numpy(
((flag_data[flag_index] & flag_mask) == 0), dtype=np.float64
)[samples]
ddata *= det_data[data_index][samples]
ddata[amplitude_flags[ampidx] != 0] = 0.0
# updates amplitude
# using np.add to insure atomicity
np.add.at(amplitudes, ampidx, ddata)
Expand All @@ -102,7 +107,7 @@ def offset_project_signal_numpy(

@kernel(impl=ImplementationType.NUMPY, name="offset_apply_diag_precond")
def offset_apply_diag_precond_numpy(
offset_var, amplitudes_in, amplitudes_out, use_accel
offset_var, amplitudes_in, amplitude_flags, amplitudes_out, use_accel
):
"""
Args:
Expand All @@ -114,7 +119,10 @@ def offset_apply_diag_precond_numpy(
Returns:
None (the result is put in amplitudes_out).
"""
amplitudes_out[:] = amplitudes_in * offset_var
good = amplitude_flags != 0
bad = np.logical_not(good)
amplitudes_out[good] = amplitudes_in[good] * offset_var[good]
amplitudes_out[bad] = 0.0


# To test:
Expand Down
8 changes: 7 additions & 1 deletion src/toast/templates/offset/offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def _add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs):
amp_offset,
n_amp_views,
amplitudes.local,
amplitudes.local_flags,
det_indx[0],
ob.detdata[self.det_data].data,
ob.intervals[self.view].data,
Expand Down Expand Up @@ -628,6 +629,7 @@ def _project_signal(self, detector, amplitudes, use_accel=None, **kwargs):
amp_offset,
n_amp_views,
amplitudes.local,
amplitudes.local_flags,
ob.intervals[self.view].data,
impl=implementation,
use_accel=use_accel,
Expand Down Expand Up @@ -670,9 +672,12 @@ def _add_prior(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs):
if det in self._filters[iob]:
# There is some contribution from this detector
amps_out[:] += scipy.signal.convolve(
amps_in, self._filters[iob][det][ivw], mode="same"
amps_in, self._filters[iob][det][ivw], mode="same", method="direct"
)
amps_out[amp_flags_in != 0] = 0.0
print(f"DBG filter: {self._filters[iob][det][ivw]}")
print(f"DBG amps_in: {amps_in}")
print(f"DBG amps_out: {amps_out}", flush=True)
else:
amps_out[:] = 0.0
offset += n_amp_view
Expand Down Expand Up @@ -743,6 +748,7 @@ def _apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs
offset_apply_diag_precond(
self._offsetvar,
amplitudes_in.local,
amplitudes_in.local_flags,
amplitudes_out.local,
impl=implementation,
use_accel=use_accel,
Expand Down

0 comments on commit 81254cb

Please sign in to comment.