Skip to content

Commit

Permalink
Moved MessageHandler inside ModelExecutor: #23
Browse files Browse the repository at this point in the history
  • Loading branch information
a-vartenkov committed Dec 2, 2024
1 parent e70e236 commit e5fe53d
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 115 deletions.
2 changes: 1 addition & 1 deletion knp/base-framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ knp_add_library("${PROJECT_NAME}-core"
impl/model.cpp
impl/model_executor.cpp
impl/model_loader.cpp
impl/message_handler.cpp
impl/message_handlers.cpp
impl/input_converter.cpp
impl/output_channel.cpp
impl/synchronization.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* @file message_handler.cpp
* @file message_handlers.cpp
* @brief Implementation of message handler functionality.
* @kaspersky_support A. Vartenkov
* @date 25.11.2024
Expand All @@ -19,33 +19,30 @@
* limitations under the License.
*/

#include <knp/framework/message_handler.h>
#include <knp/framework/message_handlers.h>

#include <unordered_set>
#include <utility>


/**
* @brief namespace for message modifier callables.
*/
namespace knp::framework::modifier
{

void SpikeMessageHandler::update(size_t step)
knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages)
{
endpoint_.receive_all_messages();
auto incoming_messages = endpoint_.unload_messages<MessageIn>(base_.uid_);
MessageOut outgoing_message = {{base_.uid_, step}, message_handler_function_(incoming_messages)};
if (!(outgoing_message.neuron_indexes_.empty()))
if (messages.empty())
{
endpoint_.send_message(outgoing_message);
return {};
}
}


knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty()) return {};

auto &msg = messages[0];
if (msg.neuron_indexes_.size() < num_winners_) return msg.neuron_indexes_;
if (msg.neuron_indexes_.size() < num_winners_)
{
return msg.neuron_indexes_;
}

knp::core::messaging::SpikeData out_spikes;
for (size_t i = 0; i < num_winners_; ++i)
Expand All @@ -54,20 +51,31 @@ knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::c
out_spikes.push_back(msg.neuron_indexes_[index]);
std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]);
}

return out_spikes;
}


knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty()) return {};
if (num_winners_ > group_borders_.size()) return messages[0].neuron_indexes_;
if (messages.empty())
{
return {};
}

if (num_winners_ > group_borders_.size())
{
return messages[0].neuron_indexes_;
}

const auto &spikes = messages[0].neuron_indexes_;
if (spikes.empty()) return {};
if (spikes.empty())
{
return {};
}

std::vector<std::vector<size_t>> spikes_per_group(group_borders_.size() + 1);
std::vector<knp::core::messaging::SpikeData> spikes_per_group(group_borders_.size() + 1);

// Fill groups in.
for (const auto &spike : spikes)
Expand All @@ -94,7 +102,10 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
knp::core::messaging::SpikeData result;
for (size_t i = 0; i < num_winners_; ++i)
{
for (auto spike : spikes_per_group[i]) result.push_back(spike);
for (const auto &spike : spikes_per_group[i])
{
result.push_back(spike);
}
}
return result;
}
Expand All @@ -103,7 +114,7 @@ knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
knp::core::messaging::SpikeData SpikeUnionHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
std::unordered_set<size_t> spikes;
std::unordered_set<knp::core::messaging::SpikeIndex> spikes;
for (const auto &msg : messages)
{
spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end());
Expand Down
24 changes: 18 additions & 6 deletions knp/base-framework/impl/model_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,29 @@ void ModelExecutor::stop()
}


void ModelExecutor::add_message_handler(
typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function,
const std::vector<core::UID> &senders, const std::vector<core::UID> &receivers, const knp::core::UID &uid)
void ModelExecutor::SpikeMessageHandler::update(size_t step)
{
endpoint_.receive_all_messages();
auto incoming_messages = endpoint_.unload_messages<MessageIn>(base_.uid_);
knp::core::messaging::SpikeMessage outgoing_message = {
{base_.uid_, step}, message_handler_function_(incoming_messages)};
if (!(outgoing_message.neuron_indexes_.empty()))
{
endpoint_.send_message(outgoing_message);
}
}


void ModelExecutor::add_spike_message_handler(
typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector<core::UID> &senders,
const std::vector<core::UID> &receivers, const knp::core::UID &uid)
{
knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint();
message_handlers_.emplace_back(
modifier::SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid});
message_handlers_.emplace_back(SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid});
message_handlers_.back().subscribe(senders);
for (const knp::core::UID &rec_uid : receivers)
{
get_backend()->subscribe<typename modifier::SpikeMessageHandler::MessageOut>(rec_uid, {uid});
get_backend()->subscribe<knp::core::messaging::SpikeMessage>(rec_uid, {uid});
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @file message_handler.h
* @brief A class that processes a number of messages then sends messages of its own.
* @file message_handlers.h
* @brief A set of predefined message handling functions to add to model executor.
* @kaspersky_support Vartenkov A.
* @date 19.11.2024
* @license Apache 2.0
Expand Down Expand Up @@ -34,83 +34,6 @@
namespace knp::framework::modifier
{

/**
* @brief An object that receives and processes messages.
*/
class SpikeMessageHandler
{
public:
/**
* @brief Input message type.
*/
using MessageIn = knp::core::messaging::SpikeMessage;

/**
* @brief Output message type.
*/
using MessageOut = knp::core::messaging::SpikeMessage;

/**
* @brief Functor type.
*/
using FunctionType = std::function<core::messaging::SpikeData(std::vector<MessageIn> &)>;

/**
* @brief Handler constructor.
* @param function a function that takes a vector of spike messages and returns a vector of spikes.
* @param endpoint message endpoint.
* @param uid the uid of this object.
*/
SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {})
: message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid}
{
}

/**
* @brief Default move constructor.
* @param other object to move from.
*/
SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default;

/**
* @brief Is not copyable.
*/
SpikeMessageHandler(const SpikeMessageHandler &) = delete;

/**
* @brief Subscribe handler to a number of other entities.
* @param entities network uids.
* @note For internal use, don't try to call it manually.
*/
void subscribe(const std::vector<core::UID> &entities) { endpoint_.subscribe<MessageIn>(base_.uid_, entities); }

/**
* @brief Read, process and send messages.
* @param step current step.
* @note for internal use, don't try to call it manually.
*/
void update(size_t step);

/**
* @brief Get handler UID.
* @return object UID.
*/
[[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; };

/**
* @brief Get a tag.
* @param tag_name tag name.
* @return tag value reference.
*/
[[nodiscard]] std::any &get_tag(const std::string &tag_name) { return base_.tags_[tag_name]; };

private:
FunctionType message_handler_function_;
core::MessageEndpoint endpoint_;
knp::core::BaseData base_;
};


/**
* @brief A modifier functor to process spikes and select random K spikes out of the whole set.
* @note Only processes a single message.
Expand Down
49 changes: 43 additions & 6 deletions knp/base-framework/include/knp/framework/model_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <knp/core/impexp.h>
#include <knp/framework/backend_loader.h>
#include <knp/framework/io/input_converter.h>
#include <knp/framework/message_handler.h>
#include <knp/framework/message_handlers.h>
#include <knp/framework/model.h>
#include <knp/framework/model_loader.h>
#include <knp/framework/monitoring/observer.h>
Expand Down Expand Up @@ -94,17 +94,22 @@ class KNP_DECLSPEC ModelExecutor
std::visit([&senders](auto &entity) { entity.subscribe(senders); }, observers_.back());
}

/**
* @brief Function type for message handlers.
*/
using SpikeHandlerFunction =
std::function<knp::core::messaging::SpikeData(std::vector<knp::core::messaging::SpikeMessage> &)>;

/**
* @brief Add spike message handler to executor.
* @param message_handler_function functor to process received messages.
* @param senders list of entities sending messages to the handler.
* @param receivers list of entities receiving messages from handler.
* @param uid handler uid.
*/
void add_message_handler(
typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function,
const std::vector<core::UID> &senders, const std::vector<core::UID> &receivers,
const knp::core::UID &uid = knp::core::UID{});
void add_spike_message_handler(
SpikeHandlerFunction &&message_handler_function, const std::vector<core::UID> &senders,
const std::vector<core::UID> &receivers, const knp::core::UID &uid = knp::core::UID{});

/**
* @brief Unlock synapse weights.
Expand All @@ -129,10 +134,42 @@ class KNP_DECLSPEC ModelExecutor
auto &get_loader() { return loader_; }

private:
/**
* @brief An object that receives and processes messages.
*/
class SpikeMessageHandler
{
public:
using MessageIn = knp::core::messaging::SpikeMessage;
using MessageData = knp::core::messaging::SpikeData;
using FunctionType = std::function<MessageData(std::vector<MessageIn> &)>;

SpikeMessageHandler(
FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {})
: message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid}
{
}

SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default;

SpikeMessageHandler(const SpikeMessageHandler &) = delete;

void subscribe(const std::vector<core::UID> &entities) { endpoint_.subscribe<MessageIn>(base_.uid_, entities); }

void update(size_t step);

[[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; };

private:
FunctionType message_handler_function_;
core::MessageEndpoint endpoint_;
knp::core::BaseData base_;
};

knp::core::BaseData base_;
ModelLoader loader_;

std::vector<monitoring::AnyObserverVariant> observers_;
std::vector<modifier::SpikeMessageHandler> message_handlers_;
std::vector<SpikeMessageHandler> message_handlers_;
};
} // namespace knp::framework
4 changes: 2 additions & 2 deletions knp/tests/framework/message_handler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include <knp/backends/cpu-single-threaded/backend.h>
#include <knp/core/messaging/messaging.h>
#include <knp/framework/message_handler.h>
#include <knp/framework/message_handlers.h>
#include <knp/framework/model_executor.h>
#include <knp/framework/network.h>
#include <knp/neuron-traits/blifat.h>
Expand Down Expand Up @@ -169,7 +169,7 @@ TEST(MessageHandlerSuite, NetworkIntegrationTest)
auto &out_channel = model_executor.get_loader().get_output_channel(output_uid);
const std::vector<size_t> group_borders{2, 4};
const knp::core::UID handler_uid;
model_executor.add_message_handler(
model_executor.add_spike_message_handler(
knp::framework::modifier::GroupWtaRandomHandler{group_borders}, {in_pop_uid}, {inter_proj_uid}, handler_uid);
constexpr int num_steps = 20;
model_executor.start([](size_t step) { return step < num_steps; });
Expand Down

0 comments on commit e5fe53d

Please sign in to comment.