diff --git a/fairmq/FairMQMessage.h b/fairmq/FairMQMessage.h index 6eb97c1a4..f9eacaa71 100644 --- a/fairmq/FairMQMessage.h +++ b/fairmq/FairMQMessage.h @@ -39,7 +39,9 @@ class FairMQMessage FairMQMessage(FairMQTransportFactory* factory) : fTransport(factory) {} virtual void Rebuild() = 0; + virtual void Rebuild(fair::mq::Alignment alignment) = 0; virtual void Rebuild(const size_t size) = 0; + virtual void Rebuild(const size_t size, fair::mq::Alignment alignment) = 0; virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0; virtual void* GetData() const = 0; diff --git a/fairmq/ofi/Message.cxx b/fairmq/ofi/Message.cxx index d27d59240..cdd56117a 100644 --- a/fairmq/ofi/Message.cxx +++ b/fairmq/ofi/Message.cxx @@ -110,6 +110,12 @@ auto Message::Rebuild() -> void fHint = nullptr; } +auto Message::Rebuild(Alignment /* alignment */) -> void +{ + // TODO: implement alignment + Rebuild(); +} + auto Message::Rebuild(const size_t size) -> void { if (fFreeFunction) { @@ -131,6 +137,12 @@ auto Message::Rebuild(const size_t size) -> void fHint = nullptr; } +auto Message::Rebuild(const size_t size, Alignment /* alignment */) -> void +{ + // TODO: implement alignment + Rebuild(size); +} + auto Message::Rebuild(void* /*data*/, const size_t size, fairmq_free_fn* ffn, void* hint) -> void { if (fFreeFunction) { diff --git a/fairmq/ofi/Message.h b/fairmq/ofi/Message.h index 8af0b4999..e1c1daab1 100644 --- a/fairmq/ofi/Message.h +++ b/fairmq/ofi/Message.h @@ -52,7 +52,9 @@ class Message final : public fair::mq::Message Message operator=(const Message&) = delete; auto Rebuild() -> void override; + auto Rebuild(Alignment alignment) -> void override; auto Rebuild(const size_t size) -> void override; + auto Rebuild(const size_t size, Alignment alignment) -> void override; auto Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) -> void override; auto GetData() const -> void* override; diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index f2b2fa55d..b8e5eb093 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -50,11 +50,12 @@ class Message final : public fair::mq::Message fManager.IncrementMsgCounter(); } - Message(Manager& manager, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + Message(Manager& manager, Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) , fQueued(false) , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) { @@ -78,10 +79,11 @@ class Message final : public fair::mq::Message , fManager(manager) , fQueued(false) , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) { - InitializeChunk(size, static_cast(alignment)); + InitializeChunk(size, fAlignment); fManager.IncrementMsgCounter(); } @@ -142,6 +144,13 @@ class Message final : public fair::mq::Message fQueued = false; } + void Rebuild(Alignment alignment) override + { + CloseMessage(); + fQueued = false; + fAlignment = alignment.alignment; + } + void Rebuild(const size_t size) override { CloseMessage(); @@ -149,6 +158,14 @@ class Message final : public fair::mq::Message InitializeChunk(size); } + void Rebuild(const size_t size, Alignment alignment) override + { + CloseMessage(); + fQueued = false; + fAlignment = alignment.alignment; + InitializeChunk(size, fAlignment); + } + void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override { CloseMessage(); @@ -242,6 +259,7 @@ class Message final : public fair::mq::Message Manager& fManager; bool fQueued; MetaHeader fMeta; + size_t fAlignment; // TODO: put this to debug mode mutable Region* fRegionPtr; mutable char* fLocalPtr; @@ -276,8 +294,9 @@ class Message final : public fair::mq::Message } fLocalPtr = nullptr; fMeta.fSize = 0; + fAlignment = 0; - fManager.DecrementMsgCounter(); + fManager.DecrementMsgCounter(); // TODO: put this to debug mode } }; diff --git a/fairmq/zeromq/Message.h b/fairmq/zeromq/Message.h index 73bf80ec2..344855f94 100644 --- a/fairmq/zeromq/Message.h +++ b/fairmq/zeromq/Message.h @@ -18,8 +18,10 @@ #include #include +#include // malloc #include #include +#include // bad_alloc #include namespace fair @@ -38,14 +40,17 @@ class Message final : public fair::mq::Message public: Message(FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(0) , fMsg(tools::make_unique()) { if (zmq_msg_init(fMsg.get()) != 0) { LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno); } } - Message(Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + + Message(Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(alignment.alignment) , fMsg(tools::make_unique()) { if (zmq_msg_init(fMsg.get()) != 0) { @@ -55,6 +60,7 @@ class Message final : public fair::mq::Message Message(const size_t size, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(0) , fMsg(tools::make_unique()) { if (zmq_msg_init_size(fMsg.get(), size) != 0) { @@ -62,17 +68,40 @@ class Message final : public fair::mq::Message } } - Message(const size_t size, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr) + static std::pair AllocateAligned(size_t size, size_t alignment) + { + char* fullBufferPtr = static_cast(malloc(size + alignment)); + if (!fullBufferPtr) { + LOG(error) << "failed to allocate buffer with provided size (" << size << ") and alignment (" << alignment << ")."; + throw std::bad_alloc(); + } + + size_t offset = alignment - (reinterpret_cast(fullBufferPtr) % alignment); + char* alignedPartPtr = fullBufferPtr + offset; + + return {static_cast(fullBufferPtr), static_cast(alignedPartPtr)}; + } + + Message(const size_t size, Alignment alignment, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(alignment.alignment) , fMsg(tools::make_unique()) { - if (zmq_msg_init_size(fMsg.get(), size) != 0) { - LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); + if (fAlignment != 0) { + auto ptrs = AllocateAligned(size, fAlignment); + if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) { + LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); + } + } else { + if (zmq_msg_init_size(fMsg.get(), size) != 0) { + LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); + } } } Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(0) , fMsg(tools::make_unique()) { if (zmq_msg_init_data(fMsg.get(), data, size, ffn, hint) != 0) { @@ -82,6 +111,7 @@ class Message final : public fair::mq::Message Message(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr) : fair::mq::Message(factory) + , fAlignment(0) , fMsg(tools::make_unique()) { // FIXME: make this zero-copy: @@ -116,6 +146,16 @@ class Message final : public fair::mq::Message } } + void Rebuild(Alignment alignment) override + { + CloseMessage(); + fAlignment = alignment.alignment; + fMsg = tools::make_unique(); + if (zmq_msg_init(fMsg.get()) != 0) { + LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno); + } + } + void Rebuild(const size_t size) override { CloseMessage(); @@ -125,6 +165,24 @@ class Message final : public fair::mq::Message } } + void Rebuild(const size_t size, Alignment alignment) override + { + CloseMessage(); + fAlignment = alignment.alignment; + fMsg = tools::make_unique(); + + if (fAlignment != 0) { + auto ptrs = AllocateAligned(size, fAlignment); + if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) { + LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); + } + } else { + if (zmq_msg_init_size(fMsg.get(), size) != 0) { + LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno); + } + } + } + void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override { CloseMessage(); @@ -174,6 +232,23 @@ class Message final : public fair::mq::Message } } + void Realign() + { + // if alignment is provided + if (fAlignment != 0) { + void* data = GetData(); + size_t size = GetSize(); + // if buffer is valid && not already aligned with the given alignment + if (data != nullptr && reinterpret_cast(GetData()) % fAlignment) { + // create new aligned buffer + auto ptrs = AllocateAligned(size, fAlignment); + std::memcpy(ptrs.second, zmq_msg_data(fMsg.get()), size); + // rebuild the message with the new buffer + Rebuild(ptrs.second, size, [](void* /* buf */, void* hint) { free(hint); }, ptrs.first); + } + } + } + Transport GetType() const override { return Transport::ZMQ; } void Copy(const fair::mq::Message& msg) override @@ -189,6 +264,7 @@ class Message final : public fair::mq::Message ~Message() override { CloseMessage(); } private: + size_t fAlignment; std::unique_ptr fMsg; zmq_msg_t* GetMessage() const { return fMsg.get(); } @@ -200,6 +276,7 @@ class Message final : public fair::mq::Message } // reset the message object to allow reuse in Rebuild fMsg.reset(nullptr); + fAlignment = 0; } }; diff --git a/fairmq/zeromq/Socket.h b/fairmq/zeromq/Socket.h index 96aed8549..4ad5ef73d 100644 --- a/fairmq/zeromq/Socket.h +++ b/fairmq/zeromq/Socket.h @@ -173,6 +173,7 @@ class Socket final : public fair::mq::Socket while (true) { int nbytes = zmq_msg_recv(static_cast(msg.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) { + static_cast(msg.get())->Realign(); int64_t actualBytes = zmq_msg_size(static_cast(msg.get())->GetMessage()); fBytesRx += actualBytes; ++fMessagesRx; @@ -261,6 +262,7 @@ class Socket final : public fair::mq::Socket int nbytes = zmq_msg_recv(static_cast(part.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) { + static_cast(part.get())->Realign(); msgVec.push_back(move(part)); totalSize += nbytes; } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { diff --git a/test/message/_message.cxx b/test/message/_message.cxx index b49edd34e..fdb23fb2e 100644 --- a/test/message/_message.cxx +++ b/test/message/_message.cxx @@ -91,17 +91,54 @@ void RunMsgRebuild(const string& transport) EXPECT_EQ(string(static_cast(msg->GetData()), msg->GetSize()), string("asdf")); } -void Alignment(const string& transport) +void Alignment(const string& transport, const string& _address) { size_t session{fair::mq::tools::UuidHash()}; + std::string address(fair::mq::tools::ToString(_address, "_", transport)); fair::mq::ProgOptions config; config.SetProperty("session", to_string(session)); auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config); - FairMQMessagePtr msg(factory->CreateMessage(100, fair::mq::Alignment{64})); - ASSERT_EQ(reinterpret_cast(msg->GetData()) % 64, 0); + FairMQChannel push{"Push", "push", factory}; + push.Bind(address); + + FairMQChannel pull{"Pull", "pull", factory}; + pull.Connect(address); + + size_t alignment = 64; + + FairMQMessagePtr outMsg1(push.NewMessage(100, fair::mq::Alignment{alignment})); + ASSERT_EQ(reinterpret_cast(outMsg1->GetData()) % alignment, 0); + ASSERT_EQ(push.Send(outMsg1), 100); + + FairMQMessagePtr inMsg1(pull.NewMessage(fair::mq::Alignment{alignment})); + ASSERT_EQ(pull.Receive(inMsg1), 100); + ASSERT_EQ(reinterpret_cast(inMsg1->GetData()) % alignment, 0); + + FairMQMessagePtr outMsg2(push.NewMessage(32, fair::mq::Alignment{alignment})); + ASSERT_EQ(reinterpret_cast(outMsg2->GetData()) % alignment, 0); + ASSERT_EQ(push.Send(outMsg2), 32); + + FairMQMessagePtr inMsg2(pull.NewMessage(fair::mq::Alignment{alignment})); + ASSERT_EQ(pull.Receive(inMsg2), 32); + ASSERT_EQ(reinterpret_cast(inMsg2->GetData()) % alignment, 0); + + FairMQMessagePtr outMsg3(push.NewMessage(100, fair::mq::Alignment{0})); + ASSERT_EQ(push.Send(outMsg3), 100); + + FairMQMessagePtr inMsg3(pull.NewMessage(fair::mq::Alignment{0})); + ASSERT_EQ(pull.Receive(inMsg3), 100); + + FairMQMessagePtr msg1(push.NewMessage(25)); + msg1->Rebuild(50, fair::mq::Alignment{alignment}); + ASSERT_EQ(reinterpret_cast(msg1->GetData()) % alignment, 0); + + size_t alignment2 = 32; + FairMQMessagePtr msg2(push.NewMessage(25, fair::mq::Alignment{alignment})); + msg2->Rebuild(50, fair::mq::Alignment{alignment2}); + ASSERT_EQ(reinterpret_cast(msg2->GetData()) % alignment2, 0); } void EmptyMessage(const string& transport, const string& _address) @@ -149,9 +186,14 @@ TEST(Rebuild, shmem) RunMsgRebuild("shmem"); } -TEST(Alignment, shmem) // TODO: add test for ZeroMQ once it is implemented +TEST(Alignment, shmem) +{ + Alignment("shmem", "ipc://test_message_alignment"); +} + +TEST(Alignment, zeromq) { - Alignment("shmem"); + Alignment("zeromq", "ipc://test_message_alignment"); } TEST(EmptyMessage, zeromq)