-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from a-vartenkov/master
Add message handler and tests: #23
- Loading branch information
Showing
8 changed files
with
554 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/** | ||
* @file message_handlers.cpp | ||
* @brief Implementation of message handler functionality. | ||
* @kaspersky_support A. Vartenkov | ||
* @date 25.11.2024 | ||
* @license Apache 2.0 | ||
* @copyright © 2024 AO Kaspersky Lab | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <knp/framework/message_handlers.h> | ||
|
||
#include <unordered_set> | ||
#include <utility> | ||
|
||
|
||
/** | ||
* @brief namespace for message modifier callables. | ||
*/ | ||
namespace knp::framework::modifier | ||
{ | ||
|
||
knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages) | ||
{ | ||
if (messages.empty()) | ||
{ | ||
return {}; | ||
} | ||
|
||
auto &msg = messages[0]; | ||
if (msg.neuron_indexes_.size() < num_winners_) | ||
{ | ||
return msg.neuron_indexes_; | ||
} | ||
|
||
knp::core::messaging::SpikeData out_spikes; | ||
for (size_t i = 0; i < num_winners_; ++i) | ||
{ | ||
const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i); | ||
out_spikes.push_back(msg.neuron_indexes_[index]); | ||
std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]); | ||
} | ||
|
||
return out_spikes; | ||
} | ||
|
||
|
||
knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()( | ||
const std::vector<knp::core::messaging::SpikeMessage> &messages) | ||
{ | ||
if (messages.empty()) | ||
{ | ||
return {}; | ||
} | ||
|
||
if (num_winners_ > group_borders_.size()) | ||
{ | ||
return messages[0].neuron_indexes_; | ||
} | ||
|
||
const auto &spikes = messages[0].neuron_indexes_; | ||
if (spikes.empty()) | ||
{ | ||
return {}; | ||
} | ||
|
||
std::vector<knp::core::messaging::SpikeData> spikes_per_group(group_borders_.size() + 1); | ||
|
||
// Fill groups in. | ||
for (const auto &spike : spikes) | ||
{ | ||
const size_t group_index = | ||
std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin(); | ||
spikes_per_group[group_index].push_back(spike); | ||
} | ||
|
||
// Sort groups by number of elements. | ||
std::sort( | ||
spikes_per_group.begin(), spikes_per_group.end(), | ||
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); | ||
|
||
// Find all groups with the same number of spikes as the K-th one. | ||
const auto &last_group = spikes_per_group[num_winners_ - 1]; | ||
auto group_interval = std::equal_range( | ||
spikes_per_group.begin(), spikes_per_group.end(), last_group, | ||
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); }); | ||
const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1; | ||
assert(already_decided <= num_winners_); | ||
// The approach could be more efficient, but I don't think it's necessary. | ||
std::shuffle(group_interval.first, group_interval.second, random_engine_); | ||
knp::core::messaging::SpikeData result; | ||
for (size_t i = 0; i < num_winners_; ++i) | ||
{ | ||
for (const auto &spike : spikes_per_group[i]) | ||
{ | ||
result.push_back(spike); | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
|
||
knp::core::messaging::SpikeData SpikeUnionHandler::operator()( | ||
const std::vector<knp::core::messaging::SpikeMessage> &messages) | ||
{ | ||
std::unordered_set<knp::core::messaging::SpikeIndex> spikes; | ||
for (const auto &msg : messages) | ||
{ | ||
spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end()); | ||
} | ||
knp::core::messaging::SpikeData result; | ||
result.reserve(spikes.size()); | ||
std::copy(spikes.begin(), spikes.end(), std::back_inserter(result)); | ||
return result; | ||
} | ||
} // namespace knp::framework::modifier |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
knp/base-framework/include/knp/framework/message_handlers.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/** | ||
* @file message_handlers.h | ||
* @brief A set of predefined message handling functions to add to model executor. | ||
* @kaspersky_support Vartenkov A. | ||
* @date 19.11.2024 | ||
* @license Apache 2.0 | ||
* @copyright © 2024 AO Kaspersky Lab | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
|
||
#include <knp/core/message_endpoint.h> | ||
#include <knp/core/messaging/messaging.h> | ||
|
||
#include <algorithm> | ||
#include <random> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
|
||
namespace knp::framework::modifier | ||
{ | ||
|
||
/** | ||
* @brief A modifier functor to process spikes and select random K spikes out of the whole set. | ||
* @note Only processes a single message. | ||
*/ | ||
class KWtaRandomHandler | ||
{ | ||
public: | ||
/** | ||
* @brief Constructor. | ||
* @param winners_number Max number of output spikes. | ||
* @param seed random generator seed. | ||
* @note uses mt19937 for random number generation. | ||
*/ | ||
explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0) | ||
: num_winners_(winners_number), random_engine_(seed) | ||
{ | ||
} | ||
|
||
/** | ||
* @brief Function call operator that takes a number of messages and returns a set of spikes. | ||
* @param messages spike messages. | ||
* @return spikes data containing no more than K spikes. | ||
* @note it's assumed that it gets no more than one message per step, so all messages except first are ignored. | ||
*/ | ||
knp::core::messaging::SpikeData operator()(std::vector<knp::core::messaging::SpikeMessage> &messages); | ||
|
||
private: | ||
size_t num_winners_; | ||
std::mt19937 random_engine_; | ||
std::uniform_int_distribution<size_t> distribution_; | ||
}; | ||
|
||
|
||
/** | ||
* @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once. | ||
* @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order. | ||
* @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups. | ||
*/ | ||
class GroupWtaRandomHandler | ||
{ | ||
public: | ||
/** | ||
* @brief Functor constructor. | ||
* @param group_borders right borders of the intervals. | ||
* @param num_winning_groups max number of groups that are allowed to pass their spikes further. | ||
* @param seed seed for internal random number generator. | ||
*/ | ||
explicit GroupWtaRandomHandler( | ||
const std::vector<size_t> &group_borders, size_t num_winning_groups = 1, int seed = 0) | ||
: group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed) | ||
{ | ||
std::sort(group_borders_.begin(), group_borders_.end()); | ||
} | ||
|
||
/** | ||
* @brief Function call operator. | ||
* @param messages input messages. | ||
* @return spikes from winning groups. | ||
*/ | ||
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages); | ||
|
||
private: | ||
std::vector<size_t> group_borders_; | ||
size_t num_winners_; | ||
std::mt19937 random_engine_; | ||
std::uniform_int_distribution<size_t> distribution_; | ||
}; | ||
|
||
|
||
/** | ||
* @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message. | ||
*/ | ||
class SpikeUnionHandler | ||
{ | ||
public: | ||
/** | ||
* @brief Function call operator, receives a vector of messages, returns a union of all spike sets from those | ||
* messages. | ||
* @param messages incoming spike messages. | ||
* @return spikes vector containing the union of input message spike sets. | ||
*/ | ||
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages); | ||
}; | ||
|
||
|
||
} // namespace knp::framework::modifier |
Oops, something went wrong.