diff --git a/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h b/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h index 96e422e2..df4c4b90 100644 --- a/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h +++ b/knp/backends/cpu/cpu-library/include/knp/backends/cpu-library/impl/synaptic_resource_stdp_impl.h @@ -320,8 +320,8 @@ void do_dopamine_plasticity( if (step - synapse->rule_.last_spike_step_ < synapse->rule_.dopamine_plasticity_period_) { // Change synapse resource. - float d_r = - neuron.dopamine_value_ * std::min(static_cast(std::pow(2, -neuron.stability_)), 1.F); + float d_r = neuron.dopamine_value_ * + std::min(static_cast(std::pow(2, -neuron.stability_)), 1.F) / 1000.F; synapse->rule_.synaptic_resource_ += d_r; neuron.free_synaptic_resource_ -= d_r; } diff --git a/knp/base-framework/CMakeLists.txt b/knp/base-framework/CMakeLists.txt index 2ee1af22..7c6126e6 100644 --- a/knp/base-framework/CMakeLists.txt +++ b/knp/base-framework/CMakeLists.txt @@ -59,6 +59,7 @@ knp_add_library("${PROJECT_NAME}-core" impl/model.cpp impl/model_executor.cpp impl/model_loader.cpp + impl/message_handlers.cpp impl/input_converter.cpp impl/output_channel.cpp impl/synchronization.cpp diff --git a/knp/base-framework/impl/message_handlers.cpp b/knp/base-framework/impl/message_handlers.cpp new file mode 100644 index 00000000..ad900fe0 --- /dev/null +++ b/knp/base-framework/impl/message_handlers.cpp @@ -0,0 +1,127 @@ +/** + * @file message_handlers.cpp + * @brief Implementation of message handler functionality. + * @kaspersky_support A. Vartenkov + * @date 25.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + + +/** + * @brief namespace for message modifier callables. + */ +namespace knp::framework::modifier +{ + +knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector &messages) +{ + if (messages.empty()) + { + return {}; + } + + auto &msg = messages[0]; + if (msg.neuron_indexes_.size() < num_winners_) + { + return msg.neuron_indexes_; + } + + knp::core::messaging::SpikeData out_spikes; + for (size_t i = 0; i < num_winners_; ++i) + { + const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i); + out_spikes.push_back(msg.neuron_indexes_[index]); + std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]); + } + + return out_spikes; +} + + +knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( + const std::vector &messages) +{ + if (messages.empty()) + { + return {}; + } + + if (num_winners_ > group_borders_.size()) + { + return messages[0].neuron_indexes_; + } + + const auto &spikes = messages[0].neuron_indexes_; + if (spikes.empty()) + { + return {}; + } + + std::vector spikes_per_group(group_borders_.size() + 1); + + // Fill groups in. + for (const auto &spike : spikes) + { + const size_t group_index = + std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin(); + spikes_per_group[group_index].push_back(spike); + } + + // Sort groups by number of elements. + std::sort( + spikes_per_group.begin(), spikes_per_group.end(), + [](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); + + // Find all groups with the same number of spikes as the K-th one. + const auto &last_group = spikes_per_group[num_winners_ - 1]; + auto group_interval = std::equal_range( + spikes_per_group.begin(), spikes_per_group.end(), last_group, + [](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); + const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1; + assert(already_decided <= num_winners_); + // The approach could be more efficient, but I don't think it's necessary. + std::shuffle(group_interval.first, group_interval.second, random_engine_); + knp::core::messaging::SpikeData result; + for (size_t i = 0; i < num_winners_; ++i) + { + for (const auto &spike : spikes_per_group[i]) + { + result.push_back(spike); + } + } + return result; +} + + +knp::core::messaging::SpikeData SpikeUnionHandler::operator()( + const std::vector &messages) +{ + std::unordered_set spikes; + for (const auto &msg : messages) + { + spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end()); + } + knp::core::messaging::SpikeData result; + result.reserve(spikes.size()); + std::copy(spikes.begin(), spikes.end(), std::back_inserter(result)); + return result; +} +} // namespace knp::framework::modifier diff --git a/knp/base-framework/impl/model_executor.cpp b/knp/base-framework/impl/model_executor.cpp index 4914f24a..71c18bea 100644 --- a/knp/base-framework/impl/model_executor.cpp +++ b/knp/base-framework/impl/model_executor.cpp @@ -27,6 +27,62 @@ namespace knp::framework { + +class ModelExecutor::SpikeMessageHandler +{ +public: + using MessageIn = knp::core::messaging::SpikeMessage; + using MessageData = knp::core::messaging::SpikeData; + using FunctionType = std::function &)>; + + SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {}) + : message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid} + { + } + + SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default; + + SpikeMessageHandler(const SpikeMessageHandler &) = delete; + + void subscribe(const std::vector &entities) { endpoint_.subscribe(base_.uid_, entities); } + + void update(size_t step); + + [[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; }; + + ~SpikeMessageHandler() = default; + +private: + FunctionType message_handler_function_; + knp::core::MessageEndpoint endpoint_; + knp::core::BaseData base_; +}; + + +void ModelExecutor::SpikeMessageHandler::update(size_t step) +{ + endpoint_.receive_all_messages(); + auto incoming_messages = endpoint_.unload_messages(base_.uid_); + knp::core::messaging::SpikeMessage outgoing_message = { + {base_.uid_, step}, message_handler_function_(incoming_messages)}; + if (!(outgoing_message.neuron_indexes_.empty())) + { + endpoint_.send_message(outgoing_message); + } +} + + +ModelExecutor::ModelExecutor( + knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map) + : loader_(backend, i_map) +{ + loader_.load(model); +} + + +ModelExecutor::~ModelExecutor() = default; + + void ModelExecutor::start() { start([](knp::core::Step) { return true; }); @@ -55,6 +111,11 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate) { o_ch.update(); } + // Running handlers + for (auto &handler : message_handlers_) + { + handler->update(get_backend()->get_step()); + } // Run monitoring observers. for (auto &observer : observers_) { @@ -72,4 +133,19 @@ void ModelExecutor::stop() get_backend()->stop(); } + +void ModelExecutor::add_spike_message_handler( + typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid) +{ + knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint(); + message_handlers_.emplace_back( + std::make_unique(std::move(message_handler_function), std::move(endpoint), uid)); + message_handlers_.back()->subscribe(senders); + for (const knp::core::UID &rec_uid : receivers) + { + get_backend()->subscribe(rec_uid, {uid}); + } +} + } // namespace knp::framework diff --git a/knp/base-framework/include/knp/framework/message_handlers.h b/knp/base-framework/include/knp/framework/message_handlers.h new file mode 100644 index 00000000..c786a947 --- /dev/null +++ b/knp/base-framework/include/knp/framework/message_handlers.h @@ -0,0 +1,122 @@ +/** + * @file message_handlers.h + * @brief A set of predefined message handling functions to add to model executor. + * @kaspersky_support Vartenkov A. + * @date 19.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + + +#include +#include + +#include +#include +#include +#include +#include + + +namespace knp::framework::modifier +{ + +/** + * @brief A modifier functor to process spikes and select random K spikes out of the whole set. + * @note Only processes a single message. + */ +class KWtaRandomHandler +{ +public: + /** + * @brief Constructor. + * @param winners_number Max number of output spikes. + * @param seed random generator seed. + * @note uses mt19937 for random number generation. + */ + explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0) + : num_winners_(winners_number), random_engine_(seed) + { + } + + /** + * @brief Function call operator that takes a number of messages and returns a set of spikes. + * @param messages spike messages. + * @return spikes data containing no more than K spikes. + * @note it's assumed that it gets no more than one message per step, so all messages except first are ignored. + */ + knp::core::messaging::SpikeData operator()(std::vector &messages); + +private: + size_t num_winners_; + std::mt19937 random_engine_; + std::uniform_int_distribution distribution_; +}; + + +/** + * @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once. + * @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order. + * @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups. + */ +class GroupWtaRandomHandler +{ +public: + /** + * @brief Functor constructor. + * @param group_borders right borders of the intervals. + * @param num_winning_groups max number of groups that are allowed to pass their spikes further. + * @param seed seed for internal random number generator. + */ + explicit GroupWtaRandomHandler( + const std::vector &group_borders, size_t num_winning_groups = 1, int seed = 0) + : group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed) + { + std::sort(group_borders_.begin(), group_borders_.end()); + } + + /** + * @brief Function call operator. + * @param messages input messages. + * @return spikes from winning groups. + */ + knp::core::messaging::SpikeData operator()(const std::vector &messages); + +private: + std::vector group_borders_; + size_t num_winners_; + std::mt19937 random_engine_; + std::uniform_int_distribution distribution_; +}; + + +/** + * @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message. + */ +class SpikeUnionHandler +{ +public: + /** + * @brief Function call operator, receives a vector of messages, returns a union of all spike sets from those + * messages. + * @param messages incoming spike messages. + * @return spikes vector containing the union of input message spike sets. + */ + knp::core::messaging::SpikeData operator()(const std::vector &messages); +}; + + +} // namespace knp::framework::modifier diff --git a/knp/base-framework/include/knp/framework/model_executor.h b/knp/base-framework/include/knp/framework/model_executor.h index 3281c6cc..f6e82d72 100644 --- a/knp/base-framework/include/knp/framework/model_executor.h +++ b/knp/base-framework/include/knp/framework/model_executor.h @@ -4,18 +4,18 @@ * @kaspersky_support Artiom N. * @date 21.04.2023 * @license Apache 2.0 - * @copyright © 2024 AO Kaspersky Lab - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and * limitations under the License. */ @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -54,11 +55,12 @@ class KNP_DECLSPEC ModelExecutor * @param i_map input channel map. */ ModelExecutor( - knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map) - : loader_(backend, i_map) - { - loader_.load(model); - } + knp::framework::Model &model, std::shared_ptr backend, ModelLoader::InputChannelMap i_map); + + /** + * @brief ModelExecutor destructor. + */ + ~ModelExecutor(); public: /** @@ -93,6 +95,23 @@ class KNP_DECLSPEC ModelExecutor std::visit([&senders](auto &entity) { entity.subscribe(senders); }, observers_.back()); } + /** + * @brief Function type for message handlers. + */ + using SpikeHandlerFunction = + std::function &)>; + + /** + * @brief Add spike message handler to executor. + * @param message_handler_function functor to process received messages. + * @param senders list of entities sending messages to the handler. + * @param receivers list of entities receiving messages from handler. + * @param uid handler uid. + */ + void add_spike_message_handler( + SpikeHandlerFunction &&message_handler_function, const std::vector &senders, + const std::vector &receivers, const knp::core::UID &uid = knp::core::UID{}); + /** * @brief Unlock synapse weights. */ @@ -116,9 +135,12 @@ class KNP_DECLSPEC ModelExecutor auto &get_loader() { return loader_; } private: + class SpikeMessageHandler; + knp::core::BaseData base_; ModelLoader loader_; std::vector observers_; + std::vector> message_handlers_; }; } // namespace knp::framework diff --git a/knp/base-framework/include/knp/framework/monitoring/observer.h b/knp/base-framework/include/knp/framework/monitoring/observer.h index d91a1afc..05277bb0 100644 --- a/knp/base-framework/include/knp/framework/monitoring/observer.h +++ b/knp/base-framework/include/knp/framework/monitoring/observer.h @@ -44,7 +44,7 @@ namespace knp::framework::monitoring * @tparam Message type of messages the functor processes. */ template -using MessageProcessor = std::function)>; +using MessageProcessor = std::function &)>; /** diff --git a/knp/tests/framework/message_handler_test.cpp b/knp/tests/framework/message_handler_test.cpp new file mode 100644 index 00000000..ccf528bb --- /dev/null +++ b/knp/tests/framework/message_handler_test.cpp @@ -0,0 +1,186 @@ +/** + * @file message_handler_test.cpp + * @brief Model handler class testing. + * @kaspersky_support A. Vartenkov + * @date 27.11.2024 + * @license Apache 2.0 + * @copyright © 2024 AO Kaspersky Lab + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include +#include +#include +#include +#include +#include + +#include + + +TEST(MessageHandlerSuite, MessageHandlerWTA) +{ + std::random_device rd; + int seed = rd(); + SPDLOG_DEBUG("Seed is {}", seed); + knp::framework::modifier::KWtaRandomHandler kwta_handler(2, seed); + knp::core::messaging::SpikeMessage message({{knp::core::UID{}}, {1, 2, 3, 4, 5}}); + std::vector msg_vec{message}; + auto out_data = kwta_handler(msg_vec); + ASSERT_EQ(out_data.size(), 2); + ASSERT_NE(out_data[0], out_data[1]); + SPDLOG_DEBUG("Selected spikes are {} and {}", out_data[0], out_data[1]); + + msg_vec[0].neuron_indexes_ = {7}; + out_data = kwta_handler(msg_vec); + ASSERT_EQ(out_data.size(), 1); + ASSERT_EQ(out_data[0], 7); +} + + +TEST(MessageHandlerSuite, MessageHandlerGroupWTASingle) +{ + std::random_device rd; + int seed = rd(); + SPDLOG_DEBUG("Seed is {}", seed); + // the intervals are [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, ...]. + knp::framework::modifier::GroupWtaRandomHandler group_handler({3, 6, 9}, 1, seed); + + // Message contains two spikes in group 0, one in group 1 and one in group 2, group 0 should be selected. + knp::core::messaging::SpikeMessage message({{knp::core::UID{}}, {1, 2, 3, 6}}); + auto out_data = group_handler({message}); + // Spikes 1 and 2 are to be passed. + ASSERT_EQ(out_data.size(), 2); + ASSERT_EQ(out_data[0] + out_data[1], 3); + + // Message contains two spikes in group 1 and two in group 3, either group can be selected. + message = {{knp::core::UID{}}, {1, 3, 5, 6, 9, 10}}; + out_data = group_handler({message}); + ASSERT_EQ(out_data.size(), 2); + // Either group 1 or group 3 is selected. + ASSERT_TRUE(out_data[0] + out_data[1] == 8 || out_data[0] + out_data[1] == 19); + SPDLOG_DEBUG("Selected values are {} and {}.", out_data[0], out_data[1]); +} + + +TEST(MessageHandlerSuite, SpikeUnionHandler) +{ + knp::framework::modifier::SpikeUnionHandler union_handler; + knp::core::messaging::SpikeMessage message_1 = {{knp::core::UID{}, 0}, {1, 3, 5}}; + knp::core::messaging::SpikeMessage message_2 = {{knp::core::UID{}, 0}, {0, 1, 3}}; + knp::core::messaging::SpikeMessage message_3 = {{knp::core::UID{}, 0}, {3, 4, 7}}; + auto result = union_handler({message_1, message_2, message_3}); + std::sort(result.begin(), result.end()); + const decltype(result) expected_result{0, 1, 3, 4, 5, 7}; + ASSERT_EQ(result, expected_result); +} + + +namespace knp::testing +{ + +class STestingBack : public knp::backends::single_threaded_cpu::SingleThreadedCPUBackend +{ +public: + STestingBack() = default; + void _init() override { knp::backends::single_threaded_cpu::SingleThreadedCPUBackend::_init(); } +}; + +} // namespace knp::testing + + +using BlifatParams = knp::neuron_traits::neuron_parameters; +using DeltaProjection = knp::core::Projection; +using BlifatPopulation = knp::core::Population; + + +DeltaProjection::Synapse input_synapse_generator(size_t index) +{ + return {{1.0, 1, knp::synapse_traits::OutputType::EXCITATORY}, 0, index}; +} + + +DeltaProjection::Synapse intermediate_synapse_generator(size_t index) +{ + return {{1.0, 1, knp::synapse_traits::OutputType::EXCITATORY}, index, index}; +} + + +TEST(MessageHandlerSuite, NetworkIntegrationTest) +{ + // In this test we are making a network that consists of: + // - Input projection. It takes a zero input and translates it to all input neurons, instantly activating them. + // - Input population, that consists of 6 regular BLIFAT neurons. + // - Modifier. It divides incoming spikes into 3 groups: 0, 1; 2, 3; 4, 5 and selects only one of them. + // - Intermediate projection that receives spikes from modifier and activates a corresponding neuron in output + // population. + // - Output population, still 6 neurons, but only two neighbouring neurons should ever be activated, at random. + auto back_path = knp::testing::get_backend_path(); + constexpr int num_neurons = 6; + // A population of 3 groups of 2 neurons per group. + BlifatPopulation population([](size_t) { return BlifatParams{}; }, num_neurons); + + BlifatPopulation output_population([](size_t) { return BlifatParams{}; }, num_neurons); + const knp::core::UID in_pop_uid = population.get_uid(); + const knp::core::UID out_pop_uid = output_population.get_uid(); + // A projection that activates all neurons simultaneously. + DeltaProjection input_projection{knp::core::UID{false}, population.get_uid(), input_synapse_generator, num_neurons}; + + // A projection that should later be connected to a modifier. + DeltaProjection inter_projection{ + knp::core::UID{false}, output_population.get_uid(), intermediate_synapse_generator, num_neurons}; + + const knp::core::UID input_proj_uid = input_projection.get_uid(); + const knp::core::UID inter_proj_uid = inter_projection.get_uid(); + + // Create network. + knp::framework::Network network; + network.add_population(std::move(population)); + network.add_population(std::move(output_population)); + network.add_projection(std::move(input_projection)); + network.add_projection(std::move(inter_projection)); + + // Move network to model and add input and output channels. + knp::framework::Model model(std::move(network)); + const knp::core::UID input_uid, output_uid; + model.add_input_channel(input_uid, input_proj_uid); + model.add_output_channel(output_uid, out_pop_uid); + + // Generate an input spike at each step. + auto input_gen = [](size_t step) { return knp::core::messaging::SpikeData{0}; }; + + knp::framework::BackendLoader backend_loader; + knp::framework::ModelExecutor model_executor( + model, backend_loader.load(knp::testing::get_backend_path()), {{input_uid, input_gen}}); + + auto &out_channel = model_executor.get_loader().get_output_channel(output_uid); + const std::vector group_borders{2, 4}; + const knp::core::UID handler_uid; + model_executor.add_spike_message_handler( + knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid); + constexpr int num_steps = 20; + model_executor.start([num_steps](size_t step) { return step < num_steps; }); + + const auto &spikes = out_channel.update(); + + ASSERT_GE(spikes.size(), 10); + for (const auto &msg : spikes) + { + ASSERT_GE(msg.header_.send_time_, 3); + ASSERT_EQ(msg.neuron_indexes_.size(), 2); + ASSERT_EQ(std::abs(static_cast(msg.neuron_indexes_[0]) - static_cast(msg.neuron_indexes_[1])), 1); + } +}