diff --git a/src/util/raft/node.cpp b/src/util/raft/node.cpp index 4a45593c7..e7cb7f652 100644 --- a/src/util/raft/node.cpp +++ b/src/util/raft/node.cpp @@ -5,6 +5,8 @@ #include "node.hpp" +#include + namespace cbdc::raft { node::node(int node_id, std::vector raft_endpoints, @@ -99,11 +101,26 @@ namespace cbdc::raft { auto node::replicate_sync(const nuraft::ptr& new_log) const -> std::optional> { auto ret = m_raft_instance->append_entries({new_log}); - if(!ret->get_accepted() - || ret->get_result_code() != nuraft::cmd_result_code::OK) { + if(!ret->get_accepted()) { + return std::nullopt; + } + auto result_code = nuraft::cmd_result_code::RESULT_NOT_EXIST_YET; + auto blocking_promise = std::promise(); + auto blocking_future = blocking_promise.get_future(); + ret->when_ready([&result_code, + &blocking_promise](raft::result_type& r, + nuraft::ptr& err) { + if(err) { + result_code = nuraft::cmd_result_code::FAILED; + } else { + result_code = r.get_result_code(); + } + blocking_promise.set_value(); + }); + blocking_future.wait(); + if(result_code != nuraft::cmd_result_code::OK) { return std::nullopt; } - return ret->get(); } diff --git a/tests/unit/raft_test.cpp b/tests/unit/raft_test.cpp index de16cd508..c21f5e0c6 100644 --- a/tests/unit/raft_test.cpp +++ b/tests/unit/raft_test.cpp @@ -173,6 +173,10 @@ class raft_test : public ::testing::Test { auto new_log = cbdc::make_buffer>(1); + auto res = nodes[0]->replicate_sync(new_log); + ASSERT_TRUE(res.has_value()); + ASSERT_EQ(nodes[0]->last_log_idx(), 2UL); + cbdc::raft::callback_type result_fn = nullptr; auto result_done = std::atomic(false); if(!blocking) { @@ -190,14 +194,7 @@ class raft_test : public ::testing::Test { while(!result_done) { std::this_thread::sleep_for(std::chrono::milliseconds(250)); } - ASSERT_EQ(nodes[0]->last_log_idx(), 2UL); - - if(blocking) { - // Replicate sync will only return a value in the blocking context - auto res = nodes[0]->replicate_sync(new_log); - ASSERT_TRUE(res.has_value()); - ASSERT_EQ(nodes[0]->last_log_idx(), 3UL); - } + ASSERT_EQ(nodes[0]->last_log_idx(), 3UL); for(size_t i{0}; i < nodes.size(); i++) { ASSERT_EQ(nodes[i]->get_sm(), sms[i].get());