Skip to content

Commit

Permalink
Add transform for History
Browse files Browse the repository at this point in the history
  • Loading branch information
wthrowe committed Oct 30, 2023
1 parent 860e258 commit d08143f
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 0 deletions.
88 changes: 88 additions & 0 deletions src/Time/History.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,4 +1008,92 @@ template <typename Vars>
std::ostream& operator<<(std::ostream& os, const History<Vars>& history) {
return history.print(os);
}

/// \ingroup TimeSteppersGroup
/// Initialize a History object based on the contents of another,
/// applying a transformation to each value and derivative.
///
/// The transformation functions can either take a value from the
/// source history and return a value for the destination history or
/// take a `gsl::not_null` value from the destination history and a
/// value from the source history to initialize it with. For the sake
/// of implementation simplicity, either both transformers must mutate
/// or both must produce values.
///
/// An overload applying the same transformation to the values and
/// derivatives is provided for convenience.
/// @{
template <typename DestVars, typename SourceVars, typename ValueTransformer,
typename DerivativeTransformer>
void transform(const gsl::not_null<History<DestVars>*> dest,
const History<SourceVars>& source,
ValueTransformer&& value_transformer,
DerivativeTransformer&& derivative_transformer) {
dest->clear_substeps();
dest->clear();
dest->integration_order(source.integration_order());
if (source.empty()) {
return;
}
auto pre_substep_end = source.end();
if (not source.substeps().empty() and
source.back().time_step_id > source.substeps().back().time_step_id) {
--pre_substep_end;
}

const auto transform_record =
[&derivative_transformer, &dest, &value_transformer](
const typename History<SourceVars>::value_type& record) {
if constexpr (std::is_invocable_v<ValueTransformer,
gsl::not_null<DestVars*>,
const SourceVars&>) {
if (record.value.has_value()) {
dest->insert_in_place(
record.time_step_id,
[&](const auto result) {
value_transformer(result, *record.value);
},
[&](const auto result) {
derivative_transformer(result, record.derivative);
});
} else {
dest->insert_in_place(
record.time_step_id, History<DestVars>::no_value,
[&](const auto result) {
derivative_transformer(result, record.derivative);
});
}
} else {
static_assert(
std::is_invocable_v<ValueTransformer, const SourceVars&>,
"Transform function must either be callable to mutate entries "
"or return the transformed state by value.");
if (record.value.has_value()) {
dest->insert(record.time_step_id, value_transformer(*record.value),
derivative_transformer(record.derivative));
} else {
dest->insert(record.time_step_id, History<DestVars>::no_value,
derivative_transformer(record.derivative));
}
}
};

auto copying_step = source.begin();
for (; copying_step != pre_substep_end; ++copying_step) {
transform_record(*copying_step);
}
for (const auto& record : source.substeps()) {
transform_record(record);
}
if (pre_substep_end != source.end()) {
transform_record(*pre_substep_end);
}
}

template <typename DestVars, typename SourceVars, typename Transformer>
void transform(const gsl::not_null<History<DestVars>*> dest,
const History<SourceVars>& source, Transformer&& transformer) {
transform(dest, source, transformer, transformer);
}
/// @}
} // namespace TimeSteppers
68 changes: 68 additions & 0 deletions tests/Unit/Time/Test_History.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,74 @@ void test_history() {
CHECK(static_cast<const ConstUntyped&>(const_history.untyped())
.at_step_start());

// Test transform when the substeps are not associated with the last
// step. This causes errors in a naive implementation.
{
const auto return_transformer = []() {
if constexpr (tt::is_a_v<Variables, Vars>) {
return [](const auto& input) { return input.data()[0] + 1.0; };
} else {
return [](const double& input) { return input + 1.0; };
}
}();
const auto return_transformer2 = []() {
if constexpr (tt::is_a_v<Variables, Vars>) {
return [](const auto& input) { return input.data()[0] + 2.0; };
} else {
return [](const double& input) { return input + 2.0; };
}
}();
const auto mutate_transformer = []() {
if constexpr (tt::is_a_v<Variables, Vars>) {
return [](const gsl::not_null<double*> result, const auto& input) {
*result = input.data()[0] + 1.0;
};
} else {
return [](const gsl::not_null<double*> result, const double& input) {
*result = input + 1.0;
};
}
}();

TimeSteppers::History<double> transformed_history{};
transform(make_not_null(&transformed_history), history, return_transformer);

CHECK(transformed_history.integration_order() ==
history.integration_order());

CHECK(transformed_history.size() == 3);
CHECK(not transformed_history[0].value.has_value());
CHECK(transformed_history[0].derivative == 11.0);
CHECK(transformed_history[1].value == std::optional{3.0});
CHECK(transformed_history[1].derivative == 21.0);
CHECK(transformed_history.substeps().size() == 2);
CHECK(transformed_history.substeps()[0].value == 5.0);
CHECK(transformed_history.substeps()[0].derivative == 41.0);

transform(make_not_null(&transformed_history), history, return_transformer,
return_transformer2);

CHECK(transformed_history.size() == 3);
CHECK(not transformed_history[0].value.has_value());
CHECK(transformed_history[0].derivative == 12.0);
CHECK(transformed_history[1].value == std::optional{3.0});
CHECK(transformed_history[1].derivative == 22.0);
CHECK(transformed_history.substeps().size() == 2);
CHECK(transformed_history.substeps()[0].value == 5.0);
CHECK(transformed_history.substeps()[0].derivative == 42.0);

transform(make_not_null(&transformed_history), history, mutate_transformer);

CHECK(transformed_history.size() == 3);
CHECK(not transformed_history[0].value.has_value());
CHECK(transformed_history[0].derivative == 11.0);
CHECK(transformed_history[1].value == std::optional{3.0});
CHECK(transformed_history[1].derivative == 21.0);
CHECK(transformed_history.substeps().size() == 2);
CHECK(transformed_history.substeps()[0].value == 5.0);
CHECK(transformed_history.substeps()[0].derivative == 41.0);
}

history.undo_latest();
// [(1/4, X, 10), (1/2, 2, 20)] [1/2: (1, 4, 40), (2, 5, 50)]

Expand Down

0 comments on commit d08143f

Please sign in to comment.