From d08143ff7733a79dc9018eaf7b3bcaf6b1d1c5df Mon Sep 17 00:00:00 2001 From: William Throwe Date: Thu, 23 Feb 2023 00:49:38 -0500 Subject: [PATCH] Add transform for History --- src/Time/History.hpp | 88 ++++++++++++++++++++++++++++++++ tests/Unit/Time/Test_History.cpp | 68 ++++++++++++++++++++++++ 2 files changed, 156 insertions(+) diff --git a/src/Time/History.hpp b/src/Time/History.hpp index 1f3eb7162489..65e532fbe19f 100644 --- a/src/Time/History.hpp +++ b/src/Time/History.hpp @@ -1008,4 +1008,92 @@ template std::ostream& operator<<(std::ostream& os, const History& history) { return history.print(os); } + +/// \ingroup TimeSteppersGroup +/// Initialize a History object based on the contents of another, +/// applying a transformation to each value and derivative. +/// +/// The transformation functions can either take a value from the +/// source history and return a value for the destination history or +/// take a `gsl::not_null` value from the destination history and a +/// value from the source history to initialize it with. For the sake +/// of implementation simplicity, either both transformers must mutate +/// or both must produce values. +/// +/// An overload applying the same transformation to the values and +/// derivatives is provided for convenience. +/// @{ +template +void transform(const gsl::not_null*> dest, + const History& source, + ValueTransformer&& value_transformer, + DerivativeTransformer&& derivative_transformer) { + dest->clear_substeps(); + dest->clear(); + dest->integration_order(source.integration_order()); + if (source.empty()) { + return; + } + auto pre_substep_end = source.end(); + if (not source.substeps().empty() and + source.back().time_step_id > source.substeps().back().time_step_id) { + --pre_substep_end; + } + + const auto transform_record = + [&derivative_transformer, &dest, &value_transformer]( + const typename History::value_type& record) { + if constexpr (std::is_invocable_v, + const SourceVars&>) { + if (record.value.has_value()) { + dest->insert_in_place( + record.time_step_id, + [&](const auto result) { + value_transformer(result, *record.value); + }, + [&](const auto result) { + derivative_transformer(result, record.derivative); + }); + } else { + dest->insert_in_place( + record.time_step_id, History::no_value, + [&](const auto result) { + derivative_transformer(result, record.derivative); + }); + } + } else { + static_assert( + std::is_invocable_v, + "Transform function must either be callable to mutate entries " + "or return the transformed state by value."); + if (record.value.has_value()) { + dest->insert(record.time_step_id, value_transformer(*record.value), + derivative_transformer(record.derivative)); + } else { + dest->insert(record.time_step_id, History::no_value, + derivative_transformer(record.derivative)); + } + } + }; + + auto copying_step = source.begin(); + for (; copying_step != pre_substep_end; ++copying_step) { + transform_record(*copying_step); + } + for (const auto& record : source.substeps()) { + transform_record(record); + } + if (pre_substep_end != source.end()) { + transform_record(*pre_substep_end); + } +} + +template +void transform(const gsl::not_null*> dest, + const History& source, Transformer&& transformer) { + transform(dest, source, transformer, transformer); +} +/// @} } // namespace TimeSteppers diff --git a/tests/Unit/Time/Test_History.cpp b/tests/Unit/Time/Test_History.cpp index 5e6042f8d823..4ef503472728 100644 --- a/tests/Unit/Time/Test_History.cpp +++ b/tests/Unit/Time/Test_History.cpp @@ -381,6 +381,74 @@ void test_history() { CHECK(static_cast(const_history.untyped()) .at_step_start()); + // Test transform when the substeps are not associated with the last + // step. This causes errors in a naive implementation. + { + const auto return_transformer = []() { + if constexpr (tt::is_a_v) { + return [](const auto& input) { return input.data()[0] + 1.0; }; + } else { + return [](const double& input) { return input + 1.0; }; + } + }(); + const auto return_transformer2 = []() { + if constexpr (tt::is_a_v) { + return [](const auto& input) { return input.data()[0] + 2.0; }; + } else { + return [](const double& input) { return input + 2.0; }; + } + }(); + const auto mutate_transformer = []() { + if constexpr (tt::is_a_v) { + return [](const gsl::not_null result, const auto& input) { + *result = input.data()[0] + 1.0; + }; + } else { + return [](const gsl::not_null result, const double& input) { + *result = input + 1.0; + }; + } + }(); + + TimeSteppers::History transformed_history{}; + transform(make_not_null(&transformed_history), history, return_transformer); + + CHECK(transformed_history.integration_order() == + history.integration_order()); + + CHECK(transformed_history.size() == 3); + CHECK(not transformed_history[0].value.has_value()); + CHECK(transformed_history[0].derivative == 11.0); + CHECK(transformed_history[1].value == std::optional{3.0}); + CHECK(transformed_history[1].derivative == 21.0); + CHECK(transformed_history.substeps().size() == 2); + CHECK(transformed_history.substeps()[0].value == 5.0); + CHECK(transformed_history.substeps()[0].derivative == 41.0); + + transform(make_not_null(&transformed_history), history, return_transformer, + return_transformer2); + + CHECK(transformed_history.size() == 3); + CHECK(not transformed_history[0].value.has_value()); + CHECK(transformed_history[0].derivative == 12.0); + CHECK(transformed_history[1].value == std::optional{3.0}); + CHECK(transformed_history[1].derivative == 22.0); + CHECK(transformed_history.substeps().size() == 2); + CHECK(transformed_history.substeps()[0].value == 5.0); + CHECK(transformed_history.substeps()[0].derivative == 42.0); + + transform(make_not_null(&transformed_history), history, mutate_transformer); + + CHECK(transformed_history.size() == 3); + CHECK(not transformed_history[0].value.has_value()); + CHECK(transformed_history[0].derivative == 11.0); + CHECK(transformed_history[1].value == std::optional{3.0}); + CHECK(transformed_history[1].derivative == 21.0); + CHECK(transformed_history.substeps().size() == 2); + CHECK(transformed_history.substeps()[0].value == 5.0); + CHECK(transformed_history.substeps()[0].derivative == 41.0); + } + history.undo_latest(); // [(1/4, X, 10), (1/2, 2, 20)] [1/2: (1, 4, 40), (2, 5, 50)]