Skip to content

Commit

Permalink
Simplify some intrp target detail functions
Browse files Browse the repository at this point in the history
  • Loading branch information
knelli2 committed Oct 18, 2023
1 parent 80adcea commit 59dcaa0
Showing 1 changed file with 50 additions and 62 deletions.
112 changes: 50 additions & 62 deletions src/ParallelAlgorithms/Interpolation/InterpolationTargetDetail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,82 +94,70 @@ double get_temporal_id_value(const T& id) {
}
}

CREATE_IS_CALLABLE(apply)
CREATE_IS_CALLABLE_V(apply)

// apply_callback accomplishes the overload for the
// two signatures of callback functions.
// Uses SFINAE on return type.
template <typename T, typename DbTags, typename Metavariables,
typename TemporalId>
auto apply_callback(
const gsl::not_null<db::DataBox<DbTags>*> box,
const gsl::not_null<Parallel::GlobalCache<Metavariables>*> cache,
const TemporalId& temporal_id)
-> decltype(T::post_interpolation_callback::apply(box, cache, temporal_id),
bool()) {
return T::post_interpolation_callback::apply(box, cache, temporal_id);
}

template <typename T, typename DbTags, typename Metavariables,
typename TemporalId>
auto apply_callback(
bool apply_callback(
const gsl::not_null<db::DataBox<DbTags>*> box,
const gsl::not_null<Parallel::GlobalCache<Metavariables>*> cache,
const TemporalId& temporal_id)
-> decltype(T::post_interpolation_callback::apply(*box, *cache,
temporal_id),
bool()) {
T::post_interpolation_callback::apply(*box, *cache, temporal_id);
// For the simpler callback function, we will always clean up volume data, so
// we return true here.
return true;
const TemporalId& temporal_id) {
if constexpr (is_apply_callable_v<typename T::post_interpolation_callback,
decltype(*box), decltype(*cache),
decltype(temporal_id)>) {
T::post_interpolation_callback::apply(*box, *cache, temporal_id);

// For the simpler callback function, we will always clean up volume data,
// so we return true here.
return true;
} else {
return T::post_interpolation_callback::apply(box, cache, temporal_id);
}
}

CREATE_HAS_STATIC_MEMBER_VARIABLE(fill_invalid_points_with)
CREATE_HAS_STATIC_MEMBER_VARIABLE_V(fill_invalid_points_with)

// Fills invalid points with some constant value.
template <typename InterpolationTargetTag, typename TemporalId, typename DbTags,
Requires<not has_fill_invalid_points_with_v<
typename InterpolationTargetTag::post_interpolation_callback>> =
nullptr>
void fill_invalid_points(const gsl::not_null<db::DataBox<DbTags>*> /*box*/,
const TemporalId& /*temporal_id*/) {}

template <typename InterpolationTargetTag, typename TemporalId, typename DbTags,
Requires<has_fill_invalid_points_with_v<
typename InterpolationTargetTag::post_interpolation_callback>> =
nullptr>
template <typename InterpolationTargetTag, typename TemporalId, typename DbTags>
void fill_invalid_points(const gsl::not_null<db::DataBox<DbTags>*> box,
const TemporalId& temporal_id) {
const auto& invalid_indices =
db::get<Tags::IndicesOfInvalidInterpPoints<TemporalId>>(*box);
if (invalid_indices.find(temporal_id) != invalid_indices.end() and
not invalid_indices.at(temporal_id).empty()) {
db::mutate<Tags::IndicesOfInvalidInterpPoints<TemporalId>,
Tags::InterpolatedVars<InterpolationTargetTag, TemporalId>>(
[&temporal_id](
const gsl::not_null<
std::unordered_map<TemporalId, std::unordered_set<size_t>>*>
indices_of_invalid_points,
const gsl::not_null<std::unordered_map<
TemporalId, Variables<typename InterpolationTargetTag::
vars_to_interpolate_to_target>>*>
vars_dest_all_times) {
auto& vars_dest = vars_dest_all_times->at(temporal_id);
const size_t npts_dest = vars_dest.number_of_grid_points();
const size_t nvars = vars_dest.number_of_independent_components;
for (auto index : indices_of_invalid_points->at(temporal_id)) {
for (size_t v = 0; v < nvars; ++v) {
// clang-tidy: no pointer arithmetic
vars_dest.data()[index + v * npts_dest] = // NOLINT
InterpolationTargetTag::post_interpolation_callback::
fill_invalid_points_with;
if constexpr (has_fill_invalid_points_with_v<
typename InterpolationTargetTag::
post_interpolation_callback>) {
const auto& invalid_indices =
db::get<Tags::IndicesOfInvalidInterpPoints<TemporalId>>(*box);
if (invalid_indices.find(temporal_id) != invalid_indices.end() and
not invalid_indices.at(temporal_id).empty()) {
db::mutate<Tags::IndicesOfInvalidInterpPoints<TemporalId>,
Tags::InterpolatedVars<InterpolationTargetTag, TemporalId>>(
[&temporal_id](
const gsl::not_null<
std::unordered_map<TemporalId, std::unordered_set<size_t>>*>
indices_of_invalid_points,
const gsl::not_null<std::unordered_map<
TemporalId, Variables<typename InterpolationTargetTag::
vars_to_interpolate_to_target>>*>
vars_dest_all_times) {
auto& vars_dest = vars_dest_all_times->at(temporal_id);
const size_t npts_dest = vars_dest.number_of_grid_points();
const size_t nvars = vars_dest.number_of_independent_components;
for (auto index : indices_of_invalid_points->at(temporal_id)) {
for (size_t v = 0; v < nvars; ++v) {
// clang-tidy: no pointer arithmetic
vars_dest.data()[index + v * npts_dest] = // NOLINT
InterpolationTargetTag::post_interpolation_callback::
fill_invalid_points_with;
}
}
}
// Further functions may test if there are invalid points.
// Clear the invalid points now, since we have filled them.
indices_of_invalid_points->erase(temporal_id);
},
box);
// Further functions may test if there are invalid points.
// Clear the invalid points now, since we have filled them.
indices_of_invalid_points->erase(temporal_id);
},
box);
}
}
}

Expand Down

0 comments on commit 59dcaa0

Please sign in to comment.