From b10ddc1b943e61018f33e9e6a31f15c2aa664b1f Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Fri, 29 Nov 2024 13:11:48 +0300 Subject: [PATCH 1/8] Add message handler and tests: #0000023 --- .../impl/synaptic_resource_stdp_impl.h | 4 +- knp/base-framework/CMakeLists.txt | 1 + knp/base-framework/impl/message_handler.cpp | 97 +++++++++ knp/base-framework/impl/model_executor.cpp | 20 ++ .../include/knp/framework/message_handler.h | 194 ++++++++++++++++++ .../include/knp/framework/model_executor.h | 38 ++-- .../knp/framework/monitoring/observer.h | 2 +- knp/tests/framework/message_handler_test.cpp | 174 ++++++++++++++++ 8 files changed, 515 insertions(+), 15 deletions(-) create mode 100644 knp/base-framework/impl/message_handler.cpp create mode 100644 knp/base-framework/include/knp/framework/message_handler.h create mode 100644 knp/tests/framework/message_handler_test.cpp diff --git a/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h b/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h index 96e422e2..df4c4b90 100644 --- a/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h +++ b/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h @@ -320,8 +320,8 @@ void do_dopamine_plasticity( if (step - synapse->rule_.last_spike_step_ < synapse->rule_.dopamine_plasticity_period_) { // Change synapse resource. - float d_r = - neuron.dopamine_value_ * std::min(static_cast(std::pow(2, -neuron.stability_)), 1.F); + float d_r = neuron.dopamine_value_ * + std::min(static_cast(std::pow(2, -neuron.stability_)), 1.F) / 1000.F; synapse->rule_.synaptic_resource_ += d_r; neuron.free_synaptic_resource_ -= d_r; } diff --git a/knp/base-framework/CMakeLists.txt b/knp/base-framework/CMakeLists.txt index e4f52f8c..df9a9b8e 100644 --- a/knp/base-framework/CMakeLists.txt +++ b/knp/base-framework/CMakeLists.txt @@ -59,6 +59,7 @@ knp_add_library("${PROJECT_NAME}-core" impl/model.cpp impl/model_executor.cpp impl/model_loader.cpp + impl/message_handler.cpp impl/input_converter.cpp impl/output_channel.cpp impl/synchronization.cpp diff --git a/knp/base-framework/impl/message_handler.cpp b/knp/base-framework/impl/message_handler.cpp new file mode 100644 index 00000000..4dc33376 --- /dev/null +++ b/knp/base-framework/impl/message_handler.cpp @@ -0,0 +1,97 @@ +// +// Created by an_vartenkov on 22.11.24. +// +#include + +#include +#include + +namespace knp::framework::modifier +{ + +void SpikeMessageHandler::update(size_t step) +{ + endpoint_.receive_all_messages(); + auto incoming_messages = endpoint_.unload_messages(base_.uid_); + MessageOut outgoing_message = {{base_.uid_, step}, message_handler_function_(incoming_messages)}; + if (!(outgoing_message.neuron_indexes_.empty())) + { + endpoint_.send_message(outgoing_message); + } +} + + +knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) +{ + if (messages.empty()) return {}; + + auto &msg = messages[0]; + if (msg.neuron_indexes_.size() < num_winners_) return msg.neuron_indexes_; + + knp::core::messaging::SpikeData out_spikes; + for (size_t i = 0; i < num_winners_; ++i) + { + const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i); + out_spikes.push_back(msg.neuron_indexes_[index]); + std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]); + } + return out_spikes; +} + + +knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( + const std::vector &messages) +{ + if (messages.empty()) return {}; + if (num_winners_ > group_borders_.size()) return messages[0].neuron_indexes_; + + const auto &spikes = messages[0].neuron_indexes_; + if (spikes.empty()) return {}; + + std::vector> spikes_per_group(group_borders_.size() + 1); + + // Fill groups in + for (auto spike : spikes) + { + const size_t group_index = + std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin(); + spikes_per_group[group_index].push_back(spike); + } + + // Sort groups by number of elements + std::sort( + spikes_per_group.begin(), spikes_per_group.end(), + [](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); + + // Find all groups with the same number of spikes as the K-th one + const auto &last_group = spikes_per_group[num_winners_ - 1]; + auto group_interval = std::equal_range( + spikes_per_group.begin(), spikes_per_group.end(), last_group, + [](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); + const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1; + assert(already_decided <= num_winners_); + // The approach could be more efficient, but I don't think it's necessary. + std::shuffle(group_interval.first, group_interval.second, random_engine_); + knp::core::messaging::SpikeData result; + for (size_t i = 0; i < num_winners_; ++i) + { + for (auto spike : spikes_per_group[i]) result.push_back(spike); + } + return result; +} + + +knp::core::messaging::SpikeData SpikeUnionHandler::operator()( + const std::vector &messages) +{ + std::unordered_set spikes; + for (const auto &msg : messages) + { + spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end()); + } + knp::core::messaging::SpikeData result; + result.reserve(spikes.size()); + std::copy(spikes.begin(), spikes.end(), std::back_inserter(result)); + return result; +} +} // namespace knp::framework::modifier diff --git a/knp/base-framework/impl/model_executor.cpp b/knp/base-framework/impl/model_executor.cpp index 4914f24a..1ed01dc6 100644 --- a/knp/base-framework/impl/model_executor.cpp +++ b/knp/base-framework/impl/model_executor.cpp @@ -55,6 +55,11 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate) { o_ch.update(); } + // Running handlers + for (auto &handler : message_handlers_) + { + handler.update(get_backend()->get_step()); + } // Run monitoring observers. for (auto &observer : observers_) { @@ -72,4 +77,19 @@ void ModelExecutor::stop() get_backend()->stop(); } + +void ModelExecutor::add_message_handler( + typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, + const std::vector &senders, const std::vector &receivers, const knp::core::UID &uid) +{ + knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint(); + message_handlers_.emplace_back( + modifier::SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); + message_handlers_.back().subscribe(senders); + for (const knp::core::UID &rec_uid : receivers) + { + get_backend()->subscribe(rec_uid, {uid}); + } +} + } // namespace knp::framework diff --git a/knp/base-framework/include/knp/framework/message_handler.h b/knp/base-framework/include/knp/framework/message_handler.h new file mode 100644 index 00000000..09fe4c7e --- /dev/null +++ b/knp/base-framework/include/knp/framework/message_handler.h @@ -0,0 +1,194 @@ +/** + * @file message_handler.h + * @brief A class that processes a number of messages then sends messages of its own. + * @kaspersky_support Vartenkov A. + * @date 19.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + + +#include +#include + +#include +#include +#include +#include +#include + + +namespace knp::framework::modifier +{ + +/** + * @brief An object that receives and processes messages. + */ +class SpikeMessageHandler +{ +public: + using MessageIn = knp::core::messaging::SpikeMessage; + using MessageOut = knp::core::messaging::SpikeMessage; + using FunctionType = std::function &)>; + + /** + * @brief Handler constructor. + * @param function a function that takes a vector of spike messages and returns a vector of spikes. + * @param endpoint message endpoint. + * @param uid the uid of this object. + */ + SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) + : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} + { + } + + /** + * @brief Default move constructor. + * @param other object to move from. + */ + SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; + + + /** + * @brief Is not copyable. + */ + SpikeMessageHandler(const SpikeMessageHandler &) = delete; + + + /** + * @brief Subscribe handler to a number of other entities. + * @param entities network uids. + * @note For internal use, don't try to call it manually. + */ + void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } + + + /** + * @brief Read, process and send messages. + * @param step current step. + * @note for internal use, don't try to call it manually. + */ + void update(size_t step); + + + /** + * @brief Get handler UID. + * @return object UID. + */ + [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; + + + /** + * @brief Get a tag. + * @param tag_name tag name. + * @return tag value reference. + */ + [[nodiscard]] std::any &get_tag(const std::string &tag_name) { return base_.tags_[tag_name]; }; + +private: + FunctionType message_handler_function_; + core::MessageEndpoint endpoint_; + knp::core::BaseData base_; +}; + + +/** + * @brief A modifier functor to process spikes and select random K spikes out of the whole set. + * @note Only processes a single message. + */ +class KWtaRandomHandler +{ +public: + /** + * @brief Constructor + * @param winners_number Max number of output spikes. + * @param seed random generator seed. + * @note uses mt19937 for random number generation. + */ + explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0) + : num_winners_(winners_number), random_engine_(seed) + { + } + + + /** + * @brief operator that takes a number of messages and returns a set of spikes. + * @param messages spike messages. + * @return spikes data containing no more than K spikes. + * @note it's assumed that it gets no more than one message per step, so all messages except first are ignored. + */ + knp::core::messaging::SpikeData operator()(std::vector &messages); + +private: + size_t num_winners_; + std::mt19937 random_engine_; + std::uniform_int_distribution distribution_; +}; + + +/** + * @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once. + * @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order. + * @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups. + */ +class GroupWtaRandomHandler +{ +public: + /** + * @brief Functor constructor. + * @param group_borders right borders of the intervals. + * @param num_winning_groups max number of groups that are allowed to pass their spikes further. + * @param seed seed for internal random number generator. + */ + explicit GroupWtaRandomHandler( + const std::vector &group_borders, size_t num_winning_groups = 1, int seed = 0) + : group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed) + { + std::sort(group_borders_.begin(), group_borders_.end()); + } + + + /** + * @brief Functor operator. + * @param messages input messages. + * @return spikes from winning groups. + */ + knp::core::messaging::SpikeData operator()(const std::vector &messages); + +private: + std::vector group_borders_; + size_t num_winners_; + std::mt19937 random_engine_; + std::uniform_int_distribution distribution_; +}; + + +/** + * @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message. + */ +class SpikeUnionHandler +{ +public: + /** + * @brief Functor operator, receives a vector of messages, returns a union of all spike sets from those messages. + * @param messages incoming spike messages. + * @return spikes vector containing the union of input message spike sets. + */ + knp::core::messaging::SpikeData operator()(const std::vector &messages); +}; + + +} // namespace knp::framework::modifier diff --git a/knp/base-framework/include/knp/framework/model_executor.h b/knp/base-framework/include/knp/framework/model_executor.h index 3281c6cc..ff5a17fd 100644 --- a/knp/base-framework/include/knp/framework/model_executor.h +++ b/knp/base-framework/include/knp/framework/model_executor.h @@ -4,18 +4,18 @@ * @kaspersky_support Artiom N. * @date 21.04.2023 * @license Apache 2.0 - * @copyright © 2024 AO Kaspersky Lab - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and * limitations under the License. */ @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -93,6 +94,18 @@ class KNP_DECLSPEC ModelExecutor std::visit([&senders](auto &entity) { entity.subscribe(senders); }, observers_.back()); } + /** + * @brief Add spike message handler to executor. + * @param message_handler_function functor to process received messages. + * @param senders list of entities sending messages to the handler. + * @param receivers list of entities receiving messages from handler. + * @param uid handler uid. + */ + void add_message_handler( + typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, + const std::vector &senders, const std::vector &receivers, + const knp::core::UID &uid = knp::core::UID{}); + /** * @brief Unlock synapse weights. */ @@ -120,5 +133,6 @@ class KNP_DECLSPEC ModelExecutor ModelLoader loader_; std::vector observers_; + std::vector message_handlers_; }; } // namespace knp::framework diff --git a/knp/base-framework/include/knp/framework/monitoring/observer.h b/knp/base-framework/include/knp/framework/monitoring/observer.h index d91a1afc..05277bb0 100644 --- a/knp/base-framework/include/knp/framework/monitoring/observer.h +++ b/knp/base-framework/include/knp/framework/monitoring/observer.h @@ -44,7 +44,7 @@ namespace knp::framework::monitoring * @tparam Message type of messages the functor processes. */ template -using MessageProcessor = std::function)>; +using MessageProcessor = std::function &)>; /** diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp new file mode 100644 index 00000000..3d2f7bf6 --- /dev/null +++ b/knp/tests/framework/message_handler_test.cpp @@ -0,0 +1,174 @@ +/** + * @file message_handler_test.cpp + * @brief Model handler class testing. + * @kaspersky_support A. Vartenkov + * @date 27.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include +#include +#include +#include +#include +#include + +#include + + +TEST(MessageHandlerSuite, MessageHandlerWTA) +{ + std::random_device rd; + int seed = rd(); + SPDLOG_DEBUG("Seed is {}", seed); + knp::framework::modifier::KWtaRandomHandler kwta_handler(2, seed); + knp::core::messaging::SpikeMessage message({{knp::core::UID{}}, {1, 2, 3, 4, 5}}); + std::vector msg_vec{message}; + auto out_data = kwta_handler(msg_vec); + ASSERT_EQ(out_data.size(), 2); + ASSERT_NE(out_data[0], out_data[1]); + SPDLOG_DEBUG("Selected spikes are {} and {}", out_data[0], out_data[1]); + + msg_vec[0].neuron_indexes_ = {7}; + out_data = kwta_handler(msg_vec); + ASSERT_EQ(out_data.size(), 1); + ASSERT_EQ(out_data[0], 7); +} + + +TEST(MessageHandlerSuite, MessageHandlerGroupWTASingle) +{ + std::random_device rd; + int seed = rd(); + SPDLOG_DEBUG("Seed is {}", seed); + // the intervals are [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, ...] + knp::framework::modifier::GroupWtaRandomHandler group_handler({3, 6, 9}, 1, seed); + + // Message contains two spikes in group 0, one in group 1 and one in group 2, group 0 should be selected. + knp::core::messaging::SpikeMessage message({{knp::core::UID{}}, {1, 2, 3, 6}}); + auto out_data = group_handler({message}); + // Spikes 1 and 2 are to be passed. + ASSERT_EQ(out_data.size(), 2); + ASSERT_EQ(out_data[0] + out_data[1], 3); + + // Message contains two spikes in group 1 and two in group 3, either group can be selected. + message = {{knp::core::UID{}}, {1, 3, 5, 6, 9, 10}}; + out_data = group_handler({message}); + ASSERT_EQ(out_data.size(), 2); + // Either group 1 or group 3 is selected. + ASSERT_TRUE(out_data[0] + out_data[1] == 8 || out_data[0] + out_data[1] == 19); + SPDLOG_DEBUG("Selected values are {} and {}.", out_data[0], out_data[1]); +} + + +namespace knp::testing +{ + +class STestingBack : public knp::backends::single_threaded_cpu::SingleThreadedCPUBackend +{ +public: + STestingBack() = default; + void _init() override { knp::backends::single_threaded_cpu::SingleThreadedCPUBackend::_init(); } +}; + +} // namespace knp::testing + + +using BlifatParams = knp::neuron_traits::neuron_parameters; +using DeltaParams = knp::synapse_traits::synapse_parameters; +using DeltaProjection = knp::core::Projection; +using BlifatPopulation = knp::core::Population; + + +DeltaProjection::Synapse input_synapse_generator(size_t index) +{ + return {{1.0, 1, knp::synapse_traits::OutputType::EXCITATORY}, 0, index}; +} + + +DeltaProjection::Synapse intermediate_synapse_generator(size_t index) +{ + return {{1.0, 1, knp::synapse_traits::OutputType::EXCITATORY}, index, index}; +} + + +TEST(MessageHandlerSuite, NetworkIntegrationTest) +{ + // In this test we are making a network that consists of: + // - Input projection. It takes a zero input and translates it to all input neurons, instantly activating them. + // - Input population, that consists of 6 regular BLIFAT neurons. + // - Modifier. It divides incoming spikes into 3 groups: 0, 1; 2, 3; 4, 5 and selects only one of them. + // - Intermediate projection that receives spikes from modifier and activates a corresponding neuron in output + // population. + // - Output population, still 6 neurons, but only two neighbouring neurons should ever be activated, at random. + auto back_path = knp::testing::get_backend_path(); + constexpr int num_neurons = 6; + // A population of 3 groups of 2 neurons per group. + BlifatPopulation population([](size_t) { return BlifatParams{}; }, num_neurons); + + BlifatPopulation output_population([](size_t) { return BlifatParams{}; }, num_neurons); + const knp::core::UID in_pop_uid = population.get_uid(); + const knp::core::UID out_pop_uid = output_population.get_uid(); + // A projection that activates all neurons simultaneously. + DeltaProjection input_projection{knp::core::UID{false}, population.get_uid(), input_synapse_generator, num_neurons}; + + // A projection that should later be connected to a modifier. + DeltaProjection inter_projection{ + knp::core::UID{false}, output_population.get_uid(), intermediate_synapse_generator, num_neurons}; + + const knp::core::UID input_proj_uid = input_projection.get_uid(); + const knp::core::UID inter_proj_uid = inter_projection.get_uid(); + + // Create network. + knp::framework::Network network; + network.add_population(std::move(population)); + network.add_population(std::move(output_population)); + network.add_projection(std::move(input_projection)); + network.add_projection(std::move(inter_projection)); + + // Move network to model and add input and output channels. + knp::framework::Model model(std::move(network)); + const knp::core::UID input_uid, output_uid; + model.add_input_channel(input_uid, input_proj_uid); + model.add_output_channel(output_uid, out_pop_uid); + + // Generate an input spike at each step. + auto input_gen = [](size_t step) { return knp::core::messaging::SpikeData{0}; }; + + knp::framework::BackendLoader backend_loader; + knp::framework::ModelExecutor model_executor( + model, backend_loader.load(knp::testing::get_backend_path()), {{input_uid, input_gen}}); + + auto &out_channel = model_executor.get_loader().get_output_channel(output_uid); + const std::vector group_borders{2, 4}; + const knp::core::UID handler_uid; + model_executor.add_message_handler( + knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid); + constexpr int num_steps = 20; + model_executor.start([](size_t step) { return step < num_steps; }); + + const auto &spikes = out_channel.update(); + + ASSERT_GE(spikes.size(), 10); + for (const auto &msg : spikes) + { + ASSERT_GE(msg.header_.send_time_, 3); + ASSERT_EQ(msg.neuron_indexes_.size(), 2); + ASSERT_EQ(std::abs(static_cast(msg.neuron_indexes_[0]) - static_cast(msg.neuron_indexes_[1])), 1); + } +} From 85199e8b0a22dbbf4acfd6e72f17bc7b7eeb1dbf Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Fri, 29 Nov 2024 13:22:11 +0300 Subject: [PATCH 2/8] Add test for union: #0000023 --- knp/tests/framework/message_handler_test.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp index 3d2f7bf6..89d01a74 100644 --- a/knp/tests/framework/message_handler_test.cpp +++ b/knp/tests/framework/message_handler_test.cpp @@ -76,6 +76,19 @@ TEST(MessageHandlerSuite, MessageHandlerGroupWTASingle) } +TEST(MessageHandlerSuite, SpikeUnionHandler) +{ + knp::framework::modifier::SpikeUnionHandler union_handler; + knp::core::messaging::SpikeMessage message_1 = {{knp::core::UID{}, 0}, {1, 3, 5}}; + knp::core::messaging::SpikeMessage message_2 = {{knp::core::UID{}, 0}, {0, 1, 3}}; + knp::core::messaging::SpikeMessage message_3 = {{knp::core::UID{}, 0}, {3, 4, 7}}; + auto result = union_handler({message_1, message_2, message_3}); + std::sort(result.begin(), result.end()); + const decltype(result) expected_result{0, 1, 3, 4, 5, 7}; + ASSERT_EQ(result, expected_result); +} + + namespace knp::testing { @@ -90,7 +103,6 @@ class STestingBack : public knp::backends::single_threaded_cpu::SingleThreadedCP using BlifatParams = knp::neuron_traits::neuron_parameters; -using DeltaParams = knp::synapse_traits::synapse_parameters; using DeltaProjection = knp::core::Projection; using BlifatPopulation = knp::core::Population; From 6480b11dfbef6f4b9b6595f3acc96a98b47cde12 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Fri, 29 Nov 2024 13:40:29 +0300 Subject: [PATCH 3/8] Add docstrings for members: #0000023 --- .../include/knp/framework/message_handler.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/knp/base-framework/include/knp/framework/message_handler.h b/knp/base-framework/include/knp/framework/message_handler.h index 09fe4c7e..a97fe9ac 100644 --- a/knp/base-framework/include/knp/framework/message_handler.h +++ b/knp/base-framework/include/knp/framework/message_handler.h @@ -40,8 +40,19 @@ namespace knp::framework::modifier class SpikeMessageHandler { public: + /** + * @brief Input message type. + */ using MessageIn = knp::core::messaging::SpikeMessage; + + /** + * @brief Output message type. + */ using MessageOut = knp::core::messaging::SpikeMessage; + + /** + * @brief Functor type. + */ using FunctionType = std::function &)>; /** From f65efcaaa5f96200866eb5f5a45d2c24965bbeda Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Fri, 29 Nov 2024 16:53:58 +0300 Subject: [PATCH 4/8] Fix comment points and newlines: #0000023 --- knp/base-framework/impl/message_handler.cpp | 33 +++++++++++++++---- .../include/knp/framework/message_handler.h | 16 +++------ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/knp/base-framework/impl/message_handler.cpp b/knp/base-framework/impl/message_handler.cpp index 4dc33376..a396d279 100644 --- a/knp/base-framework/impl/message_handler.cpp +++ b/knp/base-framework/impl/message_handler.cpp @@ -1,11 +1,30 @@ -// -// Created by an_vartenkov on 22.11.24. -// +/** + * @file message_handler.cpp + * @brief Implementation of message handler functionality. + * @kaspersky_support A. Vartenkov + * @date 25.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include #include + namespace knp::framework::modifier { @@ -50,20 +69,20 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( std::vector> spikes_per_group(group_borders_.size() + 1); - // Fill groups in - for (auto spike : spikes) + // Fill groups in. + for (const auto &spike : spikes) { const size_t group_index = std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin(); spikes_per_group[group_index].push_back(spike); } - // Sort groups by number of elements + // Sort groups by number of elements. std::sort( spikes_per_group.begin(), spikes_per_group.end(), [](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); - // Find all groups with the same number of spikes as the K-th one + // Find all groups with the same number of spikes as the K-th one. const auto &last_group = spikes_per_group[num_winners_ - 1]; auto group_interval = std::equal_range( spikes_per_group.begin(), spikes_per_group.end(), last_group, diff --git a/knp/base-framework/include/knp/framework/message_handler.h b/knp/base-framework/include/knp/framework/message_handler.h index a97fe9ac..7a9fcf39 100644 --- a/knp/base-framework/include/knp/framework/message_handler.h +++ b/knp/base-framework/include/knp/framework/message_handler.h @@ -72,13 +72,11 @@ class SpikeMessageHandler */ SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; - /** * @brief Is not copyable. */ SpikeMessageHandler(const SpikeMessageHandler &) = delete; - /** * @brief Subscribe handler to a number of other entities. * @param entities network uids. @@ -86,7 +84,6 @@ class SpikeMessageHandler */ void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } - /** * @brief Read, process and send messages. * @param step current step. @@ -94,14 +91,12 @@ class SpikeMessageHandler */ void update(size_t step); - /** * @brief Get handler UID. * @return object UID. */ [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; - /** * @brief Get a tag. * @param tag_name tag name. @@ -124,7 +119,7 @@ class KWtaRandomHandler { public: /** - * @brief Constructor + * @brief Constructor. * @param winners_number Max number of output spikes. * @param seed random generator seed. * @note uses mt19937 for random number generation. @@ -134,9 +129,8 @@ class KWtaRandomHandler { } - /** - * @brief operator that takes a number of messages and returns a set of spikes. + * @brief Function call operator that takes a number of messages and returns a set of spikes. * @param messages spike messages. * @return spikes data containing no more than K spikes. * @note it's assumed that it gets no more than one message per step, so all messages except first are ignored. @@ -171,9 +165,8 @@ class GroupWtaRandomHandler std::sort(group_borders_.begin(), group_borders_.end()); } - /** - * @brief Functor operator. + * @brief Function call operator. * @param messages input messages. * @return spikes from winning groups. */ @@ -194,7 +187,8 @@ class SpikeUnionHandler { public: /** - * @brief Functor operator, receives a vector of messages, returns a union of all spike sets from those messages. + * @brief Function call operator, receives a vector of messages, returns a union of all spike sets from those + * messages. * @param messages incoming spike messages. * @return spikes vector containing the union of input message spike sets. */ From e70e236af69301b0cc714bdbe3326b99c71e9e34 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Fri, 29 Nov 2024 19:47:30 +0300 Subject: [PATCH 5/8] Add full stop to comment in test: #0000023 --- knp/tests/framework/message_handler_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp index 89d01a74..d55f0fb1 100644 --- a/knp/tests/framework/message_handler_test.cpp +++ b/knp/tests/framework/message_handler_test.cpp @@ -56,7 +56,7 @@ TEST(MessageHandlerSuite, MessageHandlerGroupWTASingle) std::random_device rd; int seed = rd(); SPDLOG_DEBUG("Seed is {}", seed); - // the intervals are [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, ...] + // the intervals are [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, ...]. knp::framework::modifier::GroupWtaRandomHandler group_handler({3, 6, 9}, 1, seed); // Message contains two spikes in group 0, one in group 1 and one in group 2, group 0 should be selected. From e5fe53d8f19a5929c4a3e9ac95594e381cc75b14 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Mon, 2 Dec 2024 20:32:28 +0300 Subject: [PATCH 6/8] Moved MessageHandler inside ModelExecutor: #0000023 --- knp/base-framework/CMakeLists.txt | 2 +- ...ssage_handler.cpp => message_handlers.cpp} | 53 +++++++----- knp/base-framework/impl/model_executor.cpp | 24 ++++-- .../{message_handler.h => message_handlers.h} | 81 +------------------ .../include/knp/framework/model_executor.h | 49 +++++++++-- knp/tests/framework/message_handler_test.cpp | 4 +- 6 files changed, 98 insertions(+), 115 deletions(-) rename knp/base-framework/impl/{message_handler.cpp => message_handlers.cpp} (78%) rename knp/base-framework/include/knp/framework/{message_handler.h => message_handlers.h} (63%) diff --git a/knp/base-framework/CMakeLists.txt b/knp/base-framework/CMakeLists.txt index c9d18047..7c6126e6 100644 --- a/knp/base-framework/CMakeLists.txt +++ b/knp/base-framework/CMakeLists.txt @@ -59,7 +59,7 @@ knp_add_library("${PROJECT_NAME}-core" impl/model.cpp impl/model_executor.cpp impl/model_loader.cpp - impl/message_handler.cpp + impl/message_handlers.cpp impl/input_converter.cpp impl/output_channel.cpp impl/synchronization.cpp diff --git a/knp/base-framework/impl/message_handler.cpp b/knp/base-framework/impl/message_handlers.cpp similarity index 78% rename from knp/base-framework/impl/message_handler.cpp rename to knp/base-framework/impl/message_handlers.cpp index a396d279..ad900fe0 100644 --- a/knp/base-framework/impl/message_handler.cpp +++ b/knp/base-framework/impl/message_handlers.cpp @@ -1,5 +1,5 @@ /** - * @file message_handler.cpp + * @file message_handlers.cpp * @brief Implementation of message handler functionality. * @kaspersky_support A. Vartenkov * @date 25.11.2024 @@ -19,33 +19,30 @@ * limitations under the License. */ -#include +#include #include #include +/** + * @brief namespace for message modifier callables. + */ namespace knp::framework::modifier { -void SpikeMessageHandler::update(size_t step) +knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) { - endpoint_.receive_all_messages(); - auto incoming_messages = endpoint_.unload_messages(base_.uid_); - MessageOut outgoing_message = {{base_.uid_, step}, message_handler_function_(incoming_messages)}; - if (!(outgoing_message.neuron_indexes_.empty())) + if (messages.empty()) { - endpoint_.send_message(outgoing_message); + return {}; } -} - - -knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) -{ - if (messages.empty()) return {}; auto &msg = messages[0]; - if (msg.neuron_indexes_.size() < num_winners_) return msg.neuron_indexes_; + if (msg.neuron_indexes_.size() < num_winners_) + { + return msg.neuron_indexes_; + } knp::core::messaging::SpikeData out_spikes; for (size_t i = 0; i < num_winners_; ++i) @@ -54,6 +51,7 @@ knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) { - if (messages.empty()) return {}; - if (num_winners_ > group_borders_.size()) return messages[0].neuron_indexes_; + if (messages.empty()) + { + return {}; + } + + if (num_winners_ > group_borders_.size()) + { + return messages[0].neuron_indexes_; + } const auto &spikes = messages[0].neuron_indexes_; - if (spikes.empty()) return {}; + if (spikes.empty()) + { + return {}; + } - std::vector> spikes_per_group(group_borders_.size() + 1); + std::vector spikes_per_group(group_borders_.size() + 1); // Fill groups in. for (const auto &spike : spikes) @@ -94,7 +102,10 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( knp::core::messaging::SpikeData result; for (size_t i = 0; i < num_winners_; ++i) { - for (auto spike : spikes_per_group[i]) result.push_back(spike); + for (const auto &spike : spikes_per_group[i]) + { + result.push_back(spike); + } } return result; } @@ -103,7 +114,7 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( knp::core::messaging::SpikeData SpikeUnionHandler::operator()( const std::vector &messages) { - std::unordered_set spikes; + std::unordered_set spikes; for (const auto &msg : messages) { spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end()); diff --git a/knp/base-framework/impl/model_executor.cpp b/knp/base-framework/impl/model_executor.cpp index 1ed01dc6..22f037d7 100644 --- a/knp/base-framework/impl/model_executor.cpp +++ b/knp/base-framework/impl/model_executor.cpp @@ -78,17 +78,29 @@ void ModelExecutor::stop() } -void ModelExecutor::add_message_handler( - typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, - const std::vector &senders, const std::vector &receivers, const knp::core::UID &uid) +void ModelExecutor::SpikeMessageHandler::update(size_t step) +{ + endpoint_.receive_all_messages(); + auto incoming_messages = endpoint_.unload_messages(base_.uid_); + knp::core::messaging::SpikeMessage outgoing_message = { + {base_.uid_, step}, message_handler_function_(incoming_messages)}; + if (!(outgoing_message.neuron_indexes_.empty())) + { + endpoint_.send_message(outgoing_message); + } +} + + +void ModelExecutor::add_spike_message_handler( + typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid) { knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint(); - message_handlers_.emplace_back( - modifier::SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); + message_handlers_.emplace_back(SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); message_handlers_.back().subscribe(senders); for (const knp::core::UID &rec_uid : receivers) { - get_backend()->subscribe(rec_uid, {uid}); + get_backend()->subscribe(rec_uid, {uid}); } } diff --git a/knp/base-framework/include/knp/framework/message_handler.h b/knp/base-framework/include/knp/framework/message_handlers.h similarity index 63% rename from knp/base-framework/include/knp/framework/message_handler.h rename to knp/base-framework/include/knp/framework/message_handlers.h index 7a9fcf39..c786a947 100644 --- a/knp/base-framework/include/knp/framework/message_handler.h +++ b/knp/base-framework/include/knp/framework/message_handlers.h @@ -1,6 +1,6 @@ /** - * @file message_handler.h - * @brief A class that processes a number of messages then sends messages of its own. + * @file message_handlers.h + * @brief A set of predefined message handling functions to add to model executor. * @kaspersky_support Vartenkov A. * @date 19.11.2024 * @license Apache 2.0 @@ -34,83 +34,6 @@ namespace knp::framework::modifier { -/** - * @brief An object that receives and processes messages. - */ -class SpikeMessageHandler -{ -public: - /** - * @brief Input message type. - */ - using MessageIn = knp::core::messaging::SpikeMessage; - - /** - * @brief Output message type. - */ - using MessageOut = knp::core::messaging::SpikeMessage; - - /** - * @brief Functor type. - */ - using FunctionType = std::function &)>; - - /** - * @brief Handler constructor. - * @param function a function that takes a vector of spike messages and returns a vector of spikes. - * @param endpoint message endpoint. - * @param uid the uid of this object. - */ - SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) - : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} - { - } - - /** - * @brief Default move constructor. - * @param other object to move from. - */ - SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; - - /** - * @brief Is not copyable. - */ - SpikeMessageHandler(const SpikeMessageHandler &) = delete; - - /** - * @brief Subscribe handler to a number of other entities. - * @param entities network uids. - * @note For internal use, don't try to call it manually. - */ - void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } - - /** - * @brief Read, process and send messages. - * @param step current step. - * @note for internal use, don't try to call it manually. - */ - void update(size_t step); - - /** - * @brief Get handler UID. - * @return object UID. - */ - [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; - - /** - * @brief Get a tag. - * @param tag_name tag name. - * @return tag value reference. - */ - [[nodiscard]] std::any &get_tag(const std::string &tag_name) { return base_.tags_[tag_name]; }; - -private: - FunctionType message_handler_function_; - core::MessageEndpoint endpoint_; - knp::core::BaseData base_; -}; - - /** * @brief A modifier functor to process spikes and select random K spikes out of the whole set. * @note Only processes a single message. diff --git a/knp/base-framework/include/knp/framework/model_executor.h b/knp/base-framework/include/knp/framework/model_executor.h index ff5a17fd..f25bedbb 100644 --- a/knp/base-framework/include/knp/framework/model_executor.h +++ b/knp/base-framework/include/knp/framework/model_executor.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -94,6 +94,12 @@ class KNP_DECLSPEC ModelExecutor std::visit([&senders](auto &entity) { entity.subscribe(senders); }, observers_.back()); } + /** + * @brief Function type for message handlers. + */ + using SpikeHandlerFunction = + std::function &)>; + /** * @brief Add spike message handler to executor. * @param message_handler_function functor to process received messages. @@ -101,10 +107,9 @@ class KNP_DECLSPEC ModelExecutor * @param receivers list of entities receiving messages from handler. * @param uid handler uid. */ - void add_message_handler( - typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function, - const std::vector &senders, const std::vector &receivers, - const knp::core::UID &uid = knp::core::UID{}); + void add_spike_message_handler( + SpikeHandlerFunction &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid = knp::core::UID{}); /** * @brief Unlock synapse weights. @@ -129,10 +134,42 @@ class KNP_DECLSPEC ModelExecutor auto &get_loader() { return loader_; } private: + /** + * @brief An object that receives and processes messages. + */ + class SpikeMessageHandler + { + public: + using MessageIn = knp::core::messaging::SpikeMessage; + using MessageData = knp::core::messaging::SpikeData; + using FunctionType = std::function &)>; + + SpikeMessageHandler( + FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) + : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} + { + } + + SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; + + SpikeMessageHandler(const SpikeMessageHandler &) = delete; + + void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } + + void update(size_t step); + + [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; + + private: + FunctionType message_handler_function_; + core::MessageEndpoint endpoint_; + knp::core::BaseData base_; + }; + knp::core::BaseData base_; ModelLoader loader_; std::vector observers_; - std::vector message_handlers_; + std::vector message_handlers_; }; } // namespace knp::framework diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp index d55f0fb1..b14fe722 100644 --- a/knp/tests/framework/message_handler_test.cpp +++ b/knp/tests/framework/message_handler_test.cpp @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include #include @@ -169,7 +169,7 @@ TEST(MessageHandlerSuite, NetworkIntegrationTest) auto &out_channel = model_executor.get_loader().get_output_channel(output_uid); const std::vector group_borders{2, 4}; const knp::core::UID handler_uid; - model_executor.add_message_handler( + model_executor.add_spike_message_handler( knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid); constexpr int num_steps = 20; model_executor.start([](size_t step) { return step < num_steps; }); From 1b9635f591c3e7a17caa9da3ef981d3ef218cca9 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Tue, 3 Dec 2024 11:03:29 +0300 Subject: [PATCH 7/8] Add num_steps to capture: #0000023 --- knp/tests/framework/message_handler_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp index b14fe722..ccf528bb 100644 --- a/knp/tests/framework/message_handler_test.cpp +++ b/knp/tests/framework/message_handler_test.cpp @@ -172,7 +172,7 @@ TEST(MessageHandlerSuite, NetworkIntegrationTest) model_executor.add_spike_message_handler( knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid); constexpr int num_steps = 20; - model_executor.start([](size_t step) { return step < num_steps; }); + model_executor.start([num_steps](size_t step) { return step < num_steps; }); const auto &spikes = out_channel.update(); From 636fd86c2cec16ad26b773fbee67a3e4baff5500 Mon Sep 17 00:00:00 2001 From: Andrey Vartenkov Date: Wed, 4 Dec 2024 11:04:27 +0300 Subject: [PATCH 8/8] Hide declaration of SpikeMessageHandler: #0000023 --- knp/base-framework/impl/model_executor.cpp | 76 +++++++++++++++---- .../include/knp/framework/model_executor.h | 45 ++--------- 2 files changed, 68 insertions(+), 53 deletions(-) diff --git a/knp/base-framework/impl/model_executor.cpp b/knp/base-framework/impl/model_executor.cpp index 22f037d7..71c18bea 100644 --- a/knp/base-framework/impl/model_executor.cpp +++ b/knp/base-framework/impl/model_executor.cpp @@ -27,6 +27,62 @@ namespace knp::framework { + +class ModelExecutor::SpikeMessageHandler +{ +public: + using MessageIn = knp::core::messaging::SpikeMessage; + using MessageData = knp::core::messaging::SpikeData; + using FunctionType = std::function &)>; + + SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) + : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} + { + } + + SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; + + SpikeMessageHandler(const SpikeMessageHandler &) = delete; + + void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } + + void update(size_t step); + + [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; + + ~SpikeMessageHandler() = default; + +private: + FunctionType message_handler_function_; + knp::core::MessageEndpoint endpoint_; + knp::core::BaseData base_; +}; + + +void ModelExecutor::SpikeMessageHandler::update(size_t step) +{ + endpoint_.receive_all_messages(); + auto incoming_messages = endpoint_.unload_messages(base_.uid_); + knp::core::messaging::SpikeMessage outgoing_message = { + {base_.uid_, step}, message_handler_function_(incoming_messages)}; + if (!(outgoing_message.neuron_indexes_.empty())) + { + endpoint_.send_message(outgoing_message); + } +} + + +ModelExecutor::ModelExecutor( + knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map) + : loader_(backend, i_map) +{ + loader_.load(model); +} + + +ModelExecutor::~ModelExecutor() = default; + + void ModelExecutor::start() { start([](knp::core::Step) { return true; }); @@ -58,7 +114,7 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate) // Running handlers for (auto &handler : message_handlers_) { - handler.update(get_backend()->get_step()); + handler->update(get_backend()->get_step()); } // Run monitoring observers. for (auto &observer : observers_) @@ -78,26 +134,14 @@ void ModelExecutor::stop() } -void ModelExecutor::SpikeMessageHandler::update(size_t step) -{ - endpoint_.receive_all_messages(); - auto incoming_messages = endpoint_.unload_messages(base_.uid_); - knp::core::messaging::SpikeMessage outgoing_message = { - {base_.uid_, step}, message_handler_function_(incoming_messages)}; - if (!(outgoing_message.neuron_indexes_.empty())) - { - endpoint_.send_message(outgoing_message); - } -} - - void ModelExecutor::add_spike_message_handler( typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector &senders, const std::vector &receivers, const knp::core::UID &uid) { knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint(); - message_handlers_.emplace_back(SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid}); - message_handlers_.back().subscribe(senders); + message_handlers_.emplace_back( + std::make_unique(std::move(message_handler_function), std::move(endpoint), uid)); + message_handlers_.back()->subscribe(senders); for (const knp::core::UID &rec_uid : receivers) { get_backend()->subscribe(rec_uid, {uid}); diff --git a/knp/base-framework/include/knp/framework/model_executor.h b/knp/base-framework/include/knp/framework/model_executor.h index f25bedbb..f6e82d72 100644 --- a/knp/base-framework/include/knp/framework/model_executor.h +++ b/knp/base-framework/include/knp/framework/model_executor.h @@ -55,11 +55,12 @@ class KNP_DECLSPEC ModelExecutor * @param i_map input channel map. */ ModelExecutor( - knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map) - : loader_(backend, i_map) - { - loader_.load(model); - } + knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map); + + /** + * @brief ModelExecutor destructor. + */ + ~ModelExecutor(); public: /** @@ -134,42 +135,12 @@ class KNP_DECLSPEC ModelExecutor auto &get_loader() { return loader_; } private: - /** - * @brief An object that receives and processes messages. - */ - class SpikeMessageHandler - { - public: - using MessageIn = knp::core::messaging::SpikeMessage; - using MessageData = knp::core::messaging::SpikeData; - using FunctionType = std::function &)>; - - SpikeMessageHandler( - FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) - : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} - { - } - - SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; - - SpikeMessageHandler(const SpikeMessageHandler &) = delete; - - void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } - - void update(size_t step); - - [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; - - private: - FunctionType message_handler_function_; - core::MessageEndpoint endpoint_; - knp::core::BaseData base_; - }; + class SpikeMessageHandler; knp::core::BaseData base_; ModelLoader loader_; std::vector observers_; - std::vector message_handlers_; + std::vector> message_handlers_; }; } // namespace knp::framework