Skip to content

Commit

Permalink
Check FoT before UpdateMessageQueue
Browse files Browse the repository at this point in the history
Which is typicall called on control system components
  • Loading branch information
knelli2 committed Oct 4, 2023
1 parent 776ee2c commit a812c7e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/ParallelAlgorithms/Actions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ target_link_libraries(
DataStructures
Domain
ErrorHandling
FunctionsOfTime
Serialization
Utilities
)
Expand Down
10 changes: 10 additions & 0 deletions src/ParallelAlgorithms/Actions/UpdateMessageQueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
#include "DataStructures/DataBox/DataBox.hpp"
#include "DataStructures/LinkedMessageId.hpp"
#include "DataStructures/LinkedMessageQueue.hpp"
#include "Domain/FunctionsOfTime/FunctionsOfTimeAreReady.hpp"
#include "Utilities/ErrorHandling/Error.hpp"
#include "Utilities/Gsl.hpp"

/// \cond
namespace domain::Tags {
struct FunctionsOfTime;
} // namespace domain::Tags
namespace Parallel {
template <typename Metavariables>
struct GlobalCache;
Expand Down Expand Up @@ -40,6 +44,12 @@ struct UpdateMessageQueue {
const LinkedMessageId<typename LinkedMessageQueueTag::type::IdType>&
id_and_previous,
typename QueueTag::type message) {
if (not domain::functions_of_time_are_ready_simple_action_callback<
domain::Tags::FunctionsOfTime, UpdateMessageQueue>(
cache, array_index, std::add_pointer_t<ParallelComponent>{nullptr},
id_and_previous.id, std::nullopt, id_and_previous, message)) {
return;
}
auto& queue =
db::get_mutable_reference<LinkedMessageQueueTag>(make_not_null(&box));
queue.template insert<QueueTag>(id_and_previous, std::move(message));
Expand Down
3 changes: 3 additions & 0 deletions tests/Unit/ParallelAlgorithms/Actions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ target_link_libraries(
${LIBRARY}
PRIVATE
DataStructures
DomainCreators
DomainStructure
FunctionsOfTime
Parallel
)
67 changes: 54 additions & 13 deletions tests/Unit/ParallelAlgorithms/Actions/Test_UpdateMessageQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@

#include "Framework/TestingFramework.hpp"

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "DataStructures/DataBox/DataBox.hpp"
#include "DataStructures/DataBox/Tag.hpp"
#include "DataStructures/LinkedMessageId.hpp"
#include "DataStructures/LinkedMessageQueue.hpp"
#include "Domain/Creators/Tags/FunctionsOfTime.hpp"
#include "Domain/FunctionsOfTime/FunctionOfTime.hpp"
#include "Domain/FunctionsOfTime/PiecewisePolynomial.hpp"
#include "Domain/FunctionsOfTime/RegisterDerivedWithCharm.hpp"
#include "Domain/FunctionsOfTime/Tags.hpp"
#include "Framework/ActionTesting.hpp"
#include "Parallel/GlobalCache.hpp"
#include "Parallel/Phase.hpp"
#include "Parallel/PhaseDependentActionList.hpp"
#include "ParallelAlgorithms/Actions/UpdateMessageQueue.hpp"
Expand Down Expand Up @@ -42,19 +50,19 @@ struct ProcessorCalls : db::SimpleTag {
};

struct Processor {
// [Processor::apply]
template <typename DbTags, typename Metavariables, typename ArrayIndex>
static void apply(const gsl::not_null<db::DataBox<DbTags>*> box,
Parallel::GlobalCache<Metavariables>& /*cache*/,
const ArrayIndex& /*array_index*/, const int id,
tuples::TaggedTuple<Queue1, Queue2> data) {
// [Processor::apply]
db::mutate<ProcessorCalls>(
[&id, &data](const gsl::not_null<ProcessorCalls::type*> calls) {
calls->emplace_back(id, std::move(data));
},
box);
}
template <typename DbTags, typename Metavariables, typename ArrayIndex>
static void apply(const gsl::not_null<db::DataBox<DbTags>*> box,
Parallel::GlobalCache<Metavariables>& /*cache*/,
const ArrayIndex& /*array_index*/, const int id,
tuples::TaggedTuple<Queue1, Queue2> data) {
// [Processor::apply]
db::mutate<ProcessorCalls>(
[&id, &data](const gsl::not_null<ProcessorCalls::type*> calls) {
calls->emplace_back(id, std::move(data));
},
box);
}
};

template <typename Metavariables>
Expand All @@ -64,24 +72,48 @@ struct Component {
using array_index = int;
using simple_tags_from_options =
tmpl::list<LinkedMessageQueueTag, ProcessorCalls>;
using mutable_global_cache_tags =
tmpl::list<domain::Tags::FunctionsOfTimeInitialize>;
using phase_dependent_action_list = tmpl::list<
Parallel::PhaseActions<Parallel::Phase::Initialization, tmpl::list<>>>;
};

struct Metavariables {
using component_list = tmpl::list<Component<Metavariables>>;
};

using FunctionMap = domain::Tags::FunctionsOfTimeInitialize::type;
struct UpdateFoT {
static void apply(const gsl::not_null<FunctionMap*> functions,
const std::string& name, const double expiration) {
const double current_expiration = functions->at(name)->time_bounds()[1];
// Update value doesn't matter
(*functions)
.at(name)
->update(current_expiration, DataVector{0.0}, expiration);
}
};
} // namespace

SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
using component = Component<Metavariables>;
domain::FunctionsOfTime::register_derived_with_charm();

ActionTesting::MockRuntimeSystem<Metavariables> runner{{}};
FunctionMap functions_of_time{};
const std::string name{"Smaug"};
functions_of_time[name] =
std::make_unique<domain::FunctionsOfTime::PiecewisePolynomial<0>>(
0.0, std::array{DataVector{0.0}}, 1.0);

ActionTesting::MockRuntimeSystem<Metavariables> runner{
{}, {std::move(functions_of_time)}};
ActionTesting::emplace_component<component>(
&runner, 0, LinkedMessageQueueTag::type{}, ProcessorCalls::type{});

ActionTesting::set_phase(make_not_null(&runner), Parallel::Phase::Testing);

auto& cache = ActionTesting::cache<component>(runner, 0);

const auto processed_by_call = [&runner](auto queue_v,
const LinkedMessageId<int>& id,
auto data) -> decltype(auto) {
Expand All @@ -108,7 +140,16 @@ SPECTRE_TEST_CASE("Unit.Actions.UpdateMessageQueue", "[Unit][Actions]") {
CHECK(get<Queue1>(processed[0].second) == 1.23);
CHECK(get<Queue2>(processed[0].second) == 2.34);
}
// Nothing should have been inserted because 2 is after expiration of 1.
CHECK(processed_by_call(Queue1{}, {2, 1}, 2.2).empty());
Parallel::mutate<domain::Tags::FunctionsOfTime, UpdateFoT>(cache, name, 5.0);
CHECK(ActionTesting::number_of_queued_simple_actions<component>(runner, 0) ==
1);
// Now things should have been inserted
ActionTesting::invoke_queued_simple_action<component>(make_not_null(&runner),
0);
CHECK(ActionTesting::get_databox_tag<component, ProcessorCalls>(runner, 0)
.empty());
CHECK(processed_by_call(Queue2{}, {1, 0}, 1.1).empty());
CHECK(processed_by_call(Queue2{}, {2, 1}, 2.2).empty());
{
Expand Down

0 comments on commit a812c7e

Please sign in to comment.