From aae766d2ff932b7bb031fa0c1311f93bcfce416e Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Tue, 15 Dec 2020 10:03:17 +0100 Subject: [PATCH] zmq: implement alignment on sender side --- fairmq/shmem/Message.h | 7 ++++-- fairmq/zeromq/Message.h | 45 ++++++++++++++++++++++++++++++++------- test/message/_message.cxx | 29 ++++++++++++++++++++----- 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index fcd5a3880..ddb2512b7 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -50,11 +50,12 @@ class Message final : public fair::mq::Message fManager.IncrementMsgCounter(); } - Message(Manager& manager, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + Message(Manager& manager, Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) , fQueued(false) , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) { @@ -78,10 +79,11 @@ class Message final : public fair::mq::Message , fManager(manager) , fQueued(false) , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) { - InitializeChunk(size, static_cast(alignment)); + InitializeChunk(size, static_cast(fAlignment)); fManager.IncrementMsgCounter(); } @@ -242,6 +244,7 @@ class Message final : public fair::mq::Message Manager& fManager; bool fQueued; MetaHeader fMeta; + size_t fAlignment; mutable Region* fRegionPtr; mutable char* fLocalPtr; diff --git a/fairmq/zeromq/Message.h b/fairmq/zeromq/Message.h index 64ed96f37..0ee738092 100644 --- a/fairmq/zeromq/Message.h +++ b/fairmq/zeromq/Message.h @@ -18,6 +18,7 @@ #include #include +#include // malloc, aligned_alloc #include #include #include @@ -39,7 +40,8 @@ class Message final : public fair::mq::Message Message(FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fUsedSizeModified(false) - , fUsedSize() + , fUsedSize(0) + , fAlignment(0) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { @@ -47,10 +49,11 @@ class Message final : public fair::mq::Message LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno); } } - Message(Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + Message(Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fUsedSizeModified(false) - , fUsedSize() + , fUsedSize(0) + , fAlignment(alignment.alignment) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { @@ -63,22 +66,37 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fUsedSizeModified(false) , fUsedSize(size) + , fAlignment(0) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { - if (zmq_msg_init_size(fMsg.get(), size) != 0) { + void* ptr = malloc(size); + if (!ptr) { + LOG(error) << "failed to allocate buffer with provided size (" << size << ")."; + } + if (zmq_msg_init_data(fMsg.get(), ptr, size, [](void* data, void*) { free(data); }, nullptr) != 0) { LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); } } - Message(const size_t size, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + Message(const size_t size, Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fUsedSizeModified(false) , fUsedSize(size) + , fAlignment(alignment.alignment) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { - if (zmq_msg_init_size(fMsg.get(), size) != 0) { + void* ptr = nullptr; + if (alignment.alignment != 0) { + ptr = aligned_alloc(size, fAlignment); + } else { + ptr = malloc(size); + } + if (!ptr) { + LOG(error) << "failed to allocate buffer with provided size (" << size << ") and alignment (" << alignment.alignment << ")."; + } + if (zmq_msg_init_data(fMsg.get(), ptr, size, [](void* data, void*) { free(data); }, nullptr) != 0) { LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); } } @@ -86,7 +104,8 @@ class Message final : public fair::mq::Message Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fUsedSizeModified(false) - , fUsedSize() + , fUsedSize(0) + , fAlignment(0) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { @@ -98,7 +117,8 @@ class Message final : public fair::mq::Message Message(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fUsedSizeModified(false) - , fUsedSize() + , fUsedSize(0) + , fAlignment(0) , fMsg(tools::make_unique()) , fViewMsg(nullptr) { @@ -212,6 +232,14 @@ class Message final : public fair::mq::Message } } + // void Realign() + // { + // if (fAlignment != 0) { + // if (reinterpret_cast(GetData()) % fAlignment) { + // } + // } + // } + Transport GetType() const override { return Transport::ZMQ; } void Copy(const fair::mq::Message& msg) override @@ -235,6 +263,7 @@ class Message final : public fair::mq::Message private: bool fUsedSizeModified; size_t fUsedSize; + size_t fAlignment; std::unique_ptr fMsg; std::unique_ptr fViewMsg; // view on a subset of fMsg (treating it as user buffer) diff --git a/test/message/_message.cxx b/test/message/_message.cxx index 6cec36b6b..ff2468731 100644 --- a/test/message/_message.cxx +++ b/test/message/_message.cxx @@ -78,17 +78,31 @@ void RunMsgRebuild(const string& transport) EXPECT_EQ(string(static_cast(msg->GetData()), msg->GetSize()), string("asdf")); } -void Alignment(const string& transport) +void Alignment(const string& transport, const string& _address) { size_t session{fair::mq::tools::UuidHash()}; + std::string address(fair::mq::tools::ToString(_address, "_", transport)); fair::mq::ProgOptions config; config.SetProperty("session", to_string(session)); auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config); - FairMQMessagePtr msg(factory->CreateMessage(100, fair::mq::Alignment{64})); - ASSERT_EQ(reinterpret_cast(msg->GetData()) % 64, 0); + FairMQChannel push{"Push", "push", factory}; + push.Bind(address); + + FairMQChannel pull{"Pull", "pull", factory}; + pull.Connect(address); + + size_t alignment = 64; + + FairMQMessagePtr outMsg(push.NewMessage(100, fair::mq::Alignment{alignment})); + ASSERT_EQ(reinterpret_cast(outMsg->GetData()) % alignment, 0); + ASSERT_EQ(push.Send(outMsg), 100); + + FairMQMessagePtr inMsg(pull.NewMessage()); + ASSERT_EQ(pull.Receive(inMsg), 100); + // ASSERT_EQ(reinterpret_cast(inMsg->GetData()) % alignment, 0); } void EmptyMessage(const string& transport, const string& _address) @@ -136,9 +150,14 @@ TEST(Rebuild, shmem) RunMsgRebuild("shmem"); } -TEST(Alignment, shmem) // TODO: add test for ZeroMQ once it is implemented +TEST(Alignment, shmem) +{ + Alignment("shmem", "ipc://test_message_alignment"); +} + +TEST(Alignment, zeromq) { - Alignment("shmem"); + Alignment("zeromq", "ipc://test_message_alignment"); } TEST(EmptyMessage, zeromq)