Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax templates #726

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/libtoast/src/toast_tod_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void toast::filter_polynomial(int64_t order, size_t n, uint8_t * flags,
auto log = toast::Logger::get();
std::ostringstream o;
o << "DGELLS: " << ngood << "/" << scanlen << " good samples, order " <<
norder;
norder;
o << " failed with info " << info;
log.error(o.str().c_str(), TOAST_HERE());
throw std::runtime_error(o.str().c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/toast/_libtoast/ops_filterbin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void build_template_covariance(std::vector <int64_t> & starts,
// templates.templates,
// fgood,
// invcov,
// )
// )
//
// void build_template_covariance(
// py::buffer starts,
Expand Down
94 changes: 64 additions & 30 deletions src/toast/_libtoast/template_offset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ void init_template_offset(py::module & m) {
);
int64_t n_amp = temp_shape[0];

uint8_t * raw_amplitude_flags = extract_buffer <uint8_t> (
amplitude_flags, "amplitude_flags", 1, temp_shape, {n_amp}
);

double * raw_det_data = extract_buffer <double> (
det_data, "det_data", 2, temp_shape, {-1, -1}
);
Expand Down Expand Up @@ -63,6 +67,7 @@ void init_template_offset(py::module & m) {
double * dev_det_data = omgr.device_ptr(raw_det_data);
Interval * dev_intervals = omgr.device_ptr(raw_intervals);
double * dev_amplitudes = omgr.device_ptr(raw_amplitudes);
uint8_t * dev_amp_flags = omgr.device_ptr(raw_amplitude_flags);

// Calculate the maximum interval size on the CPU
int64_t max_interval_size = 0;
Expand All @@ -86,6 +91,7 @@ void init_template_offset(py::module & m) {
# pragma omp target teams distribute parallel for collapse(2) \
is_device_ptr( \
dev_amplitudes, \
dev_amp_flags, \
dev_det_data, \
dev_intervals \
)
Expand All @@ -105,7 +111,9 @@ void init_template_offset(py::module & m) {
(adjusted_isamp - dev_intervals[iview].first) /
step_length
);
dev_det_data[d] += dev_amplitudes[amp];
if (dev_amp_flags[amp] == 0) {
dev_det_data[d] += dev_amplitudes[amp];
}
}
}
}
Expand All @@ -123,7 +131,9 @@ void init_template_offset(py::module & m) {
int64_t amp = amp_offset + amp_view_off[iview] + (int64_t)(
(isamp - raw_intervals[iview].first) / step_length
);
raw_det_data[d] += raw_amplitudes[amp];
if (raw_amplitude_flags[amp] == 0) {
raw_det_data[d] += raw_amplitudes[amp];
}
}
}
}
Expand Down Expand Up @@ -160,6 +170,10 @@ void init_template_offset(py::module & m) {
);
int64_t n_amp = temp_shape[0];

uint8_t * raw_amplitude_flags = extract_buffer <uint8_t> (
amplitude_flags, "amplitude_flags", 1, temp_shape, {n_amp}
);

double * raw_det_data = extract_buffer <double> (
det_data, "det_data", 2, temp_shape, {-1, -1}
);
Expand Down Expand Up @@ -198,6 +212,7 @@ void init_template_offset(py::module & m) {
uint8_t * dev_det_flags = omgr.device_ptr(raw_det_flags);
Interval * dev_intervals = omgr.device_ptr(raw_intervals);
double * dev_amplitudes = omgr.device_ptr(raw_amplitudes);
uint8_t * dev_amp_flags = omgr.device_ptr(raw_amplitude_flags);

// Calculate the maximum interval size on the CPU
int64_t max_interval_size = 0;
Expand All @@ -224,6 +239,7 @@ void init_template_offset(py::module & m) {
# pragma omp target teams distribute collapse(2) \
is_device_ptr( \
dev_amplitudes, \
dev_amp_flags, \
dev_det_data, \
dev_det_flags, \
dev_intervals \
Expand All @@ -249,25 +265,27 @@ void init_template_offset(py::module & m) {
dev_intervals[iview].last - adjusted_isamp + 1
);

// Reduce on a chunk of `step_length` samples.
double contrib = 0.0;
# pragma omp parallel for reduction(+ : contrib)
for (int64_t i = 0; i < max_step_length; i++) {
int64_t d = data_index * n_samp + adjusted_isamp + i;
if (use_flags) {
int64_t f = flag_index * n_samp + adjusted_isamp + i;
uint8_t check = dev_det_flags[f] & flag_mask;
if (check == 0) {
int64_t amp = amp_offset + amp_view_off[iview] +
(int64_t)(isamp / step_length);

if (dev_amp_flags[amp] == 0) {
// Reduce on a chunk of `step_length` samples.
double contrib = 0.0;
# pragma omp parallel for reduction(+ : contrib)
for (int64_t i = 0; i < max_step_length; i++) {
int64_t d = data_index * n_samp + adjusted_isamp + i;
if (use_flags) {
int64_t f = flag_index * n_samp + adjusted_isamp + i;
uint8_t check = dev_det_flags[f] & flag_mask;
if (check == 0) {
contrib += dev_det_data[d];
}
} else {
contrib += dev_det_data[d];
}
} else {
contrib += dev_det_data[d];
}
dev_amplitudes[amp] += contrib;
}

int64_t amp = amp_offset + amp_view_off[iview] +
(int64_t)(isamp / step_length);
dev_amplitudes[amp] += contrib;
}
}
}
Expand All @@ -285,18 +303,20 @@ void init_template_offset(py::module & m) {
int64_t amp = amp_offset + amp_view_off[iview] + (int64_t)(
(isamp - raw_intervals[iview].first) / step_length
);
double contrib = 0.0;
if (use_flags) {
int64_t f = flag_index * n_samp + isamp;
uint8_t check = raw_det_flags[f] & flag_mask;
if (check == 0) {
if (raw_amplitude_flags[amp] == 0) {
double contrib = 0.0;
if (use_flags) {
int64_t f = flag_index * n_samp + isamp;
uint8_t check = raw_det_flags[f] & flag_mask;
if (check == 0) {
contrib = raw_det_data[d];
}
} else {
contrib = raw_det_data[d];
}
} else {
contrib = raw_det_data[d];
#pragma omp atomic update
raw_amplitudes[amp] += contrib;
}
#pragma omp atomic update
raw_amplitudes[amp] += contrib;
}
}
}
Expand Down Expand Up @@ -324,6 +344,10 @@ void init_template_offset(py::module & m) {
);
int64_t n_amp = temp_shape[0];

uint8_t * raw_amplitude_flags = extract_buffer <uint8_t> (
amplitude_flags, "amplitude_flags", 1, temp_shape, {n_amp}
);

double * raw_amp_out = extract_buffer <double> (
amplitudes_out, "amplitudes_out", 1, temp_shape, {n_amp}
);
Expand All @@ -338,27 +362,37 @@ void init_template_offset(py::module & m) {
double * dev_amp_in = omgr.device_ptr(raw_amp_in);
double * dev_amp_out = omgr.device_ptr(raw_amp_out);
double * dev_offset_var = omgr.device_ptr(raw_offset_var);
uint8_t * dev_amp_flags = omgr.device_ptr(raw_amplitude_flags);

# pragma omp target data map(to : n_amp)
{
# pragma omp target teams distribute parallel for \
is_device_ptr( \
dev_amp_in, \
dev_amp_out, \
dev_amp_flags, \
dev_offset_var \
)
for (int64_t iamp = 0; iamp < n_amp; iamp++) {
dev_amp_out[iamp] = dev_amp_in[iamp];
dev_amp_out[iamp] *= dev_offset_var[iamp];
if (dev_amp_flags[iamp] == 0) {
dev_amp_out[iamp] = dev_amp_in[iamp];
dev_amp_out[iamp] *= dev_offset_var[iamp];
} else {
dev_amp_out[iamp] = 0.0;
}
}
}

#endif // ifdef HAVE_OPENMP_TARGET
} else {
#pragma omp parallel for default(shared)
for (int64_t iamp = 0; iamp < n_amp; iamp++) {
raw_amp_out[iamp] = raw_amp_in[iamp];
raw_amp_out[iamp] *= raw_offset_var[iamp];
if (raw_amplitude_flags[iamp] == 0) {
raw_amp_out[iamp] = raw_amp_in[iamp];
raw_amp_out[iamp] *= raw_offset_var[iamp];
} else {
raw_amp_out[iamp] = 0.0;
}
}
}
return;
Expand Down
8 changes: 3 additions & 5 deletions src/toast/ops/cadence_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
if check < 0:
raise traitlets.TraitError("Det flag mask should be a positive integer")
return check

@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
check = proposal["value"]
Expand Down Expand Up @@ -181,9 +181,7 @@ def _exec(self, data, detectors=None, **kwargs):
buf[:, :] = False
for obs in data.obs:
obs_data = data.select(obs_uid=obs.uid)
dets = obs.select_local_detectors(
detectors, flagmask=self.det_mask
)
dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
times = obs.shared[self.times].data
days = to_MJD(times).astype(int)
if days[0] >= day_stop or days[-1] < day_start:
Expand Down
4 changes: 1 addition & 3 deletions src/toast/ops/conviqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,7 @@ def _get_all_detectors(self, data, detectors):
my_dets = set()
for obs in data.obs:
# Get the detectors we are using for this observation
obs_dets = obs.select_local_detectors(
detectors, flagmask=self.det_mask
)
obs_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
for det in obs_dets:
my_dets.add(det)
# Make sure detector data output exists
Expand Down
4 changes: 2 additions & 2 deletions src/toast/ops/crosslinking.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
if check < 0:
raise traitlets.TraitError("Det flag mask should be a positive integer")
return check

@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
check = proposal["value"]
Expand Down
8 changes: 3 additions & 5 deletions src/toast/ops/demodulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
if check < 0:
raise traitlets.TraitError("Det flag mask should be a positive integer")
return check

@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
check = proposal["value"]
Expand Down Expand Up @@ -219,9 +219,7 @@ def _exec(self, data, detectors=None, **kwargs):
timer.start()
for obs in demodulate_obs:
# Get the detectors which are not cut with per-detector flags
local_dets = obs.select_local_detectors(
detectors, flagmask=self.det_mask
)
local_dets = obs.select_local_detectors(detectors, flagmask=self.det_mask)
if obs.comm.comm_group is None:
all_dets = local_dets
else:
Expand Down
8 changes: 3 additions & 5 deletions src/toast/ops/filterbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,14 +635,14 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
if check < 0:
raise traitlets.TraitError("Det flag mask should be a positive integer")
return check

@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
check = proposal["value"]
Expand Down Expand Up @@ -1624,9 +1624,7 @@ def _bin_map(self, data, detectors, filtered):
save_pointing=self.binning.full_pointing,
)
cov.apply(data, detectors=detectors)
log.info_rank(
f"Binned covariance and hits in", comm=self.comm, timer=timer
)
log.info_rank(f"Binned covariance and hits in", comm=self.comm, timer=timer)

self.binning.apply(data, detectors=detectors)
log.info_rank(f"Binned signal in", comm=self.comm, timer=timer)
Expand Down
4 changes: 1 addition & 3 deletions src/toast/ops/groundfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ def _exec(self, data, detectors=None, **kwargs):
last_cov = None
last_rcond = None

for det in obs.select_local_detectors(
detectors, flagmask=self.det_mask
):
for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
if data.comm.group_rank == 0:
msg = f"{log_prefix} OpGroundFilter: " f"Processing detector {det}"
log.verbose(msg)
Expand Down
4 changes: 1 addition & 3 deletions src/toast/ops/hwpfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def _exec(self, data, detectors=None, **kwargs):
last_invcov = None
last_cov = None
last_rcond = None
for det in obs.select_local_detectors(
detectors, flagmask=self.det_mask
):
for det in obs.select_local_detectors(detectors, flagmask=self.det_mask):
if data.comm.group_rank == 0:
msg = f"{log_prefix} OpHWPFilter: " f"Processing detector {det}"
log.verbose(msg)
Expand Down
2 changes: 1 addition & 1 deletion src/toast/ops/madam.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _prepare(self, params, data, detectors):
# for plocal in pdets:
# for d in plocal:
# obs_dets.add(d)
# obs_dets = ob.comm.comm_group.bcast(obs_dets, root=0)
# obs_dets = ob.comm.comm_group.bcast(obs_dets, root=0)
# else:
# obs_dets = set(local_dets)
# all_dets.update(obs_dets)
Expand Down
2 changes: 1 addition & 1 deletion src/toast/ops/mapmaker_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _check_flag_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Flag mask should be a positive integer")
return check

@traitlets.validate("shared_flag_mask")
def _check_shared_mask(self, proposal):
check = proposal["value"]
Expand Down
6 changes: 3 additions & 3 deletions src/toast/ops/scan_map/scan_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
Expand Down Expand Up @@ -269,7 +269,7 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
Expand Down Expand Up @@ -404,7 +404,7 @@ def _check_det_mask(self, proposal):
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("det_flag_mask")
def _check_det_flag_mask(self, proposal):
check = proposal["value"]
Expand Down
Loading
Loading