Skip to content

Commit

Permalink
zmq: implement alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
rbx committed Jan 13, 2021
1 parent 02a3980 commit 6815c9c
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 12 deletions.
2 changes: 2 additions & 0 deletions fairmq/FairMQMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class FairMQMessage
FairMQMessage(FairMQTransportFactory* factory) : fTransport(factory) {}

virtual void Rebuild() = 0;
virtual void Rebuild(fair::mq::Alignment alignment) = 0;
virtual void Rebuild(const size_t size) = 0;
virtual void Rebuild(const size_t size, fair::mq::Alignment alignment) = 0;
virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0;

virtual void* GetData() const = 0;
Expand Down
12 changes: 12 additions & 0 deletions fairmq/ofi/Message.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ auto Message::Rebuild() -> void
fHint = nullptr;
}

auto Message::Rebuild(Alignment /* alignment */) -> void
{
// TODO: implement alignment
Rebuild();
}

auto Message::Rebuild(const size_t size) -> void
{
if (fFreeFunction) {
Expand All @@ -131,6 +137,12 @@ auto Message::Rebuild(const size_t size) -> void
fHint = nullptr;
}

auto Message::Rebuild(const size_t size, Alignment /* alignment */) -> void
{
// TODO: implement alignment
Rebuild(size);
}

auto Message::Rebuild(void* /*data*/, const size_t size, fairmq_free_fn* ffn, void* hint) -> void
{
if (fFreeFunction) {
Expand Down
2 changes: 2 additions & 0 deletions fairmq/ofi/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ class Message final : public fair::mq::Message
Message operator=(const Message&) = delete;

auto Rebuild() -> void override;
auto Rebuild(Alignment alignment) -> void override;
auto Rebuild(const size_t size) -> void override;
auto Rebuild(const size_t size, Alignment alignment) -> void override;
auto Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) -> void override;

auto GetData() const -> void* override;
Expand Down
25 changes: 22 additions & 3 deletions fairmq/shmem/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class Message final : public fair::mq::Message
fManager.IncrementMsgCounter();
}

Message(Manager& manager, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
Message(Manager& manager, Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
Expand All @@ -78,10 +79,11 @@ class Message final : public fair::mq::Message
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
InitializeChunk(size, static_cast<size_t>(alignment));
InitializeChunk(size, fAlignment);
fManager.IncrementMsgCounter();
}

Expand Down Expand Up @@ -142,13 +144,28 @@ class Message final : public fair::mq::Message
fQueued = false;
}

void Rebuild(Alignment alignment) override
{
CloseMessage();
fQueued = false;
fAlignment = alignment.alignment;
}

void Rebuild(const size_t size) override
{
CloseMessage();
fQueued = false;
InitializeChunk(size);
}

void Rebuild(const size_t size, Alignment alignment) override
{
CloseMessage();
fQueued = false;
fAlignment = alignment.alignment;
InitializeChunk(size, fAlignment);
}

void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override
{
CloseMessage();
Expand Down Expand Up @@ -242,6 +259,7 @@ class Message final : public fair::mq::Message
Manager& fManager;
bool fQueued;
MetaHeader fMeta;
size_t fAlignment; // TODO: put this to debug mode
mutable Region* fRegionPtr;
mutable char* fLocalPtr;

Expand Down Expand Up @@ -276,8 +294,9 @@ class Message final : public fair::mq::Message
}
fLocalPtr = nullptr;
fMeta.fSize = 0;
fAlignment = 0;

fManager.DecrementMsgCounter();
fManager.DecrementMsgCounter(); // TODO: put this to debug mode
}
};

Expand Down
85 changes: 81 additions & 4 deletions fairmq/zeromq/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#include <zmq.h>

#include <cstddef>
#include <cstdlib> // malloc
#include <cstring>
#include <memory>
#include <new> // bad_alloc
#include <string>

namespace fair
Expand All @@ -38,14 +40,17 @@ class Message final : public fair::mq::Message
public:
Message(FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
{
if (zmq_msg_init(fMsg.get()) != 0) {
LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno);
}
}
Message(Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)

Message(Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(alignment.alignment)
, fMsg(tools::make_unique<zmq_msg_t>())
{
if (zmq_msg_init(fMsg.get()) != 0) {
Expand All @@ -55,24 +60,48 @@ class Message final : public fair::mq::Message

Message(const size_t size, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
{
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
}

Message(const size_t size, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
static std::pair<void*, void*> AllocateAligned(size_t size, size_t alignment)
{
char* fullBufferPtr = static_cast<char*>(malloc(size + alignment));
if (!fullBufferPtr) {
LOG(error) << "failed to allocate buffer with provided size (" << size << ") and alignment (" << alignment << ").";
throw std::bad_alloc();
}

size_t offset = alignment - (reinterpret_cast<uintptr_t>(fullBufferPtr) % alignment);
char* alignedPartPtr = fullBufferPtr + offset;

return {static_cast<void*>(fullBufferPtr), static_cast<void*>(alignedPartPtr)};
}

Message(const size_t size, Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(alignment.alignment)
, fMsg(tools::make_unique<zmq_msg_t>())
{
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
if (fAlignment != 0) {
auto ptrs = AllocateAligned(size, fAlignment);
if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
} else {
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
}
}

Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
{
if (zmq_msg_init_data(fMsg.get(), data, size, ffn, hint) != 0) {
Expand All @@ -82,6 +111,7 @@ class Message final : public fair::mq::Message

Message(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
{
// FIXME: make this zero-copy:
Expand Down Expand Up @@ -116,6 +146,16 @@ class Message final : public fair::mq::Message
}
}

void Rebuild(Alignment alignment) override
{
CloseMessage();
fAlignment = alignment.alignment;
fMsg = tools::make_unique<zmq_msg_t>();
if (zmq_msg_init(fMsg.get()) != 0) {
LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno);
}
}

void Rebuild(const size_t size) override
{
CloseMessage();
Expand All @@ -125,6 +165,24 @@ class Message final : public fair::mq::Message
}
}

void Rebuild(const size_t size, Alignment alignment) override
{
CloseMessage();
fAlignment = alignment.alignment;
fMsg = tools::make_unique<zmq_msg_t>();

if (fAlignment != 0) {
auto ptrs = AllocateAligned(size, fAlignment);
if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
} else {
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
}
}

void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override
{
CloseMessage();
Expand Down Expand Up @@ -174,6 +232,23 @@ class Message final : public fair::mq::Message
}
}

void Realign()
{
// if alignment is provided
if (fAlignment != 0) {
void* data = GetData();
size_t size = GetSize();
// if buffer is valid && not already aligned with the given alignment
if (data != nullptr && reinterpret_cast<uintptr_t>(GetData()) % fAlignment) {
// create new aligned buffer
auto ptrs = AllocateAligned(size, fAlignment);
std::memcpy(ptrs.second, zmq_msg_data(fMsg.get()), size);
// rebuild the message with the new buffer
Rebuild(ptrs.second, size, [](void* /* buf */, void* hint) { free(hint); }, ptrs.first);
}
}
}

Transport GetType() const override { return Transport::ZMQ; }

void Copy(const fair::mq::Message& msg) override
Expand All @@ -189,6 +264,7 @@ class Message final : public fair::mq::Message
~Message() override { CloseMessage(); }

private:
size_t fAlignment;
std::unique_ptr<zmq_msg_t> fMsg;

zmq_msg_t* GetMessage() const { return fMsg.get(); }
Expand All @@ -200,6 +276,7 @@ class Message final : public fair::mq::Message
}
// reset the message object to allow reuse in Rebuild
fMsg.reset(nullptr);
fAlignment = 0;
}
};

Expand Down
2 changes: 2 additions & 0 deletions fairmq/zeromq/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class Socket final : public fair::mq::Socket
while (true) {
int nbytes = zmq_msg_recv(static_cast<Message*>(msg.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) {
static_cast<Message*>(msg.get())->Realign();
int64_t actualBytes = zmq_msg_size(static_cast<Message*>(msg.get())->GetMessage());
fBytesRx += actualBytes;
++fMessagesRx;
Expand Down Expand Up @@ -261,6 +262,7 @@ class Socket final : public fair::mq::Socket

int nbytes = zmq_msg_recv(static_cast<Message*>(part.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) {
static_cast<Message*>(part.get())->Realign();
msgVec.push_back(move(part));
totalSize += nbytes;
} else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
Expand Down
52 changes: 47 additions & 5 deletions test/message/_message.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,54 @@ void RunMsgRebuild(const string& transport)
EXPECT_EQ(string(static_cast<char*>(msg->GetData()), msg->GetSize()), string("asdf"));
}

void Alignment(const string& transport)
void Alignment(const string& transport, const string& _address)
{
size_t session{fair::mq::tools::UuidHash()};
std::string address(fair::mq::tools::ToString(_address, "_", transport));

fair::mq::ProgOptions config;
config.SetProperty<string>("session", to_string(session));

auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config);

FairMQMessagePtr msg(factory->CreateMessage(100, fair::mq::Alignment{64}));
ASSERT_EQ(reinterpret_cast<uintptr_t>(msg->GetData()) % 64, 0);
FairMQChannel push{"Push", "push", factory};
push.Bind(address);

FairMQChannel pull{"Pull", "pull", factory};
pull.Connect(address);

size_t alignment = 64;

FairMQMessagePtr outMsg1(push.NewMessage(100, fair::mq::Alignment{alignment}));
ASSERT_EQ(reinterpret_cast<uintptr_t>(outMsg1->GetData()) % alignment, 0);
ASSERT_EQ(push.Send(outMsg1), 100);

FairMQMessagePtr inMsg1(pull.NewMessage(fair::mq::Alignment{alignment}));
ASSERT_EQ(pull.Receive(inMsg1), 100);
ASSERT_EQ(reinterpret_cast<uintptr_t>(inMsg1->GetData()) % alignment, 0);

FairMQMessagePtr outMsg2(push.NewMessage(32, fair::mq::Alignment{alignment}));
ASSERT_EQ(reinterpret_cast<uintptr_t>(outMsg2->GetData()) % alignment, 0);
ASSERT_EQ(push.Send(outMsg2), 32);

FairMQMessagePtr inMsg2(pull.NewMessage(fair::mq::Alignment{alignment}));
ASSERT_EQ(pull.Receive(inMsg2), 32);
ASSERT_EQ(reinterpret_cast<uintptr_t>(inMsg2->GetData()) % alignment, 0);

FairMQMessagePtr outMsg3(push.NewMessage(100, fair::mq::Alignment{0}));
ASSERT_EQ(push.Send(outMsg3), 100);

FairMQMessagePtr inMsg3(pull.NewMessage(fair::mq::Alignment{0}));
ASSERT_EQ(pull.Receive(inMsg3), 100);

FairMQMessagePtr msg1(push.NewMessage(25));
msg1->Rebuild(50, fair::mq::Alignment{alignment});
ASSERT_EQ(reinterpret_cast<uintptr_t>(msg1->GetData()) % alignment, 0);

size_t alignment2 = 32;
FairMQMessagePtr msg2(push.NewMessage(25, fair::mq::Alignment{alignment}));
msg2->Rebuild(50, fair::mq::Alignment{alignment2});
ASSERT_EQ(reinterpret_cast<uintptr_t>(msg2->GetData()) % alignment2, 0);
}

void EmptyMessage(const string& transport, const string& _address)
Expand Down Expand Up @@ -149,9 +186,14 @@ TEST(Rebuild, shmem)
RunMsgRebuild("shmem");
}

TEST(Alignment, shmem) // TODO: add test for ZeroMQ once it is implemented
TEST(Alignment, shmem)
{
Alignment("shmem", "ipc://test_message_alignment");
}

TEST(Alignment, zeromq)
{
Alignment("shmem");
Alignment("zeromq", "ipc://test_message_alignment");
}

TEST(EmptyMessage, zeromq)
Expand Down

0 comments on commit 6815c9c

Please sign in to comment.