From e5fe53d8f19a5929c4a3e9ac95594e381cc75b14 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Mon, 2 Dec 2024 20:32:28 +0300 Subject: [PATCH] Moved MessageHandler inside ModelExecutor: #0000023 --- knp/base-framework/CMakeLists.txt | 2 +- ...ssage_handler.cpp => message_handlers.cpp} | 53 +++++++----- knp/base-framework/impl/model_executor.cpp | 24 ++++-- .../{message_handler.h => message_handlers.h} | 81 +------------------ .../include/knp/framework/model_executor.h | 49 +++++++++-- knp/tests/framework/message_handler_test.cpp | 4 +- 6 files changed, 98 insertions(+), 115 deletions(-) rename knp/base-framework/impl/{message_handler.cpp => message_handlers.cpp} (78%) rename knp/base-framework/include/knp/framework/{message_handler.h => message_handlers.h} (63%) diff --git a/knp/base-framework/CMakeLists.txt b/knp/base-framework/CMakeLists.txt index c9d18047..7c6126e6 100644 --- a/knp/base-framework/CMakeLists.txt +++ b/knp/base-framework/CMakeLists.txt @@ -59,7 +59,7 @@ knp_add_library("${PROJECT_NAME}-core" impl/model.cpp impl/model_executor.cpp impl/model_loader.cpp - impl/message_handler.cpp + impl/message_handlers.cpp impl/input_converter.cpp impl/output_channel.cpp impl/synchronization.cpp diff --git a/knp/base-framework/impl/message_handler.cpp b/knp/base-framework/impl/message_handlers.cpp similarity index 78% rename from knp/base-framework/impl/message_handler.cpp rename to knp/base-framework/impl/message_handlers.cpp index a396d279..ad900fe0 100644 --- a/knp/base-framework/impl/message_handler.cpp +++ b/knp/base-framework/impl/message_handlers.cpp @@ -1,5 +1,5 @@ /** - * @file message_handler.cpp + * @file message_handlers.cpp * @brief Implementation of message handler functionality. * @kaspersky_support A. Vartenkov * @date 25.11.2024 @@ -19,33 +19,30 @@ * limitations under the License. */ -#include +#include #include #include +/** + * @brief namespace for message modifier callables. + */ namespace knp::framework::modifier { -void SpikeMessageHandler::update(size_t step) +knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) { - endpoint_.receive_all_messages(); - auto incoming_messages = endpoint_.unload_messages(base_.uid_); - MessageOut outgoing_message = {{base_.uid_, step}, message_handler_function_(incoming_messages)}; - if (!(outgoing_message.neuron_indexes_.empty())) + if (messages.empty()) { - endpoint_.send_message(outgoing_message); + return {}; } -} - - -knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) -{ - if (messages.empty()) return {}; auto &msg = messages[0]; - if (msg.neuron_indexes_.size() < num_winners_) return msg.neuron_indexes_; + if (msg.neuron_indexes_.size() < num_winners_) + { + return msg.neuron_indexes_; + } knp::core::messaging::SpikeData out_spikes; for (size_t i = 0; i < num_winners_; ++i) @@ -54,6 +51,7 @@ knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) { - if (messages.empty()) return {}; - if (num_winners_ > group_borders_.size()) return messages[0].neuron_indexes_; + if (messages.empty()) + { + return {}; + } + + if (num_winners_ > group_borders_.size()) + { + return messages[0].neuron_indexes_; + } const auto &spikes = messages[0].neuron_indexes_; - if (spikes.empty()) return {}; + if (spikes.empty()) + { + return {}; + } - std::vector> spikes_per_group(group_borders_.size() + 1); + std::vector spikes_per_group(group_borders_.size() + 1); // Fill groups in. for (const auto &spike : spikes) @@ -94,7 +102,10 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( knp::core::messaging::SpikeData result; for (size_t i = 0; i < num_winners_; ++i) { - for (auto spike : spikes_per_group[i]) result.push_back(spike); + for (const auto &spike : spikes_per_group[i]) + { + result.push_back(spike); + } } return result; } @@ -103,7 +114,7 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( knp::core::messaging::SpikeData SpikeUnionHandler::operator()( const std::vector &messages) { - std::unordered_set spikes; + std::unordered_set spikes; for (const auto &msg : messages) { spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end()); diff --git a/knp/base-framework/impl/model_executor.cpp b/knp/base-framework/impl/model_executor.cpp index 1ed01dc6..22f037d7 100644 --- a/knp/base-framework/impl/model_executor.cpp +++ b/knp/base-framework/impl/model_executor.cpp @@ -78,17 +78,29 @@ void ModelExecutor::stop() } -void ModelExecutor::add_message_handler( - typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, - const std::vector &senders, const std::vector &receivers, const knp::core::UID &uid) +void ModelExecutor::SpikeMessageHandler::update(size_t step) +{ + endpoint_.receive_all_messages(); + auto incoming_messages = endpoint_.unload_messages(base_.uid_); + knp::core::messaging::SpikeMessage outgoing_message = { + {base_.uid_, step}, message_handler_function_(incoming_messages)}; + if (!(outgoing_message.neuron_indexes_.empty())) + { + endpoint_.send_message(outgoing_message); + } +} + + +void ModelExecutor::add_spike_message_handler( + typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid) { knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint(); - message_handlers_.emplace_back( - modifier::SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); + message_handlers_.emplace_back(SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); message_handlers_.back().subscribe(senders); for (const knp::core::UID &rec_uid : receivers) { - get_backend()->subscribe(rec_uid, {uid}); + get_backend()->subscribe(rec_uid, {uid}); } } diff --git a/knp/base-framework/include/knp/framework/message_handler.h b/knp/base-framework/include/knp/framework/message_handlers.h similarity index 63% rename from knp/base-framework/include/knp/framework/message_handler.h rename to knp/base-framework/include/knp/framework/message_handlers.h index 7a9fcf39..c786a947 100644 --- a/knp/base-framework/include/knp/framework/message_handler.h +++ b/knp/base-framework/include/knp/framework/message_handlers.h @@ -1,6 +1,6 @@ /** - * @file message_handler.h - * @brief A class that processes a number of messages then sends messages of its own. + * @file message_handlers.h + * @brief A set of predefined message handling functions to add to model executor. * @kaspersky_support Vartenkov A. * @date 19.11.2024 * @license Apache 2.0 @@ -34,83 +34,6 @@ namespace knp::framework::modifier { -/** - * @brief An object that receives and processes messages. - */ -class SpikeMessageHandler -{ -public: - /** - * @brief Input message type. - */ - using MessageIn = knp::core::messaging::SpikeMessage; - - /** - * @brief Output message type. - */ - using MessageOut = knp::core::messaging::SpikeMessage; - - /** - * @brief Functor type. - */ - using FunctionType = std::function &)>; - - /** - * @brief Handler constructor. - * @param function a function that takes a vector of spike messages and returns a vector of spikes. - * @param endpoint message endpoint. - * @param uid the uid of this object. - */ - SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) - : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} - { - } - - /** - * @brief Default move constructor. - * @param other object to move from. - */ - SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; - - /** - * @brief Is not copyable. - */ - SpikeMessageHandler(const SpikeMessageHandler &) = delete; - - /** - * @brief Subscribe handler to a number of other entities. - * @param entities network uids. - * @note For internal use, don't try to call it manually. - */ - void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } - - /** - * @brief Read, process and send messages. - * @param step current step. - * @note for internal use, don't try to call it manually. - */ - void update(size_t step); - - /** - * @brief Get handler UID. - * @return object UID. - */ - [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; - - /** - * @brief Get a tag. - * @param tag_name tag name. - * @return tag value reference. - */ - [[nodiscard]] std::any &get_tag(const std::string &tag_name) { return base_.tags_[tag_name]; }; - -private: - FunctionType message_handler_function_; - core::MessageEndpoint endpoint_; - knp::core::BaseData base_; -}; - - /** * @brief A modifier functor to process spikes and select random K spikes out of the whole set. * @note Only processes a single message. diff --git a/knp/base-framework/include/knp/framework/model_executor.h b/knp/base-framework/include/knp/framework/model_executor.h index ff5a17fd..f25bedbb 100644 --- a/knp/base-framework/include/knp/framework/model_executor.h +++ b/knp/base-framework/include/knp/framework/model_executor.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -94,6 +94,12 @@ class KNP_DECLSPEC ModelExecutor std::visit([&senders](auto &entity) { entity.subscribe(senders); }, observers_.back()); } + /** + * @brief Function type for message handlers. + */ + using SpikeHandlerFunction = + std::function &)>; + /** * @brief Add spike message handler to executor. * @param message_handler_function functor to process received messages. @@ -101,10 +107,9 @@ class KNP_DECLSPEC ModelExecutor * @param receivers list of entities receiving messages from handler. * @param uid handler uid. */ - void add_message_handler( - typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, - const std::vector &senders, const std::vector &receivers, - const knp::core::UID &uid = knp::core::UID{}); + void add_spike_message_handler( + SpikeHandlerFunction &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid = knp::core::UID{}); /** * @brief Unlock synapse weights. @@ -129,10 +134,42 @@ class KNP_DECLSPEC ModelExecutor auto &get_loader() { return loader_; } private: + /** + * @brief An object that receives and processes messages. + */ + class SpikeMessageHandler + { + public: + using MessageIn = knp::core::messaging::SpikeMessage; + using MessageData = knp::core::messaging::SpikeData; + using FunctionType = std::function &)>; + + SpikeMessageHandler( + FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) + : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} + { + } + + SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; + + SpikeMessageHandler(const SpikeMessageHandler &) = delete; + + void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } + + void update(size_t step); + + [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; + + private: + FunctionType message_handler_function_; + core::MessageEndpoint endpoint_; + knp::core::BaseData base_; + }; + knp::core::BaseData base_; ModelLoader loader_; std::vector observers_; - std::vector message_handlers_; + std::vector message_handlers_; }; } // namespace knp::framework diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp index d55f0fb1..b14fe722 100644 --- a/knp/tests/framework/message_handler_test.cpp +++ b/knp/tests/framework/message_handler_test.cpp @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include #include @@ -169,7 +169,7 @@ TEST(MessageHandlerSuite, NetworkIntegrationTest) auto &out_channel = model_executor.get_loader().get_output_channel(output_uid); const std::vector group_borders{2, 4}; const knp::core::UID handler_uid; - model_executor.add_message_handler( + model_executor.add_spike_message_handler( knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid); constexpr int num_steps = 20; model_executor.start([](size_t step) { return step < num_steps; });