Skip to content

Commit

Permalink
Use FutureMeasurements in control system trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
wthrowe committed Oct 13, 2023
1 parent 6f0422f commit afc99d8
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 152 deletions.
2 changes: 0 additions & 2 deletions src/ControlSystem/Actions/InitializeMeasurements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ struct InitializeMeasurements {
using simple_tags =
tmpl::transform<control_system_groups,
tmpl::bind<Tags::FutureMeasurements, tmpl::_1>>;
using compute_tags = tmpl::list<Parallel::Tags::FromGlobalCache<
::control_system::Tags::MeasurementTimescales>>;
using const_global_cache_tags = tmpl::list<Tags::MeasurementsPerUpdate>;
using mutable_global_cache_tags =
tmpl::list<control_system::Tags::MeasurementTimescales>;
Expand Down
165 changes: 62 additions & 103 deletions src/ControlSystem/Trigger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,20 @@

#pragma once

#include <algorithm>
#include <array>
#include <limits>
#include <memory>
#include <optional>
#include <pup.h>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "ControlSystem/CombinedName.hpp"
#include "ControlSystem/FutureMeasurements.hpp"
#include "ControlSystem/Metafunctions.hpp"
#include "DataStructures/DataVector.hpp"
#include "Domain/FunctionsOfTime/FunctionOfTime.hpp"
#include "Domain/FunctionsOfTime/FunctionsOfTimeAreReady.hpp"
#include "Evolution/EventsAndDenseTriggers/DenseTrigger.hpp"
#include "IO/Logging/Verbosity.hpp"
#include "Utilities/Algorithm.hpp"
#include "Parallel/ArrayComponentId.hpp"
#include "Parallel/Callback.hpp"
#include "Parallel/GlobalCache.hpp"
#include "Utilities/ErrorHandling/Assert.hpp"
#include "Utilities/GetOutput.hpp"
#include "Utilities/Gsl.hpp"
Expand All @@ -31,14 +25,12 @@
#include "Utilities/TMPL.hpp"

/// \cond
namespace Parallel {
template <typename Metavariables>
class GlobalCache;
} // namespace Parallel
namespace Tags {
struct Time;
} // namespace Tags
namespace control_system::Tags {
template <typename ControlSystems>
struct FutureMeasurements;
struct MeasurementTimescales;
struct Verbosity;
} // namespace control_system::Tags
Expand Down Expand Up @@ -83,125 +75,92 @@ class Trigger : public DenseTrigger {

using is_triggered_return_tags = tmpl::list<>;
using is_triggered_argument_tags =
tmpl::list<::Tags::Time, control_system::Tags::MeasurementTimescales>;
tmpl::list<::Tags::Time,
control_system::Tags::FutureMeasurements<ControlSystems>>;

template <typename Metavariables, typename ArrayIndex, typename Component>
std::optional<bool> is_triggered(
Parallel::GlobalCache<Metavariables>& cache,
const ArrayIndex& array_index, const Component* /*component*/,
const double time,
const std::unordered_map<
std::string,
std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>&
measurement_timescales) {
if (UNLIKELY(not next_trigger_.has_value())) {
// First call

// This will happen if an executable has control systems, but
// all functions of time were overriden by ones read in from a
// file. So there is no need to trigger control systems. Since
// we only enter this branch on the first call to the trigger,
// this is the initial time so we can assume the
// measurement_timescales are ready.
if (next_measurement(time, measurement_timescales) ==
std::numeric_limits<double>::infinity()) {
next_trigger_ = std::numeric_limits<double>::infinity();
} else {
next_trigger_ = time;
}
}
const control_system::FutureMeasurements& measurement_times) {
const auto next_measurement = measurement_times.next_measurement();
ASSERT(next_measurement.has_value(),
"Checking trigger without knowing next time.");
const bool triggered = time == *next_measurement;

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Debug) {
Parallel::printf(
"%s, time = %.16f: Trigger for control systems (%s) is%s "
"triggered.\n",
get_output(array_index), time,
pretty_type::list_of_names<ControlSystems>(),
(time == *next_trigger_ ? "" : " not"));
(triggered ? "" : " not"));
}

return time == *next_trigger_;
return triggered;
}

using next_check_time_return_tags = tmpl::list<>;
using next_check_time_argument_tags =
tmpl::list<::Tags::Time, control_system::Tags::MeasurementTimescales>;
using next_check_time_return_tags =
tmpl::list<control_system::Tags::FutureMeasurements<ControlSystems>>;
using next_check_time_argument_tags = tmpl::list<::Tags::Time>;

template <typename Metavariables, typename ArrayIndex, typename Component>
std::optional<double> next_check_time(
Parallel::GlobalCache<Metavariables>& cache,
const ArrayIndex& array_index, const Component* component,
const double time,
const std::unordered_map<
std::string,
std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>&
measurement_timescales) {
// At least one control system is active
const bool is_ready =
domain::functions_of_time_are_ready_algorithm_callback<
control_system::Tags::MeasurementTimescales>(
cache, array_index, component, time,
std::unordered_set{
control_system::combined_name<ControlSystems>()});
if (not is_ready) {
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Debug) {
Parallel::printf(
"%s, time = %.16f: Trigger - Cannot calculate next_check_time\n",
get_output(array_index), time);
}
return std::nullopt;
const ArrayIndex& array_index, const Component* /*component*/,
const gsl::not_null<control_system::FutureMeasurements*>
measurement_times,
const double time) {
if (measurement_times->next_measurement() == std::optional(time)) {
measurement_times->pop_front();
}

const bool triggered = time == *next_trigger_;
if (triggered) {
*next_trigger_ = next_measurement(time, measurement_timescales);
if (not measurement_times->next_measurement().has_value()) {
const auto& proxy =
::Parallel::get_parallel_component<Component>(cache)[array_index];
const bool is_ready = Parallel::mutable_cache_item_is_ready<
control_system::Tags::MeasurementTimescales>(
cache, Parallel::make_array_component_id<Component>(array_index),
[&](const auto& measurement_timescales) {
const std::string& measurement_name =
control_system::combined_name<ControlSystems>();
ASSERT(measurement_timescales.count(measurement_name) == 1,
"Control system trigger expects a measurement timescale "
"with the name '"
<< measurement_name
<< "' but could not find one. Available names are: "
<< keys_of(measurement_timescales));
measurement_times->update(
*measurement_timescales.at(measurement_name));
if (not measurement_times->next_measurement().has_value()) {
return std::unique_ptr<Parallel::Callback>(
new Parallel::PerformAlgorithmCallback(proxy));
}
return std::unique_ptr<Parallel::Callback>{};
});

if (not is_ready) {
if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Debug) {
Parallel::printf(
"%s, time = %.16f: Trigger - Cannot calculate next_check_time\n",
get_output(array_index), time);
}
return std::nullopt;
}
}

const double next_trigger = *measurement_times->next_measurement();
ASSERT(next_trigger > time,
"Next trigger is in the past: " << next_trigger << " > " << time);

if (Parallel::get<Tags::Verbosity>(cache) >= ::Verbosity::Debug) {
Parallel::printf("%s, time = %.16f: Trigger - next check time is %.16f\n",
get_output(array_index), time, *next_trigger_);
get_output(array_index), time, next_trigger);
}
return *next_trigger_;
}

// NOLINTNEXTLINE(google-runtime-references)
void pup(PUP::er& p) override {
DenseTrigger::pup(p);
p | next_trigger_;
return std::optional(next_trigger);

Check failure on line 162 in src/ControlSystem/Trigger.hpp

View workflow job for this annotation

GitHub Actions / Clang-tidy (Debug)

avoid repeating the return type from the declaration; use a braced initializer list instead

Check failure on line 162 in src/ControlSystem/Trigger.hpp

View workflow job for this annotation

GitHub Actions / Clang-tidy (Release)

avoid repeating the return type from the declaration; use a braced initializer list instead
}

private:
double next_measurement(
const double time,
const std::unordered_map<
std::string,
std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>&
measurement_timescales) {
const std::string& measurement_name =
control_system::combined_name<ControlSystems>();
ASSERT(
measurement_timescales.count(measurement_name) == 1,
"Control system trigger expects a measurement timescale with the name '"
<< measurement_name
<< "' but could not find one. Available names are: "
<< keys_of(measurement_timescales));
const DataVector timescale =
measurement_timescales.at(measurement_name)->func(time)[0];
ASSERT(timescale.size() == 1,
"Control system trigger assumes measurement timescale size is 1, "
"but it is "
<< timescale.size() << " instead.");

const double min_measure_time = timescale[0];

if (min_measure_time == std::numeric_limits<double>::infinity()) {
return min_measure_time;
}

return time + min_measure_time;
}

std::optional<double> next_trigger_{};
};

/// \cond
Expand Down
65 changes: 18 additions & 47 deletions tests/Unit/ControlSystem/Test_Trigger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <unordered_map>
#include <utility>

#include "ControlSystem/FutureMeasurements.hpp"
#include "ControlSystem/Tags/FutureMeasurements.hpp"
#include "ControlSystem/Tags/MeasurementTimescales.hpp"
#include "ControlSystem/Tags/SystemTags.hpp"
#include "ControlSystem/Trigger.hpp"
Expand Down Expand Up @@ -43,9 +45,9 @@ using measurement = control_system::TestHelpers::Measurement<LabelA>;
using SystemA = control_system::TestHelpers::System<2, LabelA, measurement>;
using SystemB = control_system::TestHelpers::System<2, LabelB, measurement>;
using SystemC = control_system::TestHelpers::System<2, LabelC, measurement>;
using control_systems = tmpl::list<SystemA, SystemB, SystemC>;

using MeasureTrigger =
control_system::Trigger<tmpl::list<SystemA, SystemB, SystemC>>;
using MeasureTrigger = control_system::Trigger<control_systems>;

using MeasurementFoT = domain::FunctionsOfTime::PiecewisePolynomial<0>;

Expand All @@ -55,9 +57,10 @@ struct Component {
using chare_type = ActionTesting::MockArrayChare;
using array_index = int;

using simple_tags = tmpl::list<Tags::Time>;
using compute_tags = tmpl::list<Parallel::Tags::FromGlobalCache<
control_system::Tags::MeasurementTimescales>>;
using simple_tags =
tmpl::list<Tags::Time,
control_system::Tags::FutureMeasurements<control_systems>>;
using compute_tags = tmpl::list<>;
using const_global_cache_tags = tmpl::list<control_system::Tags::Verbosity>;
using mutable_global_cache_tags =
tmpl::list<control_system::Tags::MeasurementTimescales>;
Expand Down Expand Up @@ -90,11 +93,13 @@ void test_trigger_no_replace() {
measurement_timescales["DifferentMeasurement"] =
std::make_unique<MeasurementFoT>(0.0, std::array{DataVector{1.0}}, 0.1);

control_system::FutureMeasurements future_measurements(6, 0.0);

MockRuntimeSystem runner{{::Verbosity::Silent},
{std::move(measurement_timescales)}};
ActionTesting::emplace_array_component_and_initialize<component>(
make_not_null(&runner), ActionTesting::NodeId{0},
ActionTesting::LocalCoreId{0}, 0, {0.0});
ActionTesting::LocalCoreId{0}, 0, {0.0, std::move(future_measurements)});
ActionTesting::set_phase(make_not_null(&runner), Parallel::Phase::Testing);

auto& box = ActionTesting::get_databox<component>(make_not_null(&runner), 0);
Expand Down Expand Up @@ -136,14 +141,11 @@ void test_trigger_no_replace() {

set_time(0.75);

// Another intermediate time where we shouldn't trigger. At this point, the
// measurement timescale has expired and has not been updated yet, so we
// cannot calculate the next check time. It should be nullopt
// Another intermediate time where we shouldn't trigger.
REQUIRE(trigger.is_triggered(make_not_null(&box), cache, 0, component_p) ==
std::optional{false});
REQUIRE(
not trigger.next_check_time(make_not_null(&box), cache, 0, component_p)
.has_value());
REQUIRE(trigger.next_check_time(make_not_null(&box), cache, 0, component_p) ==
std::optional{1.0});

// Update the measurement timescales
Parallel::mutate<control_system::Tags::MeasurementTimescales,
Expand Down Expand Up @@ -184,10 +186,13 @@ void test_trigger_with_replace() {
0.0, std::array{DataVector{std::numeric_limits<double>::infinity()}},
std::numeric_limits<double>::infinity());

control_system::FutureMeasurements future_measurements(
6, std::numeric_limits<double>::infinity());

MockRuntimeSystem runner{{}, {std::move(measurement_timescales)}};
ActionTesting::emplace_array_component_and_initialize<component>(
make_not_null(&runner), ActionTesting::NodeId{0},
ActionTesting::LocalCoreId{0}, 0, {0.0});
ActionTesting::LocalCoreId{0}, 0, {0.0, std::move(future_measurements)});
ActionTesting::set_phase(make_not_null(&runner), Parallel::Phase::Testing);

auto& box = ActionTesting::get_databox<component>(make_not_null(&runner), 0);
Expand All @@ -203,43 +208,9 @@ void test_trigger_with_replace() {
trigger.next_check_time(make_not_null(&box), cache, 0, component_p);
CHECK(next_check == std::optional{std::numeric_limits<double>::infinity()});
}

void test_errors() {
#ifdef SPECTRE_DEBUG
register_classes_with_charm<MeasurementFoT>();
const component* const component_p = nullptr;

control_system::Tags::MeasurementTimescales::type measurement_timescales{};
measurement_timescales["LabelALabelBLabelC"] =
std::make_unique<MeasurementFoT>(
0.0,
std::array{DataVector{3, std::numeric_limits<double>::infinity()}},
std::numeric_limits<double>::infinity());

MockRuntimeSystem runner{{::Verbosity::Silent},
{std::move(measurement_timescales)}};
ActionTesting::emplace_array_component_and_initialize<component>(
make_not_null(&runner), ActionTesting::NodeId{0},
ActionTesting::LocalCoreId{0}, 0, {0.0});
ActionTesting::set_phase(make_not_null(&runner), Parallel::Phase::Testing);

auto& box = ActionTesting::get_databox<component>(make_not_null(&runner), 0);
auto& cache = ActionTesting::cache<component>(runner, 0);

MeasureTrigger typed_trigger = serialize_and_deserialize(MeasureTrigger{});
DenseTrigger& trigger = typed_trigger;

CHECK_THROWS_WITH(
trigger.next_check_time(make_not_null(&box), cache, 0, component_p),
Catch::Matchers::ContainsSubstring(
"Control system trigger assumes measurement timescale size is 1, but "
"it is 3 instead."));
#endif
}
} // namespace

SPECTRE_TEST_CASE("Unit.ControlSystem.Trigger", "[Domain][Unit]") {
test_trigger_no_replace();
test_trigger_with_replace();
test_errors();
}

0 comments on commit afc99d8

Please sign in to comment.