Skip to content

Commit

Permalink
Save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
kboyarinov committed Jan 14, 2025
1 parent 8e14ce1 commit b8149b0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 44 deletions.
7 changes: 6 additions & 1 deletion include/oneapi/tbb/detail/_flow_graph_node_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ class multifunction_node_tag {

multifunction_node_tag& operator=(const multifunction_node_tag&) = delete;
multifunction_node_tag& operator=(multifunction_node_tag&& other) {
// TODO: should this method be thread-safe?
if (this != &other) {
reset();
my_metainfo = std::move(other.my_metainfo);
Expand All @@ -595,6 +596,8 @@ class multifunction_node_tag {
}

void reset() {
tbb::spin_mutex::scoped_lock lock(my_mutex);

for (auto waiter : my_metainfo.waiters()) {
waiter->release();
}
Expand Down Expand Up @@ -660,9 +663,11 @@ class multifunction_input : public function_input_base<Input, Policy, A, multifu
graph_task* apply_body_impl_bypass( const input_type &i
__TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) )
{
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
multifunction_node_tag tag(metainfo);
#endif
fgt_begin_body( my_body );
(*my_body)(i, my_output_ports, std::move(tag));
(*my_body)(i, my_output_ports __TBB_FLOW_GRAPH_METAINFO_ARG(std::move(tag)));
fgt_end_body( my_body );
graph_task* ttask = nullptr;
if(base_type::my_max_concurrency != 0) {
Expand Down
106 changes: 63 additions & 43 deletions test/tbb/test_multifunction_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,61 +725,81 @@ void test_simple_broadcast() {
}

void test_no_broadcast() {
using multinode_type = tbb::flow::multifunction_node<int, std::tuple<int>>;
using ports_type = typename multinode_type::output_ports_type;
using tag_type = typename multinode_type::tag_type;
if (std::thread::hardware_concurrency() == 1) {
return;
}

std::size_t num_items = 10;
std::atomic<std::size_t> num_processed_items = 0;
int wait_message = 42;
tag_type global_tag;
tbb::task_arena arena(std::thread::hardware_concurrency(), 0);

std::vector<int> processed_items;
multinode_type* this_node = nullptr;
arena.execute([]() {
using multinode_type = tbb::flow::multifunction_node<int, std::tuple<int>>;
using ports_type = typename multinode_type::output_ports_type;
using tag_type = typename multinode_type::tag_type;

std::size_t num_items = 10;
std::size_t num_additional_items = 10;

std::atomic<std::size_t> num_processed_items = 0;
std::atomic<std::size_t> num_processed_accumulators = 0;
std::atomic<bool> try_put_and_wait_exit_flag{};

int accumulator_message = 1;
int add_message = 2;

tag_type global_tag;

multinode_type* this_node = nullptr;

std::vector<int> postprocessed_items;
auto global_index = tbb::this_task_arena::current_thread_index();

tbb::flow::graph g;
multinode_type node(g, tbb::flow::unlimited,
[&](int input, ports_type& ports, tag_type&& local_tag) {
if (num_processed_items++ == 0) {
CHECK(input == accumulator_message);
++num_processed_accumulators;

tbb::flow::graph g;
multinode_type node(g, tbb::flow::unlimited,
[&](int input, ports_type& ports, tag_type&& local_tag) {
// std::cout << "Process " << input << std::endl;
printf("Processed %li items\n", num_processed_items.load());
if (num_processed_items < num_items) {
if (input == wait_message) {
global_tag = std::move(local_tag);
for (int i = 0; i < int(num_items - 1); ++i) {
this_node->try_put(i);
for (std::size_t i = 1; i < num_items; ++i) {
this_node->try_put(accumulator_message);
}
for (std::size_t i = 0; i < num_additional_items; ++i) {
this_node->try_put(add_message);
}
} else {
if (input == accumulator_message) {
if (num_processed_accumulators++ == num_items - 1) {
// The last accumulator was received - "cancel" the operation
global_tag.reset();
}
} else {
if (global_index != tbb::this_task_arena::current_thread_index()) {
// Block the worker thread until the try_put_and_wait was exitted
while(!try_put_and_wait_exit_flag.load()) {
std::this_thread::yield();
}
}
std::get<0>(ports).try_put(input);
}
}
std::get<0>(ports).try_put(input);
} else {
global_tag.reset();
}
++num_processed_items;
});

this_node = &node;

tbb::flow::function_node<int, int, tbb::flow::lightweight> write_node(g, tbb::flow::unlimited,
[&](int value) noexcept { processed_items.emplace_back(value); return 0; });
});

tbb::flow::make_edge(tbb::flow::output_port<0>(node), write_node);
this_node = &node;

node.try_put_and_wait(wait_message);
tbb::flow::function_node<int, int> write_node(g, tbb::flow::serial,
[&](int value) noexcept { postprocessed_items.emplace_back(value); return 0; });

CHECK(num_items == num_processed_items);
CHECK(processed_items.size() == num_items - 1);
std::sort(processed_items.begin(), processed_items.end());
tbb::flow::make_edge(tbb::flow::output_port<0>(node), write_node);

for (auto item : processed_items) {
std::cout << item << " ";
}
std::cout << std::endl;
node.try_put_and_wait(accumulator_message);

for (std::size_t i = 0; i < num_items - 1; ++i) {
CHECK_MESSAGE(processed_items[i] == i, "Incorrect items processing order");
}
CHECK_MESSAGE(processed_items.back() == wait_message, "Incorrect items processing");
std::cout << num_processed_accumulators << std::endl;
std::cout << num_processed_items << std::endl;
std::cout << postprocessed_items.size() << std::endl;

g.wait_for_all();
g.wait_for_all();
});
}

void test_reduction() {
Expand Down

0 comments on commit b8149b0

Please sign in to comment.