Skip to content

Commit

Permalink
Merge pull request #27 from a-vartenkov/master
Browse files Browse the repository at this point in the history
Add message handler and tests: #23
  • Loading branch information
artiomn authored Dec 4, 2024
2 parents 2746bcf + 636fd86 commit 365b789
Show file tree
Hide file tree
Showing 8 changed files with 554 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ void do_dopamine_plasticity(
if (step - synapse->rule_.last_spike_step_ < synapse->rule_.dopamine_plasticity_period_)
{
// Change synapse resource.
float d_r =
neuron.dopamine_value_ * std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F);
float d_r = neuron.dopamine_value_ *
std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F) / 1000.F;
synapse->rule_.synaptic_resource_ += d_r;
neuron.free_synaptic_resource_ -= d_r;
}
Expand Down
1 change: 1 addition & 0 deletions knp/base-framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ knp_add_library("${PROJECT_NAME}-core"
impl/model.cpp
impl/model_executor.cpp
impl/model_loader.cpp
impl/message_handlers.cpp
impl/input_converter.cpp
impl/output_channel.cpp
impl/synchronization.cpp
Expand Down
127 changes: 127 additions & 0 deletions knp/base-framework/impl/message_handlers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/**
* @file message_handlers.cpp
* @brief Implementation of message handler functionality.
* @kaspersky_support A. Vartenkov
* @date 25.11.2024
* @license Apache 2.0
* @copyright © 2024 AO Kaspersky Lab
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <knp/framework/message_handlers.h>

#include <unordered_set>
#include <utility>


/**
* @brief namespace for message modifier callables.
*/
namespace knp::framework::modifier
{

knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty())
{
return {};
}

auto &msg = messages[0];
if (msg.neuron_indexes_.size() < num_winners_)
{
return msg.neuron_indexes_;
}

knp::core::messaging::SpikeData out_spikes;
for (size_t i = 0; i < num_winners_; ++i)
{
const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i);
out_spikes.push_back(msg.neuron_indexes_[index]);
std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]);
}

return out_spikes;
}


knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty())
{
return {};
}

if (num_winners_ > group_borders_.size())
{
return messages[0].neuron_indexes_;
}

const auto &spikes = messages[0].neuron_indexes_;
if (spikes.empty())
{
return {};
}

std::vector<knp::core::messaging::SpikeData> spikes_per_group(group_borders_.size() + 1);

// Fill groups in.
for (const auto &spike : spikes)
{
const size_t group_index =
std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin();
spikes_per_group[group_index].push_back(spike);
}

// Sort groups by number of elements.
std::sort(
spikes_per_group.begin(), spikes_per_group.end(),
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });

// Find all groups with the same number of spikes as the K-th one.
const auto &last_group = spikes_per_group[num_winners_ - 1];
auto group_interval = std::equal_range(
spikes_per_group.begin(), spikes_per_group.end(), last_group,
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });
const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1;
assert(already_decided <= num_winners_);
// The approach could be more efficient, but I don't think it's necessary.
std::shuffle(group_interval.first, group_interval.second, random_engine_);
knp::core::messaging::SpikeData result;
for (size_t i = 0; i < num_winners_; ++i)
{
for (const auto &spike : spikes_per_group[i])
{
result.push_back(spike);
}
}
return result;
}


knp::core::messaging::SpikeData SpikeUnionHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
std::unordered_set<knp::core::messaging::SpikeIndex> spikes;
for (const auto &msg : messages)
{
spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end());
}
knp::core::messaging::SpikeData result;
result.reserve(spikes.size());
std::copy(spikes.begin(), spikes.end(), std::back_inserter(result));
return result;
}
} // namespace knp::framework::modifier
76 changes: 76 additions & 0 deletions knp/base-framework/impl/model_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,62 @@
namespace knp::framework
{


class ModelExecutor::SpikeMessageHandler
{
public:
using MessageIn = knp::core::messaging::SpikeMessage;
using MessageData = knp::core::messaging::SpikeData;
using FunctionType = std::function<MessageData(std::vector<MessageIn> &)>;

SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {})
: message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid}
{
}

SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default;

SpikeMessageHandler(const SpikeMessageHandler &) = delete;

void subscribe(const std::vector<core::UID> &entities) { endpoint_.subscribe<MessageIn>(base_.uid_, entities); }

void update(size_t step);

[[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; };

~SpikeMessageHandler() = default;

private:
FunctionType message_handler_function_;
knp::core::MessageEndpoint endpoint_;
knp::core::BaseData base_;
};


void ModelExecutor::SpikeMessageHandler::update(size_t step)
{
endpoint_.receive_all_messages();
auto incoming_messages = endpoint_.unload_messages<MessageIn>(base_.uid_);
knp::core::messaging::SpikeMessage outgoing_message = {
{base_.uid_, step}, message_handler_function_(incoming_messages)};
if (!(outgoing_message.neuron_indexes_.empty()))
{
endpoint_.send_message(outgoing_message);
}
}


ModelExecutor::ModelExecutor(
knp::framework::Model &model, std::shared_ptr<core::Backend> backend, ModelLoader::InputChannelMap i_map)
: loader_(backend, i_map)
{
loader_.load(model);
}


ModelExecutor::~ModelExecutor() = default;


void ModelExecutor::start()
{
start([](knp::core::Step) { return true; });
Expand Down Expand Up @@ -55,6 +111,11 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate)
{
o_ch.update();
}
// Running handlers
for (auto &handler : message_handlers_)
{
handler->update(get_backend()->get_step());
}
// Run monitoring observers.
for (auto &observer : observers_)
{
Expand All @@ -72,4 +133,19 @@ void ModelExecutor::stop()
get_backend()->stop();
}


void ModelExecutor::add_spike_message_handler(
typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector<core::UID> &senders,
const std::vector<core::UID> &receivers, const knp::core::UID &uid)
{
knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint();
message_handlers_.emplace_back(
std::make_unique<SpikeMessageHandler>(std::move(message_handler_function), std::move(endpoint), uid));
message_handlers_.back()->subscribe(senders);
for (const knp::core::UID &rec_uid : receivers)
{
get_backend()->subscribe<knp::core::messaging::SpikeMessage>(rec_uid, {uid});
}
}

} // namespace knp::framework
122 changes: 122 additions & 0 deletions knp/base-framework/include/knp/framework/message_handlers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/**
* @file message_handlers.h
* @brief A set of predefined message handling functions to add to model executor.
* @kaspersky_support Vartenkov A.
* @date 19.11.2024
* @license Apache 2.0
* @copyright © 2024 AO Kaspersky Lab
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once


#include <knp/core/message_endpoint.h>
#include <knp/core/messaging/messaging.h>

#include <algorithm>
#include <random>
#include <string>
#include <utility>
#include <vector>


namespace knp::framework::modifier
{

/**
* @brief A modifier functor to process spikes and select random K spikes out of the whole set.
* @note Only processes a single message.
*/
class KWtaRandomHandler
{
public:
/**
* @brief Constructor.
* @param winners_number Max number of output spikes.
* @param seed random generator seed.
* @note uses mt19937 for random number generation.
*/
explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0)
: num_winners_(winners_number), random_engine_(seed)
{
}

/**
* @brief Function call operator that takes a number of messages and returns a set of spikes.
* @param messages spike messages.
* @return spikes data containing no more than K spikes.
* @note it's assumed that it gets no more than one message per step, so all messages except first are ignored.
*/
knp::core::messaging::SpikeData operator()(std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once.
* @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order.
* @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups.
*/
class GroupWtaRandomHandler
{
public:
/**
* @brief Functor constructor.
* @param group_borders right borders of the intervals.
* @param num_winning_groups max number of groups that are allowed to pass their spikes further.
* @param seed seed for internal random number generator.
*/
explicit GroupWtaRandomHandler(
const std::vector<size_t> &group_borders, size_t num_winning_groups = 1, int seed = 0)
: group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed)
{
std::sort(group_borders_.begin(), group_borders_.end());
}

/**
* @brief Function call operator.
* @param messages input messages.
* @return spikes from winning groups.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
std::vector<size_t> group_borders_;
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message.
*/
class SpikeUnionHandler
{
public:
/**
* @brief Function call operator, receives a vector of messages, returns a union of all spike sets from those
* messages.
* @param messages incoming spike messages.
* @return spikes vector containing the union of input message spike sets.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);
};


} // namespace knp::framework::modifier
Loading

0 comments on commit 365b789

Please sign in to comment.