Skip to content

Commit

Permalink
zmq: implement alignment on sender side
Browse files Browse the repository at this point in the history
  • Loading branch information
rbx committed Dec 17, 2020
1 parent cc40063 commit aae766d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 15 deletions.
7 changes: 5 additions & 2 deletions fairmq/shmem/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class Message final : public fair::mq::Message
fManager.IncrementMsgCounter();
}

Message(Manager& manager, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
Message(Manager& manager, Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
Expand All @@ -78,10 +79,11 @@ class Message final : public fair::mq::Message
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
InitializeChunk(size, static_cast<size_t>(alignment));
InitializeChunk(size, static_cast<size_t>(fAlignment));
fManager.IncrementMsgCounter();
}

Expand Down Expand Up @@ -242,6 +244,7 @@ class Message final : public fair::mq::Message
Manager& fManager;
bool fQueued;
MetaHeader fMeta;
size_t fAlignment;
mutable Region* fRegionPtr;
mutable char* fLocalPtr;

Expand Down
45 changes: 37 additions & 8 deletions fairmq/zeromq/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <zmq.h>

#include <cstddef>
#include <cstdlib> // malloc, aligned_alloc
#include <cstring>
#include <memory>
#include <string>
Expand All @@ -39,18 +40,20 @@ class Message final : public fair::mq::Message
Message(FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize()
, fUsedSize(0)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
if (zmq_msg_init(fMsg.get()) != 0) {
LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno);
}
}
Message(Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
Message(Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize()
, fUsedSize(0)
, fAlignment(alignment.alignment)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
Expand All @@ -63,30 +66,46 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize(size)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
void* ptr = malloc(size);
if (!ptr) {
LOG(error) << "failed to allocate buffer with provided size (" << size << ").";
}
if (zmq_msg_init_data(fMsg.get(), ptr, size, [](void* data, void*) { free(data); }, nullptr) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
}

Message(const size_t size, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
Message(const size_t size, Alignment alignment, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize(size)
, fAlignment(alignment.alignment)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
void* ptr = nullptr;
if (alignment.alignment != 0) {
ptr = aligned_alloc(size, fAlignment);
} else {
ptr = malloc(size);
}
if (!ptr) {
LOG(error) << "failed to allocate buffer with provided size (" << size << ") and alignment (" << alignment.alignment << ").";
}
if (zmq_msg_init_data(fMsg.get(), ptr, size, [](void* data, void*) { free(data); }, nullptr) != 0) {
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
}
}

Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize()
, fUsedSize(0)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
Expand All @@ -98,7 +117,8 @@ class Message final : public fair::mq::Message
Message(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fUsedSizeModified(false)
, fUsedSize()
, fUsedSize(0)
, fAlignment(0)
, fMsg(tools::make_unique<zmq_msg_t>())
, fViewMsg(nullptr)
{
Expand Down Expand Up @@ -212,6 +232,14 @@ class Message final : public fair::mq::Message
}
}

// void Realign()
// {
// if (fAlignment != 0) {
// if (reinterpret_cast<uintptr_t>(GetData()) % fAlignment) {
// }
// }
// }

Transport GetType() const override { return Transport::ZMQ; }

void Copy(const fair::mq::Message& msg) override
Expand All @@ -235,6 +263,7 @@ class Message final : public fair::mq::Message
private:
bool fUsedSizeModified;
size_t fUsedSize;
size_t fAlignment;
std::unique_ptr<zmq_msg_t> fMsg;
std::unique_ptr<zmq_msg_t> fViewMsg; // view on a subset of fMsg (treating it as user buffer)

Expand Down
29 changes: 24 additions & 5 deletions test/message/_message.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,31 @@ void RunMsgRebuild(const string& transport)
EXPECT_EQ(string(static_cast<char*>(msg->GetData()), msg->GetSize()), string("asdf"));
}

void Alignment(const string& transport)
void Alignment(const string& transport, const string& _address)
{
size_t session{fair::mq::tools::UuidHash()};
std::string address(fair::mq::tools::ToString(_address, "_", transport));

fair::mq::ProgOptions config;
config.SetProperty<string>("session", to_string(session));

auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config);

FairMQMessagePtr msg(factory->CreateMessage(100, fair::mq::Alignment{64}));
ASSERT_EQ(reinterpret_cast<uintptr_t>(msg->GetData()) % 64, 0);
FairMQChannel push{"Push", "push", factory};
push.Bind(address);

FairMQChannel pull{"Pull", "pull", factory};
pull.Connect(address);

size_t alignment = 64;

FairMQMessagePtr outMsg(push.NewMessage(100, fair::mq::Alignment{alignment}));
ASSERT_EQ(reinterpret_cast<uintptr_t>(outMsg->GetData()) % alignment, 0);
ASSERT_EQ(push.Send(outMsg), 100);

FairMQMessagePtr inMsg(pull.NewMessage());
ASSERT_EQ(pull.Receive(inMsg), 100);
// ASSERT_EQ(reinterpret_cast<uintptr_t>(inMsg->GetData()) % alignment, 0);
}

void EmptyMessage(const string& transport, const string& _address)
Expand Down Expand Up @@ -136,9 +150,14 @@ TEST(Rebuild, shmem)
RunMsgRebuild("shmem");
}

TEST(Alignment, shmem) // TODO: add test for ZeroMQ once it is implemented
TEST(Alignment, shmem)
{
Alignment("shmem", "ipc://test_message_alignment");
}

TEST(Alignment, zeromq)
{
Alignment("shmem");
Alignment("zeromq", "ipc://test_message_alignment");
}

TEST(EmptyMessage, zeromq)
Expand Down

0 comments on commit aae766d

Please sign in to comment.