Skip to content

Commit

Permalink
Use promises from coroio
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 30, 2023
1 parent 7e43676 commit cbd49a0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 121 deletions.
2 changes: 1 addition & 1 deletion coroio
4 changes: 2 additions & 2 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "messages.h"

template<typename TSocket>
TPromise<void>::TTask TWriter<TSocket>::Write(TMessageHolder<TMessage> message) {
NNet::TValueTask<void> TWriter<TSocket>::Write(TMessageHolder<TMessage> message) {
auto payload = std::move(message.Payload);
char* p = (char*)message.Mes; // TODO: const char
uint32_t len = message->Len;
Expand All @@ -32,7 +32,7 @@ TPromise<void>::TTask TWriter<TSocket>::Write(TMessageHolder<TMessage> message)
}

template<typename TSocket>
TPromise<TMessageHolder<TMessage>>::TTask TReader<TSocket>::Read() {
NNet::TValueTask<TMessageHolder<TMessage>> TReader<TSocket>::Read() {
decltype(TMessage::Type) type;
decltype(TMessage::Len) len;
auto s = co_await Socket.ReadSome((char*)&type, sizeof(type));
Expand Down
108 changes: 2 additions & 106 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,118 +13,14 @@
#include "messages.h"
#include "raft.h"


template<typename T>
struct TPromise
{
struct TTask : std::coroutine_handle<TPromise<T>>
{
using promise_type = TPromise<T>;

~TTask() { this->destroy(); /*TODO:*/ }

bool await_ready() {
return this->promise().Value != nullptr;
}

void await_suspend(std::coroutine_handle<> caller) {
this->promise().Caller = caller;
}

T await_resume() {
if (this->promise().Exception) {
std::rethrow_exception(this->promise().Exception);
} else {
return *this->promise().Value;
}
}
};

TTask get_return_object() { return { TTask::from_promise(*this) }; }
std::suspend_never initial_suspend() { return {}; }
auto final_suspend() noexcept {
struct TAwaitable {
bool await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<TPromise<T>> h) noexcept {
return h.promise().Caller;
}
void await_resume() noexcept { }
};
return TAwaitable {};
}

void return_value(const T& t) {
Value = std::make_shared<T>(t);
}

void unhandled_exception() {
Exception = std::current_exception();
}

std::shared_ptr<T> Value;
std::exception_ptr Exception;
std::coroutine_handle<> Caller = std::noop_coroutine();
};

template<>
struct TPromise<void>
{
struct TTask : std::coroutine_handle<TPromise<void>>
{
using promise_type = TPromise<void>;

~TTask() { destroy(); /* TODO: */ }

bool await_ready() {
return this->promise().Ready;
}

void await_suspend(std::coroutine_handle<> caller) {
this->promise().Caller = caller;
}

void await_resume() {
if (this->promise().Exception) {
std::rethrow_exception(this->promise().Exception);
}
}
};

TTask get_return_object() { return { TTask::from_promise(*this) }; }
std::suspend_never initial_suspend() { return {}; }
auto final_suspend() noexcept {
struct TAwaitable {
bool await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<TPromise<void>> h) noexcept {
return h.promise().Caller;
}
void await_resume() noexcept { }
};
return TAwaitable {};
}


void return_void() {
Ready = true;
}
void unhandled_exception() {
Exception = std::current_exception();
Ready = true;
}

bool Ready = false;
std::exception_ptr Exception;
std::coroutine_handle<> Caller = std::noop_coroutine();
};

template<typename TSocket>
class TReader {
public:
TReader(TSocket& socket)
: Socket(socket)
{ }

TPromise<TMessageHolder<TMessage>>::TTask Read();
NNet::TValueTask<TMessageHolder<TMessage>> Read();

private:
TSocket& Socket;
Expand All @@ -137,7 +33,7 @@ class TWriter {
: Socket(socket)
{ }

TPromise<void>::TTask Write(TMessageHolder<TMessage> message);
NNet::TValueTask<void> Write(TMessageHolder<TMessage> message);

private:
TSocket& Socket;
Expand Down
22 changes: 10 additions & 12 deletions test/test_read_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ extern "C" {
#include <cmocka.h>
}

using namespace NNet;

namespace {

TMessageHolder<TLogEntry> MakeEntry(const char* text) {
Expand All @@ -36,22 +34,22 @@ TMessageHolder<TLogEntry> MakeEntry(const char* text) {
void test_read_write(void**) {
auto mes = MakeEntry("MESSAGE");

TLoop<TPoll> loop;
TSocket socket(TAddress{"127.0.0.1", 8888}, loop.Poller());
NNet::TLoop<NNet::TPoll> loop;
NNet::TSocket socket(NNet::TAddress{"127.0.0.1", 8888}, loop.Poller());
socket.Bind();
socket.Listen();

TSocket client(TAddress{"127.0.0.1", 8888}, loop.Poller());
NNet::TSocket client(NNet::TAddress{"127.0.0.1", 8888}, loop.Poller());

TTestTask h1 = [](TSocket& client, TMessageHolder<TLogEntry> mes) -> TTestTask
NNet::TTestTask h1 = [](NNet::TSocket& client, TMessageHolder<TLogEntry> mes) -> NNet::TTestTask
{
co_await client.Connect();
co_await TWriter(client).Write(std::move(mes));
co_return;
}(client, mes);

TMessageHolder<TMessage> received;
TTestTask h2 = [](TSocket& server, TMessageHolder<TMessage>& received) -> TTestTask
NNet::TTestTask h2 = [](NNet::TSocket& server, TMessageHolder<TMessage>& received) -> NNet::TTestTask
{
auto client = std::move(co_await server.Accept());
received = co_await TReader(client).Read();
Expand Down Expand Up @@ -79,22 +77,22 @@ void test_read_write_payload(void**) {
}
mes->Nentries = mes.Payload.size();

TLoop<TPoll> loop;
TSocket socket(TAddress{"127.0.0.1", 8889}, loop.Poller());
NNet::TLoop<NNet::TPoll> loop;
NNet::TSocket socket(NNet::TAddress{"127.0.0.1", 8889}, loop.Poller());
socket.Bind();
socket.Listen();

TSocket client(TAddress{"127.0.0.1", 8889}, loop.Poller());
NNet::TSocket client(NNet::TAddress{"127.0.0.1", 8889}, loop.Poller());

TTestTask h1 = [](TSocket& client, TMessageHolder<TMessage> mes) -> TTestTask
NNet::TTestTask h1 = [](NNet::TSocket& client, TMessageHolder<TMessage> mes) -> NNet::TTestTask
{
co_await client.Connect();
co_await TWriter(client).Write(std::move(mes));
co_return;
}(client, mes);

TMessageHolder<TMessage> received;
TTestTask h2 = [](TSocket& server, TMessageHolder<TMessage>& received) -> TTestTask
NNet::TTestTask h2 = [](NNet::TSocket& server, TMessageHolder<TMessage>& received) -> NNet::TTestTask
{
auto client = std::move(co_await server.Accept());
received = co_await TReader(client).Read();
Expand Down

0 comments on commit cbd49a0

Please sign in to comment.