diff --git a/.gitignore b/.gitignore index b29a93b..ed143f7 100755 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ third-party # Build /*build* + diff --git a/CMakeLists.txt b/CMakeLists.txt index 4568f46..fb5803e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,8 @@ target_include_directories(rome PUBLIC $ $) target_compile_definitions(rome PUBLIC ROME_LOG_LEVEL=${LOG_LEVEL}) +# We are defining an external fmt package, we don't need spdlog's internal one +target_compile_definitions(rome PUBLIC SPDLOG_FMT_EXTERNAL=ON) target_link_libraries(rome PUBLIC rome::protos rdma::ibverbs rdma::cm fmt::fmt std::coroutines) target_link_libraries(rome PUBLIC absl::status absl::statusor absl::synchronization) diff --git a/DevDockerfile b/DevDockerfile new file mode 100644 index 0000000..288678c --- /dev/null +++ b/DevDockerfile @@ -0,0 +1,6 @@ +FROM ubuntu:22.04 +RUN apt-get update +RUN apt-get install libprotobuf-dev protobuf-compiler -y +RUN apt-get install cmake -y +RUN apt-get install clang-15 libabsl-dev librdmacm-dev libibverbs-dev libgtest-dev libbenchmark-dev libfmt-dev libspdlog-dev libgmock-dev -y +RUN apt-get install libc6-dev-i386 -y diff --git a/README.md b/README.md index efbd49c..118a939 100644 --- a/README.md +++ b/README.md @@ -77,14 +77,14 @@ librome uses: For Ubuntu 22.04 the following packages can be installed through apt: -* libabsl-dev -* librdmacm-dev -* libibverbs-dev -* libgtest-dev -* libbenchmark-dev -* libfmt-dev -* libspdlog-dev -* protobuf-compiler +* libabsl-dev +* librdmacm-dev +* libibverbs-dev +* libgtest-dev +* libbenchmark-dev +* libfmt-dev +* libspdlog-dev +* protobuf-compiler * libgmock-dev `cicd/install_dependencies_ubuntu.sh` is a script for installing these on Ubuntu 22.04. @@ -111,6 +111,17 @@ Make sure to clear your build directory if recompiling with a different compiler `make install` will install librome in your default installation location or the directory passed through defining `CMAKE_INSTALL_PREFIX`. +## Docker + +Create and run a docker container to emulate the build enviornment with the minimum dependencies installed + +```{bash} +docker build --tag sss-dev --file DevDockerfile . +docker run --privileged --rm -v {MOUNT_DIR}:/home --name sss -it sss-dev +``` + +You can then develop from that container using the Dev Container extension so you can take full advantage of syntax highlighting and be able to build locally. + # Old Setup instructions (unsure if this still works) The Dockerfile contains all the dependencies required by this project and handles automatically setting up the correct development environment. There are enough comments in the Dockerfile itself to understand what is going on, but at a high level its main purpose is to install the tooling necessary to build the project. @@ -150,5 +161,3 @@ One peculiarity for UTM's `davfs` setup is that it requires a username and passw When prompted, just hit enter. To avoid the prompt altogether, you can update `/etc/davfs2/secrets` to include a line for the hosted files. In my configuration, I simply put the following line: `http://localhost:9843 user passwd`. - -## Docker diff --git a/gladiators/CMakeLists.txt b/gladiators/CMakeLists.txt index 3081160..1b44929 100644 --- a/gladiators/CMakeLists.txt +++ b/gladiators/CMakeLists.txt @@ -1,6 +1,6 @@ -add_executable(coroutines coroutines/main.cc) -target_link_libraries(coroutines PRIVATE rome::rome) +add_executable(coroutines_out coroutines/main.cc) +target_link_libraries(coroutines_out PRIVATE rome::rome) -add_executable(hello_world hello_world/main.cc) -target_link_libraries(hello_world PRIVATE rome::rome) -target_link_libraries(hello_world PRIVATE absl::flags absl::flags_parse) +add_executable(hello_world_out hello_world/main.cc) +target_link_libraries(hello_world_out PRIVATE rome::rome) +target_link_libraries(hello_world_out PRIVATE absl::flags absl::flags_parse) diff --git a/include/rome/rdma/connection_manager/connection.h b/include/rome/rdma/connection_manager/connection.h new file mode 100644 index 0000000..2e7a486 --- /dev/null +++ b/include/rome/rdma/connection_manager/connection.h @@ -0,0 +1,61 @@ +#pragma once + +#include + +#include +#include + +#include "rome/rdma/channel/rdma_accessor.h" +#include "rome/rdma/channel/rdma_channel.h" +#include "rome/rdma/channel/twosided_messenger.h" + +namespace rome::rdma { + +// Contains the necessary information for communicating between nodes. This +// class wraps a unique pointer to the `rdma_cm_id` that holds the QP used for +// communication, along with the `RdmaChannel` that represents the memory used +// for 2-sided message-passing. +template > +class Connection { + public: + typedef Channel channel_type; + + Connection() + : terminated_(false), + src_id_(std::numeric_limits::max()), + dst_id_(std::numeric_limits::max()), + channel_(nullptr) {} + Connection(uint32_t src_id, uint32_t dst_id, + std::unique_ptr channel) + : terminated_(false), + src_id_(src_id), + dst_id_(dst_id), + channel_(std::move(channel)) {} + + Connection(const Connection&) = delete; + Connection(Connection&& c) + : terminated_(c.terminated_), + src_id_(c.src_id_), + dst_id_(c.dst_id_), + channel_(std::move(c.channel_)) {} + + // Getters. + inline bool terminated() const { return terminated_; } + uint32_t src_id() const { return src_id_; } + uint32_t dst_id() const { return dst_id_; } + rdma_cm_id* id() const { return channel_->id(); } + channel_type* channel() const { return channel_.get(); } + + void Terminate() { terminated_ = true; } + + private: + volatile bool terminated_; + + uint32_t src_id_; + uint32_t dst_id_; + + // Remotely accessible memory that is used for 2-sided message-passing. + std::unique_ptr channel_; +}; + +} // namespace rome::rdma \ No newline at end of file diff --git a/include/rome/rdma/connection_manager/connection_manager.h b/include/rome/rdma/connection_manager/connection_manager.h new file mode 100644 index 0000000..964e573 --- /dev/null +++ b/include/rome/rdma/connection_manager/connection_manager.h @@ -0,0 +1,169 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "connection.h" +#include "rome/rdma/channel/rdma_accessor.h" +#include "rome/rdma/channel/rdma_channel.h" +#include "rome/rdma/channel/twosided_messenger.h" +#include "rome/rdma/rdma_broker.h" +#include "rome/rdma/rdma_device.h" +#include "rome/rdma/rdma_memory.h" +#include "rome/rdma/rdma_receiver.h" +#include "rome/util/coroutine.h" + +namespace rome::rdma { + +template +class ConnectionManager : public RdmaReceiverInterface { + public: + typedef Connection conn_type; + + ~ConnectionManager(); + explicit ConnectionManager(uint32_t my_id); + + absl::Status Start(std::string_view addr, std::optional port); + + // Getters. + std::string address() const { return broker_->address(); } + uint16_t port() const { return broker_->port(); } + ibv_pd* pd() const { return broker_->pd(); } + + int GetNumConnections() { + Acquire(my_id_); + int size = established_.size(); + Release(); + return size; + } + + // `RdmaReceiverInterface` implementaiton + void OnConnectRequest(rdma_cm_id* id, rdma_cm_event* event) override; + void OnEstablished(rdma_cm_id* id, rdma_cm_event* event) override; + void OnDisconnect(rdma_cm_id* id) override; + + // `RdmaClientInterface` implementation + absl::StatusOr Connect(uint32_t node_id, std::string_view server, + uint16_t port); + + absl::StatusOr GetConnection(uint32_t node_id); + + void Shutdown(); + + private: + // The size of each memory region dedicated to a single connection. + static constexpr int kCapacity = 1 << 12; // 4 KiB + static constexpr int kMaxRecvBytes = 64; + + static constexpr int kMaxWr = kCapacity / kMaxRecvBytes; + static constexpr int kMaxSge = 1; + static constexpr int kMaxInlineData = 0; + + static constexpr char kPdId[] = "ConnectionManager"; + + static constexpr int kUnlocked = -1; + + static constexpr uint32_t kMinBackoffUs = 100; + static constexpr uint32_t kMaxBackoffUs = 5000000; + + // Each `rdma_cm_id` can be associated with some context, which is represented + // by `IdContext`. `node_id` is the numerical identifier for the peer node of + // the connection and `conn_param` is used to provide private data during the + // connection set up to send the local node identifier upon connection setup. + struct IdContext { + uint32_t node_id; + rdma_conn_param conn_param; + ChannelType* channel; + + static inline uint32_t GetNodeId(void* ctx) { + return reinterpret_cast(ctx)->node_id; + } + + static inline ChannelType* GetRdmaChannel(void* ctx) { + return reinterpret_cast(ctx)->channel; + } + }; + + // Lock acquisition will spin until either the lock is acquired successfully + // or the locker is an outgoing connection request from this node. + inline bool Acquire(int peer_id) { + for (int expected = kUnlocked; + !mu_.compare_exchange_weak(expected, peer_id); expected = kUnlocked) { + if (expected == my_id_) { + ROME_DEBUG( + "[Acquire] (Node {}) Giving up lock acquisition: actual={}, " + "swap={}", + my_id_, expected, peer_id); + return false; + } + } + return true; + } + + inline void Release() { mu_ = kUnlocked; } + + constexpr ibv_qp_init_attr DefaultQpInitAttr() { + ibv_qp_init_attr init_attr; + std::memset(&init_attr, 0, sizeof(init_attr)); + init_attr.cap.max_send_wr = init_attr.cap.max_recv_wr = kMaxWr; + init_attr.cap.max_send_sge = init_attr.cap.max_recv_sge = kMaxSge; + init_attr.cap.max_inline_data = kMaxInlineData; + init_attr.sq_sig_all = 0; // Must request completions. + init_attr.qp_type = IBV_QPT_RC; + return init_attr; + } + + constexpr ibv_qp_attr DefaultQpAttr() { + ibv_qp_attr attr; + std::memset(&attr, 0, sizeof(attr)); + attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC; + attr.max_dest_rd_atomic = 8; + attr.path_mtu = IBV_MTU_4096; + attr.min_rnr_timer = 12; + attr.rq_psn = 0; + attr.sq_psn = 0; + attr.timeout = 12; + attr.retry_cnt = 7; + attr.rnr_retry = 1; + attr.max_rd_atomic = 8; + return attr; + } + + absl::StatusOr ConnectLoopback(rdma_cm_id* id); + + // Whether or not to stop handling requests. + volatile bool accepting_; + + // Current status + absl::Status status_; + + uint32_t my_id_; + std::unique_ptr broker_; + ibv_pd* pd_; // Convenience ptr to protection domain of `broker_` + + // Maintains connection information for a given Internet address. A connection + // manager only maintains a single connection per node. Nodes are identified + // by a string representing their IP address. + std::atomic mu_; + std::unordered_map> requested_; + std::unordered_map> established_; + + uint32_t backoff_us_{0}; + + rdma_cm_id* loopback_id_ = nullptr; +}; + +} // namespace rome::rdma + +#include "connection_manager_impl.h" \ No newline at end of file diff --git a/include/rome/rdma/connection_manager/connection_manager_impl.h b/include/rome/rdma/connection_manager/connection_manager_impl.h new file mode 100644 index 0000000..ff372bf --- /dev/null +++ b/include/rome/rdma/connection_manager/connection_manager_impl.h @@ -0,0 +1,450 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "connection_manager.h" +#include "rome/rdma/channel/rdma_channel.h" +#include "rome/rdma/rdma_util.h" +#include "rome/util/coroutine.h" +#include "rome/util/status_util.h" + +#define LOOPBACK_PORT_NUM 1 + +namespace rome::rdma { + +using ::util::InternalErrorBuilder; + +template +ConnectionManager::~ConnectionManager() { + ROME_DEBUG("Shutting down: {}", fmt::ptr(this)); + Acquire(my_id_); + Shutdown(); + + ROME_DEBUG("Stopping broker..."); + if (broker_ != nullptr) auto s = broker_->Stop(); + + auto cleanup = [this](auto& iter) { + // A loopback connection is made manually, so we do not need to deal with + // the regular `rdma_cm` handling. Similarly, we avoid destroying the event + // channel below since it is destroyed along with the id. + auto id = iter.second->id(); + if (iter.first != my_id_) { + rdma_disconnect(id); + rdma_cm_event* event; + auto result = rdma_get_cm_event(id->channel, &event); + while (result == 0) { + RDMA_CM_ASSERT(rdma_ack_cm_event, event); + result = rdma_get_cm_event(id->channel, &event); + } + } + + // We only allocate contexts for connections that were created by the + // `RdmaReceiver` callbacks. Otherwise, we created an event channel so + // that we could asynchronously connect (and avoid distributed deadlock). + auto* context = id->context; + auto* channel = id->channel; + rdma_destroy_ep(id); + + if (iter.first != my_id_ && context != nullptr) { + free(context); + } else if (iter.first != my_id_) { + rdma_destroy_event_channel(channel); + } + }; + + std::for_each(established_.begin(), established_.end(), cleanup); + Release(); +} + +template +void ConnectionManager::Shutdown() { + // Stop accepting new requests. + accepting_ = false; +} + +template +ConnectionManager::ConnectionManager(uint32_t my_id) + : accepting_(false), my_id_(my_id), broker_(nullptr), mu_(-1) {} + +template +absl::Status ConnectionManager::Start( + std::string_view addr, std::optional port) { + if (accepting_) { + return InternalErrorBuilder() << "Cannot start broker twice"; + } + accepting_ = true; + + broker_ = RdmaBroker::Create(addr, port, this); + ROME_CHECK_QUIET( + ROME_RETURN(InternalErrorBuilder() << "Failed to create broker"), + broker_ != nullptr) + return absl::OkStatus(); +} + +namespace { + +[[maybe_unused]] inline std::string GetDestinationAsString(rdma_cm_id* id) { + char addr_str[INET_ADDRSTRLEN]; + ROME_ASSERT(inet_ntop(AF_INET, &(id->route.addr.dst_sin.sin_addr), addr_str, + INET_ADDRSTRLEN) != nullptr, + "inet_ntop(): {}", strerror(errno)); + std::stringstream ss; + ss << addr_str << ":" << rdma_get_dst_port(id); + return ss.str(); +} + +} // namespace + +template +void ConnectionManager::OnConnectRequest(rdma_cm_id* id, + rdma_cm_event* event) { + if (!accepting_) return; + + // The private data is used to understand from what node the connection + // request is coming from. + ROME_ASSERT_DEBUG(event->param.conn.private_data != nullptr, + "Received connect request without private data."); + uint32_t peer_id = + *reinterpret_cast(event->param.conn.private_data); + ROME_DEBUG("[OnConnectRequest] (Node {}) Got connection request from: {}", + my_id_, peer_id); + + if (peer_id != my_id_) { + // Attempt to acquire lock when not originating from same node + if (!Acquire(peer_id)) { + ROME_DEBUG("Lock acquisition failed: {}", mu_); + rdma_reject(event->id, nullptr, 0); + rdma_destroy_ep(id); + rdma_ack_cm_event(event); + return; + } + + // Check if the connection has already been established. + if (auto conn = established_.find(peer_id); + conn != established_.end() || requested_.contains(peer_id)) { + rdma_reject(event->id, nullptr, 0); + rdma_destroy_ep(id); + rdma_ack_cm_event(event); + if (peer_id != my_id_) Release(); + auto status = + util::AlreadyExistsErrorBuilder() + << "[OnConnectRequest] (Node " << my_id_ << ") Connection already " + << (conn != established_.end() ? "established" : "requested") << ": " + << peer_id; + ROME_DEBUG(absl::Status(status).ToString()); + return; + } + + // Create a new QP for the connection. + ibv_qp_init_attr init_attr = DefaultQpInitAttr(); + ROME_ASSERT(id->qp == nullptr, "QP already allocated...?"); + RDMA_CM_ASSERT(rdma_create_qp, id, pd(), &init_attr); + } else { + // rdma_destroy_id(id); + id = loopback_id_; + } + + // Prepare the necessary resources for this connection. Includes a + // `RdmaChannel` that holds the QP and memory for 2-sided communication. + // The underlying QP is RC, so we reuse it for issuing 1-sided RDMA too. We + // also store the `peer_id` associated with this id so that we can reference + // it later. + auto context = new IdContext{peer_id, {}, {}}; + std::memset(&context->conn_param, 0, sizeof(context->conn_param)); + context->conn_param.private_data = &context->node_id; + context->conn_param.private_data_len = sizeof(context->node_id); + context->conn_param.rnr_retry_count = 1; // Retry forever + context->conn_param.retry_count = 7; + context->conn_param.responder_resources = 8; + context->conn_param.initiator_depth = 8; + id->context = context; + + auto iter = established_.emplace( + peer_id, + new Connection{my_id_, peer_id, std::make_unique(id)}); + ROME_ASSERT_DEBUG(iter.second, "Insertion failed"); + + ROME_DEBUG("[OnConnectRequest] (Node {}) peer={}, id={}", my_id_, peer_id, + fmt::ptr(id)); + RDMA_CM_ASSERT(rdma_accept, id, + peer_id == my_id_ ? nullptr : &context->conn_param); + rdma_ack_cm_event(event); + if (peer_id != my_id_) Release(); +} + +template +void ConnectionManager::OnEstablished(rdma_cm_id* id, + rdma_cm_event* event) { + rdma_ack_cm_event(event); +} + +template +void ConnectionManager::OnDisconnect(rdma_cm_id* id) { + // This disconnection originated from the peer, so we simply disconnect the + // local endpoint and clean it up. + // + // NOTE: The event is already ack'ed by the caller. + rdma_disconnect(id); + + uint32_t peer_id = IdContext::GetNodeId(id->context); + Acquire(peer_id); + if (auto conn = established_.find(peer_id); + conn != established_.end() && conn->second->id() == id) { + ROME_DEBUG("(Node {}) Disconnected from node {}", my_id_, peer_id); + established_.erase(peer_id); + } + Release(); + auto* event_channel = id->channel; + rdma_destroy_ep(id); + rdma_destroy_event_channel(event_channel); +} + +template +absl::StatusOr::conn_type*> +ConnectionManager::ConnectLoopback(rdma_cm_id* id) { + ROME_ASSERT_DEBUG(id->qp != nullptr, "No QP associated with endpoint"); + ROME_DEBUG("Connecting loopback..."); + ibv_qp_attr attr; + int attr_mask; + + attr = DefaultQpAttr(); + attr.qp_state = IBV_QPS_INIT; + attr.port_num = LOOPBACK_PORT_NUM; // id->port_num; + attr_mask = + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; + ROME_TRACE("Loopback: IBV_QPS_INIT"); + RDMA_CM_CHECK(ibv_modify_qp, id->qp, &attr, attr_mask); + + ibv_port_attr port_attr; + RDMA_CM_CHECK(ibv_query_port, id->verbs, LOOPBACK_PORT_NUM, &port_attr); // RDMA_CM_CHECK(ibv_query_port, id->verbs, id->port_num, &port_attr); + attr.ah_attr.dlid = port_attr.lid; + attr.qp_state = IBV_QPS_RTR; + attr.dest_qp_num = id->qp->qp_num; + attr.ah_attr.port_num = LOOPBACK_PORT_NUM; // id->port_num; + attr_mask = + (IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + ROME_TRACE("Loopback: IBV_QPS_RTR"); + RDMA_CM_CHECK(ibv_modify_qp, id->qp, &attr, attr_mask); + + attr.qp_state = IBV_QPS_RTS; + attr_mask = (IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_MAX_QP_RD_ATOMIC); + ROME_TRACE("Loopback: IBV_QPS_RTS"); + RDMA_CM_CHECK(ibv_modify_qp, id->qp, &attr, attr_mask); + + RDMA_CM_CHECK(fcntl, id->recv_cq->channel->fd, F_SETFL, + fcntl(id->recv_cq->channel->fd, F_GETFL) | O_NONBLOCK); + RDMA_CM_CHECK(fcntl, id->send_cq->channel->fd, F_SETFL, + fcntl(id->send_cq->channel->fd, F_GETFL) | O_NONBLOCK); + + // Allocate a new control channel to be used with this connection + auto channel = std::make_unique(id); + auto iter = established_.emplace( + my_id_, new Connection{my_id_, my_id_, std::move(channel)}); + ROME_ASSERT(iter.second, "Unexepected error"); + Release(); + return established_[my_id_].get(); +} + +template +absl::StatusOr::conn_type*> +ConnectionManager::Connect(uint32_t peer_id, + std::string_view server, + uint16_t port) { + if (Acquire(my_id_)) { + auto conn = established_.find(peer_id); + if (conn != established_.end()) { + Release(); + return conn->second.get(); + } + + auto port_str = std::to_string(htons(port)); + rdma_cm_id* id = nullptr; + rdma_addrinfo hints, *resolved = nullptr; + + std::memset(&hints, 0, sizeof(hints)); + hints.ai_port_space = RDMA_PS_TCP; + hints.ai_qp_type = IBV_QPT_RC; + hints.ai_family = AF_IB; + + struct sockaddr_in src; + std::memset(&src, 0, sizeof(src)); + src.sin_family = AF_INET; + auto src_addr_str = broker_->address(); + inet_aton(src_addr_str.data(), &src.sin_addr); + + hints.ai_src_addr = reinterpret_cast(&src); + hints.ai_src_len = sizeof(src); + + // Resolve the server's address. If this connection request is for the + // loopback connection, then we are going to + int gai_ret = + rdma_getaddrinfo(server.data(), port_str.data(), &hints, &resolved); + ROME_CHECK_QUIET( + ROME_RETURN(InternalErrorBuilder() + << "rdma_getaddrinfo(): " << gai_strerror(gai_ret)), + gai_ret == 0); + + ibv_qp_init_attr init_attr = DefaultQpInitAttr(); + auto err = rdma_create_ep(&id, resolved, pd(), &init_attr); + rdma_freeaddrinfo(resolved); + if (err) { + Release(); + return util::InternalErrorBuilder() + << "rdma_create_ep(): " << strerror(errno) << " (" << errno << ")"; + } + ROME_DEBUG("[Connect] (Node {}) Trying to connect to: {} (id={})", my_id_, + peer_id, fmt::ptr(id)); + + if (peer_id == my_id_) return ConnectLoopback(id); + + auto* event_channel = rdma_create_event_channel(); + RDMA_CM_CHECK(fcntl, event_channel->fd, F_SETFL, + fcntl(event_channel->fd, F_GETFL) | O_NONBLOCK); + RDMA_CM_CHECK(rdma_migrate_id, id, event_channel); + + rdma_conn_param conn_param; + std::memset(&conn_param, 0, sizeof(conn_param)); + conn_param.private_data = &my_id_; + conn_param.private_data_len = sizeof(my_id_); + conn_param.retry_count = 7; + conn_param.rnr_retry_count = 1; + conn_param.responder_resources = 8; + conn_param.initiator_depth = 8; + + RDMA_CM_CHECK(rdma_connect, id, &conn_param); + + // Handle events. + while (true) { + rdma_cm_event* event; + auto result = rdma_get_cm_event(id->channel, &event); + while (result < 0 && errno == EAGAIN) { + result = rdma_get_cm_event(id->channel, &event); + } + ROME_DEBUG("[Connect] (Node {}) Got event: {} (id={})", my_id_, + rdma_event_str(event->event), fmt::ptr(id)); + + absl::StatusOr conn_or; + switch (event->event) { + case RDMA_CM_EVENT_ESTABLISHED: { + RDMA_CM_CHECK(rdma_ack_cm_event, event); + auto conn = established_.find(peer_id); + if (bool is_established = (conn != established_.end()); + is_established && peer_id != my_id_) { + Release(); + + // Since we are initiating the disconnection, we must get and ack + // the event. + ROME_DEBUG("[Connect] (Node {}) Disconnecting: (id={})", my_id_, + fmt::ptr(id)); + RDMA_CM_CHECK(rdma_disconnect, id); + rdma_cm_event* event; + auto result = rdma_get_cm_event(id->channel, &event); + while (result < 0 && errno == EAGAIN) { + result = rdma_get_cm_event(id->channel, &event); + } + RDMA_CM_CHECK(rdma_ack_cm_event, event); + + rdma_destroy_ep(id); + rdma_destroy_event_channel(event_channel); + + if (is_established) { + ROME_DEBUG("[Connect] Already connected: {}", peer_id); + return conn->second.get(); + } else { + return util::UnavailableErrorBuilder() + << "[Connect] (Node " << my_id_ + << ") Connection is already requested: " << peer_id; + } + } + + // If this code block is reached, then the connection established by + // this call is the first successful connection to be established and + // therefore we must add it to the set of established connections. + ROME_DEBUG( + "Connected: dev={}, addr={}, port={}", id->verbs->device->name, + inet_ntoa(reinterpret_cast(rdma_get_local_addr(id)) + ->sin_addr), + rdma_get_src_port(id)); + + RDMA_CM_CHECK(fcntl, event_channel->fd, F_SETFL, + fcntl(event_channel->fd, F_GETFL) | O_SYNC); + RDMA_CM_CHECK(fcntl, id->recv_cq->channel->fd, F_SETFL, + fcntl(id->recv_cq->channel->fd, F_GETFL) | O_NONBLOCK); + RDMA_CM_CHECK(fcntl, id->send_cq->channel->fd, F_SETFL, + fcntl(id->send_cq->channel->fd, F_GETFL) | O_NONBLOCK); + + // Allocate a new control channel to be used with this connection + auto channel = std::make_unique(id); + auto iter = established_.emplace( + peer_id, new Connection{my_id_, peer_id, std::move(channel)}); + ROME_ASSERT(iter.second, "Unexepected error"); + auto* new_conn = established_[peer_id].get(); + Release(); + return new_conn; + } + case RDMA_CM_EVENT_ADDR_RESOLVED: + ROME_WARN("Got addr resolved..."); + RDMA_CM_CHECK(rdma_ack_cm_event, event); + break; + default: { + auto cm_event = event->event; + RDMA_CM_CHECK(rdma_ack_cm_event, event); + backoff_us_ = + backoff_us_ > 0 + ? std::min((backoff_us_ + (100 * my_id_)) * 2, kMaxBackoffUs) + : kMinBackoffUs; + Release(); + rdma_destroy_ep(id); + rdma_destroy_event_channel(event_channel); + if (cm_event == RDMA_CM_EVENT_REJECTED) { + std::this_thread::sleep_for(std::chrono::microseconds(backoff_us_)); + return util::UnavailableErrorBuilder() + << "Connection request rejected"; + } + return InternalErrorBuilder() + << "Got unexpected event: " << rdma_event_str(cm_event); + } + } + } + } else { + return util::UnavailableErrorBuilder() << "Lock acquisition failed"; + } +} + +template +absl::StatusOr::conn_type*> +ConnectionManager::GetConnection(uint32_t peer_id) { + while (!Acquire(my_id_)) { + std::this_thread::yield(); + } + auto conn = established_.find(peer_id); + if (conn != established_.end()) { + auto result = conn->second.get(); + Release(); + return result; + } else { + Release(); + return util::NotFoundErrorBuilder() << "Connection not found: " << peer_id; + } +} + +} // namespace rome::rdma \ No newline at end of file diff --git a/include/rome/rdma/memory_pool/memory_pool.h b/include/rome/rdma/memory_pool/memory_pool.h new file mode 100644 index 0000000..b5f1413 --- /dev/null +++ b/include/rome/rdma/memory_pool/memory_pool.h @@ -0,0 +1,230 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "protos/rdma.pb.h" +#include "remote_ptr.h" +#include "rome/metrics/summary.h" +#include "rome/rdma/channel/twosided_messenger.h" +#include "rome/rdma/connection_manager/connection.h" +#include "rome/rdma/connection_manager/connection_manager.h" +#include "rome/rdma/rmalloc/rmalloc.h" + +namespace rome::rdma { + +class MemoryPool { +#ifndef ROME_MEMORY_POOL_MESSENGER_CAPACITY + static constexpr size_t kMemoryPoolMessengerCapacity = 1 << 12; +#else + static constexpr size_t kMemoryPoolMessengerCapacity = + ROME_MEMORY_POOL_MESSENGER_CAPACITY; +#endif +#ifndef ROME_MEMORY_POOL_MESSAGE_SIZE + static constexpr size_t kMemoryPoolMessageSize = 1 << 8; +#else + static constexpr size_t kMemoryPoolMessageSize = + ROME_MEMORY_POOL_MESSAGE_SIZE; +#endif + public: + typedef RdmaChannel, + EmptyRdmaAccessor> + channel_type; + typedef ConnectionManager cm_type; + typedef cm_type::conn_type conn_type; + + struct Peer { + uint16_t id; + std::string address; + uint16_t port; + + Peer() : Peer(0, "", 0) {} + Peer(uint16_t id, std::string address, uint16_t port) + : id(id), address(address), port(port) {} + }; + + struct conn_info_t { + conn_type *conn; + uint32_t rkey; + uint32_t lkey; + }; + + inline MemoryPool( + const Peer &self, + std::unique_ptr> connection_manager); + + class DoorbellBatch { + public: + ~DoorbellBatch() { + delete wrs_; + delete[] sges_; + } + + explicit DoorbellBatch(const conn_info_t &conn_info, int capacity) + : conn_info_(conn_info), capacity_(capacity) { + wrs_ = new ibv_send_wr[capacity]; + sges_ = new ibv_sge *[capacity]; + std::memset(wrs_, 0, sizeof(ibv_send_wr) * capacity); + wrs_[capacity - 1].send_flags = IBV_SEND_SIGNALED; + for (auto i = 1; i < capacity; ++i) { + wrs_[i - 1].next = &wrs_[i]; + } + } + + std::pair Add(int num_sge = 1) { + if (size_ == capacity_) return {nullptr, nullptr}; + auto *sge = new ibv_sge[num_sge]; + sges_[size_] = sge; + auto wr = &wrs_[size_]; + wr->num_sge = num_sge; + wr->sg_list = sge; + std::memset(sge, 0, sizeof(*sge) * num_sge); + ++size_; + return {wr, sge}; + } + + void SetKillSwitch(std::atomic *kill_switch) { + kill_switch_ = kill_switch; + } + + ibv_send_wr *GetWr(int i) { return &wrs_[i]; } + + inline int size() const { return size_; } + inline int capacity() const { return capacity_; } + inline conn_info_t conn_info() const { return conn_info_; } + inline bool is_mortal() const { return kill_switch_ != nullptr; } + + friend class MemoryPool; + + private: + conn_info_t conn_info_; + + int size_ = 0; + const int capacity_; + + ibv_send_wr *wrs_; + ibv_sge **sges_; + std::atomic *kill_switch_ = nullptr; + }; + + class DoorbellBatchBuilder { + public: + DoorbellBatchBuilder(MemoryPool *pool, uint16_t id, int num_ops = 1) + : pool_(pool) { + batch_ = std::make_unique(pool->conn_info(id), num_ops); + } + + template + remote_ptr AddRead(remote_ptr rptr, bool fence = false, + remote_ptr prealloc = remote_nullptr); + + template + remote_ptr AddPartialRead(remote_ptr ptr, size_t offset, size_t bytes, + bool fence, + remote_ptr prealloc = remote_nullptr); + + template + void AddWrite(remote_ptr rptr, const T &t, bool fence = false); + + template + void AddWrite(remote_ptr rptr, remote_ptr prealloc = remote_nullptr, + bool fence = false); + + template + void AddWriteBytes(remote_ptr rptr, remote_ptr prealloc, int bytes, + bool fence = false); + + void AddKillSwitch(std::atomic *kill_switch) { + batch_->SetKillSwitch(kill_switch); + } + + std::unique_ptr Build(); + + private: + template + void AddReadInternal(remote_ptr rptr, size_t offset, size_t bytes, + size_t chunk, bool fence, remote_ptr prealloc); + std::unique_ptr batch_; + MemoryPool *pool_; + }; + + MemoryPool(const MemoryPool &) = delete; + MemoryPool(MemoryPool &&) = delete; + + // Getters. + cm_type *connection_manager() const { return connection_manager_.get(); } + rome::metrics::MetricProto rdma_per_read_proto() { + return rdma_per_read_.ToProto(); + } + conn_info_t conn_info(uint16_t id) const { return conn_info_.at(id); } + + inline absl::Status Init(uint32_t capacity, const std::vector &peers); + + template + remote_ptr Allocate(size_t size = 1); + + template + void Deallocate(remote_ptr p, size_t size = 1); + + void Execute(DoorbellBatch *batch); + + template + remote_ptr Read(remote_ptr ptr, remote_ptr prealloc = remote_nullptr, + std::atomic *kill = nullptr); + + template + remote_ptr PartialRead(remote_ptr ptr, size_t offset, size_t bytes, + remote_ptr prealloc = remote_nullptr); + + template + void Write(remote_ptr ptr, const T &val, + remote_ptr prealloc = remote_nullptr); + + template + T AtomicSwap(remote_ptr ptr, uint64_t swap, uint64_t hint = 0); + + template + T CompareAndSwap(remote_ptr ptr, uint64_t expected, uint64_t swap); + + template + inline remote_ptr GetRemotePtr(const T *ptr) const { + return remote_ptr(self_.id, reinterpret_cast(ptr)); + } + + template + inline remote_ptr GetBaseAddress() const { + return GetRemotePtr(reinterpret_cast(mr_->addr)); + } + + private: + template + inline void ReadInternal(remote_ptr ptr, size_t offset, size_t bytes, + size_t chunk_size, remote_ptr prealloc, + std::atomic *kill = nullptr); + + Peer self_; + + volatile uint64_t *prev_ = nullptr; + + std::unique_ptr> connection_manager_; + std::unique_ptr rdma_memory_; + ibv_mr *mr_; + + std::unordered_map conn_info_; + ibv_send_wr send_wr_{}; + + rome::metrics::Summary rdma_per_read_; +}; + +} // namespace rome::rdma + +#include "memory_pool_impl.h" \ No newline at end of file diff --git a/include/rome/rdma/memory_pool/memory_pool_impl.h b/include/rome/rdma/memory_pool/memory_pool_impl.h new file mode 100644 index 0000000..afcd305 --- /dev/null +++ b/include/rome/rdma/memory_pool/memory_pool_impl.h @@ -0,0 +1,393 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "memory_pool.h" +#include "rome/util/thread_util.h" + +namespace rome::rdma { + +template +remote_ptr MemoryPool::DoorbellBatchBuilder::AddRead( + remote_ptr rptr, bool fence, remote_ptr prealloc) { + auto local = (prealloc == remote_nullptr) ? pool_->Allocate() : prealloc; + AddReadInternal(rptr, 0, sizeof(T), sizeof(T), fence, local); + return local; +} + +template +remote_ptr MemoryPool::DoorbellBatchBuilder::AddPartialRead( + remote_ptr rptr, size_t offset, size_t bytes, bool fence, + remote_ptr prealloc) { + auto local = (prealloc == remote_nullptr) ? pool_->Allocate() : prealloc; + AddReadInternal(rptr, offset, bytes, bytes, fence, local); + return local; +} + +template +void MemoryPool::DoorbellBatchBuilder::AddReadInternal(remote_ptr rptr, + size_t offset, + size_t bytes, + size_t chunk, bool fence, + remote_ptr prealloc) { + const int num_chunks = bytes % chunk ? (bytes / chunk) + 1 : bytes / chunk; + const size_t remainder = bytes % chunk; + const bool is_multiple = remainder == 0; + + T *local = std::to_address(prealloc); + for (int i = 0; i < num_chunks; ++i) { + auto wr_sge = batch_->Add(); + auto wr = wr_sge.first; + auto sge = wr_sge.second; + + auto chunk_offset = offset + i * chunk; + sge->addr = reinterpret_cast(local) + chunk_offset; + if (is_multiple) { + sge->length = chunk; + } else { + sge->length = (i == num_chunks - 1 ? remainder : chunk); + } + sge->lkey = batch_->conn_info().lkey; + + wr->opcode = IBV_WR_RDMA_READ; + if (fence) wr->send_flags |= IBV_SEND_FENCE; + wr->wr.rdma.remote_addr = rptr.address() + chunk_offset; + wr->wr.rdma.rkey = batch_->conn_info().rkey; + } +} + +template +void MemoryPool::DoorbellBatchBuilder::AddWrite(remote_ptr rptr, + const T &value, bool fence) { + auto local = pool_->Allocate(); + std::memcpy(std::to_address(local), &value, sizeof(value)); + AddWrite(rptr, local, fence); +} + +template +void MemoryPool::DoorbellBatchBuilder::AddWrite(remote_ptr rptr, + remote_ptr prealloc, + bool fence) { + auto wr_sge = batch_->Add(); + ibv_send_wr *wr = wr_sge.first; + ibv_sge *sge = wr_sge.second; + + sge->addr = (uint64_t)std::to_address(prealloc); + sge->length = sizeof(T); + sge->lkey = batch_->conn_info().lkey; + + wr->opcode = IBV_WR_RDMA_WRITE; + if (fence) wr->send_flags |= IBV_SEND_FENCE; + wr->wr.rdma.remote_addr = (uint64_t)std::to_address(rptr); + wr->wr.rdma.rkey = batch_->conn_info().rkey; +} + +template +void MemoryPool::DoorbellBatchBuilder::AddWriteBytes(remote_ptr rptr, + remote_ptr prealloc, + int bytes, bool fence) { + auto wr_sge = batch_->Add(); + ibv_send_wr *wr = wr_sge.first; + ibv_sge *sge = wr_sge.second; + + sge->addr = (uint64_t)std::to_address(prealloc); + sge->length = bytes; + sge->lkey = batch_->conn_info().lkey; + + wr->opcode = IBV_WR_RDMA_WRITE; + if (fence) wr->send_flags |= IBV_SEND_FENCE; + wr->wr.rdma.remote_addr = (uint64_t)std::to_address(rptr); + wr->wr.rdma.rkey = batch_->conn_info().rkey; +} + +inline std::unique_ptr +MemoryPool::DoorbellBatchBuilder::Build() { + const int size = batch_->size(); + const int capcity = batch_->capacity(); + ROME_ASSERT(size > 0, "Cannot build an empty batch."); + ROME_ASSERT(size == capcity, "Batch must be full"); + for (int i = 0; i < size; ++i) { + batch_->wrs_[i].wr_id = batch_->wrs_[i].wr.rdma.remote_addr; + } + return std::move(batch_); +} + +MemoryPool::MemoryPool( + const Peer &self, + std::unique_ptr> connection_manager) + : self_(self), + connection_manager_(std::move(connection_manager)), + rdma_per_read_("rdma_per_read", "ops", 10000) {} + +absl::Status MemoryPool::Init(uint32_t capacity, + const std::vector &peers) { + auto status = connection_manager_->Start(self_.address, self_.port); + ROME_CHECK_OK(ROME_RETURN(status), status); + rdma_memory_ = std::make_unique( + capacity + sizeof(uint64_t), connection_manager_->pd()); + mr_ = rdma_memory_->mr(); + + auto alloc = rdma_allocator(rdma_memory_.get()); + prev_ = alloc.allocate(); + + for (const auto &p : peers) { + auto connected = connection_manager_->Connect(p.id, p.address, p.port); + while (absl::IsUnavailable(connected.status())) { + connected = connection_manager_->Connect(p.id, p.address, p.port); + } + ROME_CHECK_OK(ROME_RETURN(connected.status()), connected); + } + + RemoteObjectProto rm_proto; + rm_proto.set_rkey(mr_->rkey); + rm_proto.set_raddr(reinterpret_cast(mr_->addr)); + for (const auto &p : peers) { + auto conn = VALUE_OR_DIE(connection_manager_->GetConnection(p.id)); + status = conn->channel()->Send(rm_proto); + ROME_CHECK_OK(ROME_RETURN(status), status); + } + + for (const auto &p : peers) { + auto conn = VALUE_OR_DIE(connection_manager_->GetConnection(p.id)); + auto got = conn->channel()->TryDeliver(); + while (!got.ok() && got.status().code() == absl::StatusCode::kUnavailable) { + got = conn->channel()->TryDeliver(); + } + ROME_CHECK_OK(ROME_RETURN(got.status()), got); + conn_info_.emplace(p.id, conn_info_t{conn, got->rkey(), mr_->lkey}); + } + return absl::OkStatus(); +} + +template +remote_ptr MemoryPool::Allocate(size_t size) { + return remote_ptr(self_.id, + rdma_allocator(rdma_memory_.get()).allocate(size)); +} + +template +void MemoryPool::Deallocate(remote_ptr p, size_t size) { + ROME_ASSERT(p.id() == self_.id, + "Alloc/dealloc on remote node not implemented..."); + rdma_allocator(rdma_memory_.get()).deallocate(std::to_address(p), size); +} + +inline void MemoryPool::Execute(DoorbellBatch *batch) { + ibv_send_wr *bad; + auto *conn = batch->conn_info().conn; + RDMA_CM_ASSERT(ibv_post_send, conn->id()->qp, batch->wrs_, &bad); + + int poll; + ibv_wc wc; + while ((poll = ibv_poll_cq(conn->id()->send_cq, 1, &wc)) == 0) { + if (batch->is_mortal() && *kill) return; + cpu_relax(); + } + ROME_ASSERT(poll == 1 && wc.status == IBV_WC_SUCCESS, + "ibv_poll_cq(): {} (dest={})", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status)), + remote_ptr(wc.wr_id)); +} + +template +remote_ptr MemoryPool::Read(remote_ptr ptr, remote_ptr prealloc, + std::atomic *kill) { + if (prealloc == remote_nullptr) prealloc = Allocate(); + ReadInternal(ptr, 0, sizeof(T), sizeof(T), prealloc, kill); + return prealloc; +} + +template +remote_ptr MemoryPool::PartialRead(remote_ptr ptr, size_t offset, + size_t bytes, remote_ptr prealloc) { + if (prealloc == remote_nullptr) prealloc = Allocate(); + ReadInternal(ptr, offset, bytes, sizeof(T), prealloc); + return prealloc; +} + +template +void MemoryPool::ReadInternal(remote_ptr ptr, size_t offset, size_t bytes, + size_t chunk_size, remote_ptr prealloc, + std::atomic *kill) { + const int num_chunks = + bytes % chunk_size ? (bytes / chunk_size) + 1 : bytes / chunk_size; + const size_t remainder = bytes % chunk_size; + const bool is_multiple = remainder == 0; + + auto info = conn_info_.at(ptr.id()); + + T *local = std::to_address(prealloc); + ibv_sge sges[num_chunks]; + ibv_send_wr wrs[num_chunks]; + for (int i = 0; i < num_chunks; ++i) { + auto chunk_offset = offset + i * chunk_size; + sges[i].addr = reinterpret_cast(local) + chunk_offset; + if (is_multiple) { + sges[i].length = chunk_size; + } else { + sges[i].length = (i == num_chunks - 1 ? remainder : chunk_size); + } + sges[i].lkey = mr_->lkey; + + wrs[i].num_sge = 1; + wrs[i].sg_list = &sges[i]; + wrs[i].opcode = IBV_WR_RDMA_READ; + wrs[i].send_flags = IBV_SEND_FENCE; + if (i == num_chunks - 1) wrs[i].send_flags |= IBV_SEND_SIGNALED; + wrs[i].wr.rdma.remote_addr = ptr.address() + chunk_offset; + wrs[i].wr.rdma.rkey = info.rkey; + wrs[i].next = (i != num_chunks - 1 ? &wrs[i + 1] : nullptr); + } + + ibv_send_wr *bad; + RDMA_CM_ASSERT(ibv_post_send, info.conn->id()->qp, wrs, &bad); + ibv_wc wc; + int poll = 0; + if (kill == nullptr) { + for (; poll == 0; poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc)) + ; + ROME_ASSERT( + poll == 1 && wc.status == IBV_WC_SUCCESS, "ibv_poll_cq(): {} @ {}", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status)), ptr); + } else { + for (; poll == 0 && !(*kill); + poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc)) + ; + if (!(*kill) && (poll != 1 || wc.status != IBV_WC_SUCCESS)) { + ROME_FATAL("ibv_poll_cq(): {}", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status))); + } + } + rdma_per_read_ << num_chunks; +} + +template +void MemoryPool::Write(remote_ptr ptr, const T &val, + remote_ptr prealloc) { + ROME_DEBUG("Write: {:x} @ {}", (uint64_t)val, ptr); + auto info = conn_info_.at(ptr.id()); + + T *local; + if (prealloc == remote_nullptr) { + auto alloc = rdma_allocator(rdma_memory_.get()); + local = alloc.allocate(); + ROME_DEBUG("Allocated memory for Write: {} bytes @ 0x{:x}", sizeof(T), + (uint64_t)local); + } else { + local = std::to_address(prealloc); + ROME_DEBUG("Preallocated memory for Write: {} bytes @ 0x{:x}", sizeof(T), + (uint64_t)local); + } + + ROME_ASSERT((uint64_t)local != ptr.address(), "WTF"); + std::memset(local, 0, sizeof(T)); + *local = val; + ibv_sge sge{}; + sge.addr = reinterpret_cast(local); + sge.length = sizeof(T); + sge.lkey = mr_->lkey; + + send_wr_.num_sge = 1; + send_wr_.sg_list = &sge; + send_wr_.opcode = IBV_WR_RDMA_WRITE; + send_wr_.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + send_wr_.wr.rdma.remote_addr = ptr.address(); + send_wr_.wr.rdma.rkey = info.rkey; + + ibv_send_wr *bad = nullptr; + RDMA_CM_ASSERT(ibv_post_send, info.conn->id()->qp, &send_wr_, &bad); + ibv_wc wc; + auto poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + while (poll == 0 || (poll < 0 && errno == EAGAIN)) { + poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + } + + if (prealloc == remote_nullptr) { + auto alloc = rdma_allocator(rdma_memory_.get()); + alloc.deallocate(local); + } + ROME_ASSERT(poll == 1 && wc.status == IBV_WC_SUCCESS, + "ibv_poll_cq(): {} ({})", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status)), + (std::stringstream() << ptr).str()); +} + +template +T MemoryPool::AtomicSwap(remote_ptr ptr, uint64_t swap, uint64_t hint) { + static_assert(sizeof(T) == 8); + auto info = conn_info_.at(ptr.id()); + + ibv_sge sge{}; + sge.addr = reinterpret_cast(prev_); + sge.length = sizeof(uint64_t); + sge.lkey = mr_->lkey; + + send_wr_.num_sge = 1; + send_wr_.sg_list = &sge; + send_wr_.opcode = IBV_WR_ATOMIC_CMP_AND_SWP; + send_wr_.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + send_wr_.wr.atomic.remote_addr = ptr.address(); + send_wr_.wr.atomic.rkey = info.rkey; + send_wr_.wr.atomic.compare_add = hint; + send_wr_.wr.atomic.swap = swap; + + ibv_send_wr *bad = nullptr; + while (true) { + RDMA_CM_ASSERT(ibv_post_send, info.conn->id()->qp, &send_wr_, &bad); + ibv_wc wc; + auto poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + while (poll == 0 || (poll < 0 && errno == EAGAIN)) { + poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + } + ROME_ASSERT(poll == 1 && wc.status == IBV_WC_SUCCESS, "ibv_poll_cq(): {}", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status))); + + ROME_DEBUG("Swap: expected={:x}, swap={:x}, prev={:x} (id={})", + send_wr_.wr.atomic.compare_add, (uint64_t)swap, *prev_, + self_.id); + if (*prev_ == send_wr_.wr.atomic.compare_add) break; + send_wr_.wr.atomic.compare_add = *prev_; + }; + return T(*prev_); +} + +template +T MemoryPool::CompareAndSwap(remote_ptr ptr, uint64_t expected, + uint64_t swap) { + static_assert(sizeof(T) == 8); + auto info = conn_info_.at(ptr.id()); + + ibv_sge sge{}; + sge.addr = reinterpret_cast(prev_); + sge.length = sizeof(uint64_t); + sge.lkey = mr_->lkey; + + send_wr_.num_sge = 1; + send_wr_.sg_list = &sge; + send_wr_.opcode = IBV_WR_ATOMIC_CMP_AND_SWP; + send_wr_.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; + send_wr_.wr.atomic.remote_addr = ptr.address(); + send_wr_.wr.atomic.rkey = info.rkey; + send_wr_.wr.atomic.compare_add = expected; + send_wr_.wr.atomic.swap = swap; + + ibv_send_wr *bad = nullptr; + RDMA_CM_ASSERT(ibv_post_send, info.conn->id()->qp, &send_wr_, &bad); + ibv_wc wc; + auto poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + while (poll == 0 || (poll < 0 && errno == EAGAIN)) { + poll = ibv_poll_cq(info.conn->id()->send_cq, 1, &wc); + } + ROME_ASSERT(poll == 1 && wc.status == IBV_WC_SUCCESS, "ibv_poll_cq(): {}", + (poll < 0 ? strerror(errno) : ibv_wc_status_str(wc.status))); + ROME_DEBUG("CompareAndSwap: expected={:x}, swap={:x}, actual={:x} (id={})", + expected, swap, *prev_, static_cast(self_.id)); + return T(*prev_); +} + +} // namespace rome::rdma \ No newline at end of file diff --git a/include/rome/rdma/memory_pool/remote_ptr.h b/include/rome/rdma/memory_pool/remote_ptr.h new file mode 100644 index 0000000..a577e40 --- /dev/null +++ b/include/rome/rdma/memory_pool/remote_ptr.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace rome::rdma { + +template +class remote_ptr; +struct nullptr_type {}; +typedef remote_ptr remote_nullptr_t; + +template +class remote_ptr { + public: + using element_type = T; + using pointer = T *; + using reference = T &; + using id_type = uint16_t; + using address_type = uint64_t; + + template + using rebind = remote_ptr; + + // Constructors + constexpr remote_ptr() : raw_(0) {} + explicit remote_ptr(uint64_t raw) : raw_(raw) {} + remote_ptr(uint64_t id, T *address) + : remote_ptr(id, reinterpret_cast(address)) {} + remote_ptr(id_type id, uint64_t address) + : raw_((((uint64_t)id) << kAddressBits) | (address & kAddressBitmask)) {} + + // Copy and Move + template >> + remote_ptr(const remote_ptr &p) : raw_(p.raw_) {} + template >> + remote_ptr(remote_ptr &&p) : raw_(p.raw_) {} + constexpr remote_ptr(const remote_nullptr_t &n) : raw_(0) {} + + // Getters + uint64_t id() const { return (raw_ & kIdBitmask) >> kAddressBits; } + uint64_t address() const { return raw_ & kAddressBitmask; } + uint64_t raw() const { return raw_; } + + // Assignment + void operator=(const remote_ptr &p) volatile; + template >> + void operator=(const remote_nullptr_t &n) volatile; + + // Increment operator + remote_ptr &operator+=(size_t s); + remote_ptr &operator++(); + remote_ptr operator++(int); + + // Conversion operators + explicit operator uint64_t() const; + template + explicit operator remote_ptr() const; + + // Pointer-like functions + static constexpr element_type *to_address(const remote_ptr &p); + static constexpr remote_ptr pointer_to(element_type &r); + pointer get() const { return (element_type *)address(); } + pointer operator->() const noexcept { return (element_type *)address(); } + reference operator*() const noexcept { return *((element_type *)address()); } + + // Stream operator + template + friend std::ostream &operator<<(std::ostream &os, const remote_ptr &p); + + // Equivalence + bool operator==(const volatile remote_nullptr_t &n) const volatile; + bool operator==(remote_ptr &n) { return raw_ == n.raw_; } + template + friend bool operator==(remote_ptr &p1, remote_ptr &p2); + template + friend bool operator==(const volatile remote_ptr &p1, + const volatile remote_ptr &p2); + + bool operator<(const volatile remote_ptr &p) { return raw_ < p.raw_; } + friend bool operator<(const volatile remote_ptr &p1, + const volatile remote_ptr &p2) { + return p1.raw() < p2.raw(); + } + + private: + static inline constexpr uint64_t bitsof(const uint32_t &bytes) { + return bytes * 8; + } + + static constexpr auto kAddressBits = + (bitsof(sizeof(uint64_t))) - bitsof(sizeof(id_type)); + static constexpr auto kAddressBitmask = ((1ul << kAddressBits) - 1); + static constexpr auto kIdBitmask = (uint64_t)(-1) ^ kAddressBitmask; + + uint64_t raw_; +}; + +} // namespace rome::rdma + +template +struct fmt::formatter<::rome::rdma::remote_ptr> { + typedef ::rome::rdma::remote_ptr remote_ptr; + constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) { + return ctx.end(); + } + + template + auto format(const remote_ptr &input, FormatContext &ctx) + -> decltype(ctx.out()) { + return format_to(ctx.out(), "(id={}, address=0x{:x})", input.id(), + input.address()); + } +}; + +template +struct std::hash> { + std::size_t operator()(const rome::rdma::remote_ptr &ptr) const { + return std::hash()(static_cast(ptr)); + } +}; + +#include "remote_ptr_impl.h" \ No newline at end of file diff --git a/include/rome/rdma/memory_pool/remote_ptr_impl.h b/include/rome/rdma/memory_pool/remote_ptr_impl.h new file mode 100644 index 0000000..f67f904 --- /dev/null +++ b/include/rome/rdma/memory_pool/remote_ptr_impl.h @@ -0,0 +1,78 @@ +#include "remote_ptr.h" + +namespace rome::rdma { + +constexpr remote_nullptr_t remote_nullptr{}; + +template +void remote_ptr::operator=(const remote_ptr& p) volatile { + raw_ = p.raw_; +} + +template +template >> +void remote_ptr::operator=(const remote_nullptr_t& n) volatile { + raw_ = 0; +} + +template +remote_ptr& remote_ptr::operator+=(size_t s) { + const auto address = (raw_ + (sizeof(element_type) * s)) & kAddressBitmask; + raw_ = (raw_ & kIdBitmask) | address; + return *this; +} + +template +remote_ptr& remote_ptr::operator++() { + *this += 1; + return *this; +} + +template +remote_ptr remote_ptr::operator++(int) { + remote_ptr prev = *this; + *this += 1; + return prev; +} + +template +remote_ptr::operator uint64_t() const { + return raw_; +} + +template +template +remote_ptr::operator remote_ptr() const { + return remote_ptr(raw_); +} + +template +/* static */ constexpr typename remote_ptr::element_type* +remote_ptr::to_address(const remote_ptr& p) { + return (element_type*)p.address(); +} + +template +/* static */ constexpr remote_ptr remote_ptr::pointer_to(T& p) { + return remote_ptr(-1, &p); +} + +template +std::ostream& operator<<(std::ostream& os, const remote_ptr& p) { + return os << ""; +} + +template +bool remote_ptr::operator==(const volatile remote_nullptr_t& n) const + volatile { + return raw_ == 0; +} + +template +bool operator==(const volatile remote_ptr& p1, + const volatile remote_ptr& p2) { + return p1.raw_ == p2.raw_; +} + +} // namespace rome::rdma \ No newline at end of file diff --git a/include/rome/rdma/rmalloc/rmalloc.h b/include/rome/rdma/rmalloc/rmalloc.h new file mode 100755 index 0000000..9a68154 --- /dev/null +++ b/include/rome/rdma/rmalloc/rmalloc.h @@ -0,0 +1,159 @@ +#pragma once + +#include +#include +#include +#include + +#include "rome/rdma/rdma_memory.h" +#include "rome/util/memory_resource.h" + +template class std::unordered_map>; + +namespace rome::rdma { + +// Specialization of a `memory_resource` that wraps RDMA accessible memory. +class rdma_memory_resource : public util::pmr::memory_resource { + public: + virtual ~rdma_memory_resource() {} + explicit rdma_memory_resource(size_t bytes, ibv_pd *pd) + : rdma_memory_(std::make_unique( + bytes, "/proc/sys/vm/nr_hugepages", pd)), + head_(rdma_memory_->raw() + bytes) { + std::memset(alignments_.data(), 0, sizeof(alignments_)); + ROME_DEBUG("rdma_memory_resource: {} to {} (length={})", + fmt::ptr(rdma_memory_->raw()), fmt::ptr(head_.load()), bytes); + } + + rdma_memory_resource(const rdma_memory_resource &) = delete; + rdma_memory_resource &operator=(const rdma_memory_resource &) = delete; + ibv_mr *mr() const { return rdma_memory_->GetDefaultMemoryRegion(); } + + private: + static constexpr uint8_t kMinSlabClass = 3; + static constexpr uint8_t kMaxSlabClass = 20; + static constexpr uint8_t kNumSlabClasses = kMaxSlabClass - kMinSlabClass + 1; + static constexpr size_t kMaxAlignment = 1 << 8; + static constexpr char kLogTable[256] = { +#define LT(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n + -1, 0, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, + LT(4), LT(5), LT(5), LT(6), LT(6), LT(6), LT(6), LT(7), + LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), + }; + + inline unsigned int UpperLog2(size_t x) { + size_t r; + size_t p, q; + if ((q = x >> 16)) { + r = (p = q >> 8) ? 24 + kLogTable[p] : 16 + kLogTable[q]; + } else { + r = (p = x >> 8) ? 8 + kLogTable[p] : kLogTable[x]; + } + return ((1ul << r) < x) ? r + 1 : r; + } + + // Returns a region of RDMA-accessible memory that satisfies the given memory + // allocation request of `bytes` with `alignment`. First, it checks whether + // there exists a region in `freelists_` that satisfies the request, then it + // attempts to allocate a new region. If the request cannot be satisfied, then + // `nullptr` is returned. + void *do_allocate(size_t bytes, size_t alignment) override { + if (alignment > bytes) bytes = alignment; + auto slabclass = UpperLog2(bytes); + slabclass = std::max(kMinSlabClass, static_cast(slabclass)); + auto slabclass_idx = slabclass - kMinSlabClass; + ROME_ASSERT(slabclass_idx >= 0 && slabclass_idx < kNumSlabClasses, + "Invalid allocation requested: {} bytes", bytes); + ROME_ASSERT(alignment <= kMaxAlignment, "Invalid alignment: {} bytes", + alignment); + + if (alignments_[slabclass_idx] & alignment) { + auto *freelist = &freelists_[slabclass_idx]; + ROME_ASSERT_DEBUG(!freelist->empty(), "Freelist should not be empty"); + for (auto f = freelist->begin(); f != freelist->end(); ++f) { + if (f->first >= alignment) { + auto *ptr = f->second; + f = freelist->erase(f); + if (f == freelist->end()) alignments_[slabclass_idx] &= ~alignment; + std::memset(ptr, 0, 1 << slabclass); + ROME_TRACE("(Re)allocated {} bytes @ {}", bytes, fmt::ptr(ptr)); + return ptr; + } + } + } + + uint8_t *__e = head_, *__d; + do { + __d = (uint8_t *)((uint64_t)__e & ~(alignment - 1)) - bytes; + if ((void *)(__d) < rdma_memory_->raw()) { + ROME_CRITICAL("OOM!"); + return nullptr; + } + } while (!head_.compare_exchange_strong(__e, __d)); + + ROME_TRACE("Allocated {} bytes @ {}", bytes, fmt::ptr(__d)); + return reinterpret_cast(__d); + } + + void do_deallocate(void *p, size_t bytes, size_t alignment) override { + ROME_TRACE("Deallocating {} bytes @ {}", bytes, fmt::ptr(p)); + auto slabclass = UpperLog2(bytes); + slabclass = std::max(kMinSlabClass, static_cast(slabclass)); + auto slabclass_idx = slabclass - kMinSlabClass; + + alignments_[slabclass_idx] |= alignment; + freelists_[slabclass_idx].emplace_back(alignment, p); + } + + // Only equal if they are the same object. + bool do_is_equal( + const util::pmr::memory_resource &other) const noexcept override { + return this == &other; + } + + std::unique_ptr rdma_memory_; + std::atomic head_; + + // Stores addresses of freed memory for a given slab class. + inline static thread_local std::array alignments_; + inline static thread_local std::array>, + kNumSlabClasses> + freelists_; +}; + +// An allocator wrapping `rdma_memory_resouce` to be used to allocate new +// RDMA-accessible memory. +template +class rdma_allocator { + public: + typedef T value_type; + + rdma_allocator() : memory_resource_(nullptr) {} + rdma_allocator(rdma_memory_resource *memory_resource) + : memory_resource_(memory_resource) {} + rdma_allocator(const rdma_allocator &other) = default; + + template + rdma_allocator(const rdma_allocator &other) noexcept { + memory_resource_ = other.memory_resource(); + } + + rdma_allocator &operator=(const rdma_allocator &) = delete; + + // Getters + rdma_memory_resource *memory_resource() const { return memory_resource_; } + + [[nodiscard]] constexpr T *allocate(std::size_t n = 1) { + return reinterpret_cast(memory_resource_->allocate(sizeof(T) * n, 64)); + } + + constexpr void deallocate(T *p, std::size_t n = 1) { + memory_resource_->deallocate(reinterpret_cast(p), sizeof(T) * n, 64); + } + + private: + rdma_memory_resource *memory_resource_; +}; + +} // namespace rome::rdma \ No newline at end of file diff --git a/include/rome/util/coroutine.h b/include/rome/util/coroutine.h index 78f8963..cea8e61 100644 --- a/include/rome/util/coroutine.h +++ b/include/rome/util/coroutine.h @@ -8,7 +8,6 @@ #include #include "rome/logging/logging.h" -#include "rome/util/coroutine.h" namespace util { using namespace std; diff --git a/include/rome/util/thread_util.h b/include/rome/util/thread_util.h new file mode 100644 index 0000000..9988639 --- /dev/null +++ b/include/rome/util/thread_util.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include + +#define XSTR(x) STR(x) +#define STR(x) #x + +static inline void cpu_relax() { asm volatile("pause\n" ::: "memory"); } + +#define PAD_SIZEOF(type) (*(*(type))(nullptr)) + +#define __PAD(pad, x, prefix, id) [[maybe_unused]] char prefix##id[(pad) - (x)] + +#define CACHELINE_SIZE 64 +#define _CACHELINE_PAD(x, prefix, id) __PAD(CACHELINE_SIZE, x, prefix, id) +#define CACHELINE_PAD(x) _CACHELINE_PAD(x, __pad, __LINE__) +#define CACHELINE_PAD1(a) _CACHELINE_PAD(sizeof(a), __pad, __LINE__) +#define CACHELINE_PAD2(a, b) \ + _CACHELINE_PAD(sizeof(a) + sizeof(b), __pad, __LINE__) +#define CACHELINE_PAD3(a, b, c) \ + _CACHELINE_PAD(sizeof(a) + sizeof(b) + sizeof(c), __pad, __LINE__) +#define CACHELINE_PAD4(a, b, c, d) \ + _CACHELINE_PAD(sizeof(a) + sizeof(b) + sizeof(c) + sizeof(d), __pad, __LINE__) +#define CACHELINE_PAD5(a, b, c, d, e) \ + _CACHELINE_PAD(sizeof(a) + sizeof(b) + sizeof(c) + sizeof(d) + sizeof(e), \ + __pad, __LINE__) + +#define PREFETCH_SIZE 128 +#define _PREFETCH_PAD(x, prefix, id) __PAD(PREFETCH_SIZE, x, prefix, id) +#define PREFETCH_PAD() _PREFETCH_PAD(0, __pad, __LINE__) +#define PREFETCH_PAD1(a) _PREFETCH_PAD(sizeof(a), __pad, __LINE__) +#define PREFETCH_PAD2(a, b) \ + _PREFETCH_PAD(sizeof(a) + sizeof(b), __pad, __LINE__) +#define PREFETCH_PAD3(a, b, c) \ + _PREFETCH_PAD(sizeof(a) + sizeof(b) + sizeof(c), __pad, __LINE__) +#define PREFETCH_PAD4(a, b, c, d) \ + _PREFETCH_PAD(sizeof(a) + sizeof(b) + sizeof(c) + sizeof(d), __pad, __LINE__) +#define PREFETCH_PAD5(a, b, c, d, e) \ + _PREFETCH_PAD(sizeof(a) + sizeof(b) + sizeof(c) + sizeof(d) + sizeof(e), \ + __pad, __LINE__) + +#define EAGER_AWAIT(cond) \ + if (!(cond)) { \ + while (!(cond)) \ + ; \ + } + +#define RELAXED_AWAIT(cond) \ + if (!(cond)) { \ + while (!(cond)) { \ + cpu_relax(); \ + } \ + } + +#define YIELDING_AWAIT(cond) \ + if (!(cond)) { \ + cpu_relax(); \ + while (!(cond)) { \ + std::this_thread::yield(); \ + } \ + } + +// A class to track thread binding and enforce various policies. The basic +// principle is to map a policy to a set of CPUs that threads are then bound to +// in a circular fashion. High-level policies can generate the CPU list. For +// example, NUMA-fill would schedule all cores in a NUMA zone then all +// hyperthreaded cores, before moving on to the next NUMA zone. +// template +// class CpuBinder : public Policy { +// public: +// static CpuBinder& GetInstance() { +// static CpuBinder instance; +// return instance; +// } + +// CpuBinder(const CpuBinder&) = delete; +// void operator=(const CpuBinder&) = delete; + +// // Compile time polymorphism. +// void Bind() { uint32_t cpu = static_cast(this)->GetNextCpu(); } + +// private: +// CpuBinder(); +// }; + +// class NumaFillPolicy { +// public: +// NumaFillPolicy() {} +// void Bind(); +// }; \ No newline at end of file diff --git a/include/rome/util/timing_util.h b/include/rome/util/timing_util.h new file mode 100644 index 0000000..f34604f --- /dev/null +++ b/include/rome/util/timing_util.h @@ -0,0 +1,11 @@ +#pragma once +#include + +static inline uint64_t rdtscp() { + uint32_t lo, hi; + __asm__ __volatile__("mfence\nrdtscp\nlfence\n" + : "=a"(lo), "=d"(hi) + : + : "%ecx"); + return (((uint64_t)hi) << 32) + lo; +} diff --git a/protos/CMakeLists.txt b/protos/CMakeLists.txt index aed7519..c160b8e 100644 --- a/protos/CMakeLists.txt +++ b/protos/CMakeLists.txt @@ -1,5 +1,5 @@ set(Protobuf_IMPORT_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/..) -protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS metrics.proto colosseum.proto testutil.proto) +protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS metrics.proto colosseum.proto testutil.proto rdma.proto) add_library(protos SHARED ${PROTO_SRCS}) add_library(rome::protos ALIAS protos) diff --git a/protos/rdma.proto b/protos/rdma.proto new file mode 100644 index 0000000..25dd906 --- /dev/null +++ b/protos/rdma.proto @@ -0,0 +1,21 @@ +syntax = "proto2"; + +package rome.rdma; + +// Represents a remotely accessible memory region. Used to convey the necessary information for remote nodes to interact with this memory, assuming that they have access to a QP that is connected to it. +message RemoteObjectProto { + // An string identifier for this object. Must be unique among remote objects. + optional string id = 1; + + // Address of first byte in the memory region. + optional uint64 raddr = 2; + + // Size of the memory region + optional uint32 size = 3; + + // Local access key. + optional uint32 lkey = 4; + + // Remote access key. + optional uint32 rkey = 5; +} diff --git a/tests/rome/rdma/CMakeLists.txt b/tests/rome/rdma/CMakeLists.txt index 4cc4618..466722c 100644 --- a/tests/rome/rdma/CMakeLists.txt +++ b/tests/rome/rdma/CMakeLists.txt @@ -9,4 +9,8 @@ add_test_executable(rdma_util_test rdma_util_test.cc) add_test_executable(rdma_memory_test rdma_memory_test.cc) add_test_executable(rdma_device_test rdma_device_test.cc) add_test_executable(rdma_broker_test rdma_broker_test.cc) +add_subdirectory(channel) +add_subdirectory(connection_manager) +add_subdirectory(memory_pool) +add_subdirectory(rmalloc) endif() diff --git a/tests/rome/rdma/channel/CMakeLists.txt b/tests/rome/rdma/channel/CMakeLists.txt new file mode 100644 index 0000000..7619188 --- /dev/null +++ b/tests/rome/rdma/channel/CMakeLists.txt @@ -0,0 +1,5 @@ +if(NOT ${HAVE_RDMA_CARD}) +add_test_executable(twosided_messenger_test twosided_messenger_test.cc DISABLE_TEST) +else() +add_test_executable(twosided_messenger_test twosided_messenger_test.cc) +endif() diff --git a/tests/rome/rdma/channel/twosided_messenger_test.cc b/tests/rome/rdma/channel/twosided_messenger_test.cc new file mode 100644 index 0000000..08e2c8e --- /dev/null +++ b/tests/rome/rdma/channel/twosided_messenger_test.cc @@ -0,0 +1,219 @@ +#include "twosided_messenger.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "rdma_accessor.h" +#include "rdma_channel.h" +#include "protos/testutil.pb.h" +#include "rome/rdma/rdma_broker.h" + +namespace rome::rdma { +namespace { + +using ::util::InternalErrorBuilder; + +constexpr char kServer[] = "10.0.0.1"; +constexpr uint32_t kPort = 18018; +constexpr uint32_t kCapacity = 1UL << 12; +constexpr int32_t kRecvMaxBytes = 64; +constexpr uint32_t kNumWr = kCapacity / kRecvMaxBytes; +const std::string kMessage = "Shhh, it's a message!"; + +using ChannelType = RdmaChannel, + EmptyRdmaAccessor>; + +class FakeRdmaReceiver : public RdmaReceiverInterface { + public: + void OnConnectRequest(rdma_cm_id* id, rdma_cm_event* event) override { + ibv_qp_init_attr init_attr; + std::memset(&init_attr, 0, sizeof(init_attr)); + init_attr.cap.max_send_wr = init_attr.cap.max_recv_wr = kNumWr; + init_attr.cap.max_send_sge = init_attr.cap.max_recv_sge = 1; + init_attr.cap.max_inline_data = 0; + init_attr.qp_type = id->qp_type; + RDMA_CM_ASSERT(rdma_create_qp, id, nullptr, &init_attr); + + id_ = id; + channel_ = std::make_unique(id_); + + RDMA_CM_ASSERT(rdma_accept, id, nullptr); + rdma_ack_cm_event(event); + } + + void OnEstablished(rdma_cm_id* id, rdma_cm_event* event) override { + rdma_ack_cm_event(event); + } + + void OnDisconnect(rdma_cm_id* id) override { rdma_disconnect(id); } + + absl::StatusOr Deliver() { + return channel_->TryDeliver(); + } + + private: + rdma_cm_id* id_; + std::unique_ptr channel_; +}; + +class FakeRdmaClient { + public: + ~FakeRdmaClient() { rdma_destroy_ep(id_); } + + absl::Status Connect(std::string_view server, uint16_t port) { + rdma_cm_id* id = nullptr; + rdma_addrinfo hints, *resolved; + ibv_qp_init_attr init_attr; + + std::memset(&hints, 0, sizeof(hints)); + hints.ai_port_space = RDMA_PS_TCP; + hints.ai_flags = AI_NUMERICSERV; + int gai_ret = rdma_getaddrinfo( + server.data(), std::to_string(htons(port)).data(), &hints, &resolved); + ROME_CHECK_QUIET( + ROME_RETURN(InternalErrorBuilder() + << "rdma_getaddrinfo(): " << gai_strerror(gai_ret)), + gai_ret == 0); + ROME_ASSERT( + reinterpret_cast(resolved->ai_dst_addr)->sin_port == port, + "Port does not match: expected={}, actual={}", port, + reinterpret_cast(resolved->ai_dst_addr)->sin_port); + + std::memset(&init_attr, 0, sizeof(init_attr)); + init_attr.cap.max_send_wr = init_attr.cap.max_recv_wr = kNumWr; + init_attr.cap.max_send_sge = init_attr.cap.max_recv_sge = 1; + init_attr.cap.max_inline_data = 0; + init_attr.qp_type = ibv_qp_type(resolved->ai_qp_type); + RDMA_CM_CHECK(rdma_create_ep, &id, resolved, nullptr, &init_attr); + + id_ = id; + channel_ = std::make_unique(id_); + + RDMA_CM_CHECK(rdma_connect, id, nullptr); + ROME_INFO( + "Connected to {} (port={})", + inet_ntoa( + reinterpret_cast(rdma_get_peer_addr(id))->sin_addr), + rdma_get_dst_port(id)); + + return absl::OkStatus(); + } + + absl::Status Send(const testutil::RdmaChannelTestProto& proto) { + return channel_->Send(proto); + } + + private: + rdma_cm_id* id_; + std::unique_ptr channel_; +}; + +class RdmaChannelTest : public ::testing::Test { + protected: + RdmaChannelTest() { ROME_INIT_LOG(); } + + void SetUp() { + broker_ = RdmaBroker::Create(kServer, kPort, &receiver_); + ASSERT_NE(broker_, nullptr); + ASSERT_OK(client_.Connect(kServer, kPort)); + } + + FakeRdmaReceiver receiver_; + std::unique_ptr broker_; + FakeRdmaClient client_; +}; + +TEST_F(RdmaChannelTest, Test) { + // Test plan: Do something crazy + RdmaChannel channel(nullptr); + EXPECT_TRUE(true); +} + +TEST_F(RdmaChannelTest, SendOnce) { + // Test plan: Create a channel and test that it can send messages without + // failing. This does not actually check that the message arrives, but that + // the send does not fail. + testutil::RdmaChannelTestProto proto; + *proto.mutable_message() = kMessage; + EXPECT_OK(client_.Send(proto)); +} + +TEST_F(RdmaChannelTest, DeliverOnce) { + testutil::RdmaChannelTestProto expected; + *expected.mutable_message() = kMessage; + ASSERT_OK(client_.Send(expected)); + auto msg_or = receiver_.Deliver(); + ASSERT_OK(msg_or.status()); + EXPECT_EQ(msg_or->message(), kMessage); +} + +TEST_F(RdmaChannelTest, DeliverMultipleWithoutRepopulatingRecvBuffer) { + testutil::RdmaChannelTestProto proto; + *proto.mutable_message() = kMessage; + + for (int i = 0; i < (kCapacity / 2) / kRecvMaxBytes; ++i) { + ASSERT_OK(client_.Send(proto)); + auto proto_or = receiver_.Deliver(); + ASSERT_OK(proto_or.status()); + EXPECT_EQ(proto_or->message(), kMessage); + } +} + +TEST_F(RdmaChannelTest, DeliverMultipleRepopulatingRecvBufferOnce) { + testutil::RdmaChannelTestProto proto; + *proto.mutable_message() = kMessage; + + for (int i = 0; i < ((kCapacity / 2) / kRecvMaxBytes) + 1; ++i) { + ASSERT_OK(client_.Send(proto)); + auto proto_or = receiver_.Deliver(); + ASSERT_OK(proto_or.status()); + EXPECT_EQ(proto_or->message(), kMessage); + } +} + +TEST_F(RdmaChannelTest, DeliverMultipleRepopulatingRecvBufferMultipleTimes) { + testutil::RdmaChannelTestProto proto; + *proto.mutable_message() = kMessage; + + for (int i = 0; i < ((kCapacity / 2) / kRecvMaxBytes) * 10; ++i) { + ASSERT_OK(client_.Send(proto)); + auto proto_or = receiver_.Deliver(); + ASSERT_OK(proto_or.status()); + EXPECT_EQ(proto_or->message(), kMessage); + } +} + +TEST_F(RdmaChannelTest, LargeProtoExhaustsBuffer) { + testutil::RdmaChannelTestProto proto; + + static const char alphabet[] = + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "0123456789"; + + std::random_device rd; + std::default_random_engine eng(rd()); + std::uniform_int_distribution<> dist( + 0, sizeof(alphabet) / sizeof(*alphabet) - 2); + + // Each protobuf string includes at least one byte for its wire-format and its + // field number. It also includes a `varint` that denotes the strings length. + // Therefore, if we set the size of the field to be exactly the number of + // bytes we are able to handle at the receiver, then the total bytes will be + // larger and the `Send` operations should fail. + // See (https://developers.google.com/protocol-buffers/docs/encoding). + static constexpr int kSize = kRecvMaxBytes; + std::string str; + str.reserve(kSize); + std::generate_n(std::back_inserter(str), kSize, [&]() { return dist(eng); }); + + proto.mutable_message()->reserve(str.size()); + proto.mutable_message()->swap(str); + + EXPECT_THAT(client_.Send(proto), + ::testutil::StatusIs(absl::StatusCode::kResourceExhausted)); +} + +} // namespace +} // namespace rome::rdma \ No newline at end of file diff --git a/tests/rome/rdma/connection_manager/CMakeLists.txt b/tests/rome/rdma/connection_manager/CMakeLists.txt new file mode 100644 index 0000000..cea9e71 --- /dev/null +++ b/tests/rome/rdma/connection_manager/CMakeLists.txt @@ -0,0 +1,5 @@ +if(NOT ${HAVE_RDMA_CARD}) +add_test_executable(connection_manager_test connection_manager_test.cc DISABLE_TEST) +else() +add_test_executable(connection_manager_test connection_manager_test.cc) +endif() diff --git a/tests/rome/rdma/connection_manager/connection_manager_test.cc b/tests/rome/rdma/connection_manager/connection_manager_test.cc new file mode 100644 index 0000000..5892449 --- /dev/null +++ b/tests/rome/rdma/connection_manager/connection_manager_test.cc @@ -0,0 +1,231 @@ +#include "connection_manager.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "protos/rdma.pb.h" +#include "protos/testutil.pb.h" +#include "rome/rdma/channel/rdma_messenger.h" +#include "rome/testutil/status_matcher.h" +#include "rome/util/clocks.h" + +namespace rome::rdma { + +using ::util::SystemClock; +using TestProto = ::rome::testutil::ConnectionManagerTestProto; + +static constexpr char kAddress[] = "10.0.0.1"; + +using Channel = RdmaChannel, EmptyRdmaAccessor>; + +class ConnectionManagerTest : public ::testing::Test { + protected: + ConnectionManagerTest() : rd_(), rand_(rd_()), backoff_dist_(0, 10000000) {} + void SetUp() { ROME_INIT_LOG(); } + + absl::StatusOr::conn_type*> Connect( + ConnectionManager* client, uint32_t peer_id, + std::string_view address, uint16_t port) { + int tries = 1; + auto conn_or = client->Connect(peer_id, address, port); + auto backoff = std::chrono::nanoseconds(100); + while (!conn_or.ok() && tries < kMaxRetries) { + ROME_DEBUG(conn_or.status().ToString()); + conn_or = client->Connect(peer_id, address, port); + ++tries; + + std::this_thread::sleep_for(backoff); + backoff = std::max(std::chrono::nanoseconds(10000000), backoff * 2); + } + if (!conn_or.ok()) ROME_ERROR("Retries exceeded: {}", tries); + ROME_DEBUG("Number of tries: {}", tries); + return conn_or; + } + + static constexpr int kMaxRetries = std::numeric_limits::max(); + + std::random_device rd_; + std::default_random_engine rand_; + std::uniform_int_distribution backoff_dist_; +}; + +template +std::string ToString(const std::vector& v) { + std::stringstream ss; + for (auto iter = v.begin(); iter < v.end(); ++iter) { + ss << *iter; + if (&(*iter) != &v.back()) { + ss << ", "; + } + } + return ss.str(); +} + +TEST_F(ConnectionManagerTest, ConstructAndDestroy) { + // Test plan: Ensure that a `ConnectionManager` that is created then + // immediately destroyed does not crash. + static constexpr int kServerId = 1; + ConnectionManager server(kServerId); +} + +TEST_F(ConnectionManagerTest, SingleConnection) { + // Test plan: Something... + static constexpr int kServerId = 1; + static constexpr int kClientId = 42; + + ConnectionManager server(kServerId); + ASSERT_OK(server.Start(kAddress, std::nullopt)); + + ConnectionManager client(kClientId); + ASSERT_OK(client.Start(kAddress, std::nullopt)); + + auto conn_or = Connect(&client, kServerId, server.address(), server.port()); + EXPECT_OK(conn_or); + + conn_or = client.GetConnection(kServerId); + EXPECT_OK(conn_or); + + client.Shutdown(); + server.Shutdown(); +} + +TEST_F(ConnectionManagerTest, ConnectionOnOtherIp) { + // Test plan: Something... + static constexpr int kServerId = 1; + static constexpr int kClientId = 42; + + ConnectionManager server(kServerId); + ASSERT_OK(server.Start("10.0.0.1", std::nullopt)); + + ConnectionManager client(kClientId); + ASSERT_OK(client.Start("10.0.0.2", std::nullopt)); + + auto conn_or = Connect(&client, kServerId, server.address(), server.port()); + EXPECT_OK(conn_or); + conn_or = client.GetConnection(kServerId); + EXPECT_OK(conn_or); + + client.Shutdown(); + server.Shutdown(); +} + +TEST_F(ConnectionManagerTest, LoopbackTest) { + // Test plan: Something... + static constexpr int kId = 1; + ConnectionManager node(kId); + ASSERT_OK(node.Start(kAddress, std::nullopt)); + auto conn_or = Connect(&node, kId, node.address(), node.port()); + EXPECT_OK(conn_or); + conn_or = node.GetConnection(kId); + EXPECT_OK(conn_or); + node.Shutdown(); +} + +// TEST_F(ConnectionManagerTest, MultipleConnections) { +// // Test plan: Something... + +// static constexpr uint32_t kServerId = 100; +// NodeProto server_proto; +// server_proto.set_node_id(kServerId); +// ConnectionManager server(kAddress, std::nullopt, server_proto, &handler_); + +// static constexpr int kNumNodes = 10; +// std::vector> nodes; +// for (int i = 0; i < kNumNodes; ++i) { +// NodeProto node; +// node.set_node_id(i); +// nodes.emplace_back(std::make_unique( +// kAddress, std::nullopt, node, &handler_)); + +// auto conn_or = Connect(&(*nodes[i]), kServerId, kAddress, server.port()); +// ASSERT_OK(conn_or); + +// auto conn = conn_or.value(); + +// TestProto p; +// *p.mutable_message() = std::to_string(i); +// EXPECT_OK(conn->channel()->Send(p)); + +// absl::StatusOr m = conn->channel()->TryDeliver(); +// for (; !m.ok() && m.status().code() == absl::StatusCode::kUnavailable; +// m = conn->channel()->TryDeliver()) +// ; +// EXPECT_OK(m); +// } + +// EXPECT_EQ(server.GetNumConnections(), kNumNodes); +// } + +TEST_F(ConnectionManagerTest, FullyConnected) { + // Test plan: Something... + + static constexpr int kNumNodes = 15; + std::vector>> conns; + std::vector> node_info; + for (int i = 0; i < kNumNodes; ++i) { + conns.emplace_back(std::make_unique>(i)); + ASSERT_OK(conns.back()->Start(kAddress, std::nullopt)); + node_info.push_back({i, conns.back()->port()}); + } + + std::random_device rd; + std::default_random_engine eng(rd()); + std::vector threads; + std::barrier sync(kNumNodes); + for (int i = 0; i < kNumNodes; ++i) { + threads.emplace_back([&conns, &node_info, &eng, &sync, i, this]() { + // Randomize connection order. + std::vector> rand; + std::copy(node_info.begin(), node_info.end(), std::back_inserter(rand)); + std::shuffle(rand.begin(), rand.end(), eng); + + for (auto n : rand) { + auto conn_or = Connect(&(*conns[i]), n.first, kAddress, n.second); + if (!conn_or.ok()) { + ROME_FATAL(conn_or.status().ToString()); + } + } + + sync.arrive_and_wait(); + ASSERT_EQ(conns[i]->GetNumConnections(), kNumNodes); + + for (auto n : rand) { + TestProto p; + *p.mutable_message() = std::to_string(i); + auto conn_or = conns[i]->GetConnection(n.first); + EXPECT_OK(VALUE_OR_DIE(conn_or)->channel()->Send(p)); + } + + for (auto n : rand) { + auto* conn = VALUE_OR_DIE(conns[i]->GetConnection(n.first)); + auto m = conn->channel()->TryDeliver(); + while (!m.ok() && m.status().code() == absl::StatusCode::kUnavailable) { + m = conn->channel()->TryDeliver(); + } + EXPECT_OK(m); + if (m.ok()) { + ROME_DEBUG("[FullyConnected] (Node {}) Got: {}", conn->src_id(), + m->DebugString()); + } + } + + sync.arrive_and_wait(); + ROME_DEBUG("Shutting down: {}", i); + conns[i]->Shutdown(); + }); + } + + // Join threads. + for (auto& t : threads) { + t.join(); + } +} + +} // namespace rome::rdma \ No newline at end of file diff --git a/tests/rome/rdma/memory_pool/CMakeLists.txt b/tests/rome/rdma/memory_pool/CMakeLists.txt new file mode 100644 index 0000000..6dc4997 --- /dev/null +++ b/tests/rome/rdma/memory_pool/CMakeLists.txt @@ -0,0 +1,7 @@ +if(NOT ${HAVE_RDMA_CARD}) +add_test_executable(memory_pool_test memory_pool_test.cc DISABLE_TEST) +add_test_executable(remote_ptr_test remote_ptr_test.cc DISABLE_TEST) +else() +add_test_executable(memory_pool_test memory_pool_test.cc) +add_test_executable(remote_ptr_test remote_ptr_test.cc) +endif() diff --git a/tests/rome/rdma/memory_pool/memory_pool_test.cc b/tests/rome/rdma/memory_pool/memory_pool_test.cc new file mode 100644 index 0000000..1fc0e27 --- /dev/null +++ b/tests/rome/rdma/memory_pool/memory_pool_test.cc @@ -0,0 +1,381 @@ +#include "memory_pool.h" + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "remote_ptr.h" +#include "rome/rdma/connection_manager/connection_manager.h" + +namespace rome::rdma { +namespace { + +constexpr char kIpAddress[] = "10.0.0.1"; +const MemoryPool::Peer kServer = {1, kIpAddress, 18018}; +const MemoryPool::Peer kClient = {2, kIpAddress, 18015}; + +class MemoryPoolTest : public ::testing::Test { + protected: + void SetUp() { ROME_INIT_LOG(); } +}; + +class LoopbackPolicy : public ::testing::Test { + protected: + void SetUp() { + ROME_INIT_LOG(); + using CM = ConnectionManager; + mp_ = std::make_unique(p_, std::make_unique(p_.id)); + ASSERT_OK(mp_->Init(1ul << 12, {p_})); // Set up loopback + } + + template + remote_ptr AllocateClient(size_t s = 1) { + return AllocateServer(s); + } + + template + remote_ptr AllocateServer(size_t s = 1) { + return mp_->Allocate(s); + } + + template + remote_ptr Read(remote_ptr ptr, + remote_ptr prealloc = remote_nullptr) { + return mp_->Read(ptr, prealloc); + } + + template + remote_ptr PartialRead(remote_ptr ptr, size_t offset, size_t bytes, + remote_ptr prealloc = remote_nullptr) { + return mp_->PartialRead(ptr, offset, bytes, prealloc); + } + + template + void Write(remote_ptr ptr, const T& value, + remote_ptr prealloc = remote_nullptr) { + return mp_->Write(ptr, value, prealloc); + } + + template + T AtomicSwap(remote_ptr ptr, uint64_t swap, uint64_t hint = 0) { + return mp_->AtomicSwap(ptr, swap, hint); + } + + template + T CompareAndSwap(remote_ptr ptr, uint64_t expected, uint64_t swap) { + return mp_->CompareAndSwap(ptr, expected, swap); + } + + MemoryPool::DoorbellBatchBuilder CreateDoorbellBatchBuilder(int num_ops) { + return MemoryPool::DoorbellBatchBuilder(mp_.get(), p_.id, num_ops); + } + + void Execute(MemoryPool::DoorbellBatch* batch) { return mp_->Execute(batch); } + + private: + const MemoryPool::Peer p_ = kServer; + std::unique_ptr mp_; +}; + +class ClientServerPolicy : public ::testing::Test { + protected: + void SetUp() { + ROME_INIT_LOG(); + using CM = ConnectionManager; + server_mp_ = + std::make_unique(server_, std::make_unique(server_.id)); + client_mp_ = + std::make_unique(client_, std::make_unique(client_.id)); + // No loopback + + std::thread t([&]() { + ASSERT_OK(this->server_mp_->Init(1ul << 12, {this->client_})); + }); + ASSERT_OK(client_mp_->Init(1ul << 12, {server_})); + t.join(); + } + + template + remote_ptr AllocateClient(size_t s = 1) { + return client_mp_->Allocate(s); + } + + template + remote_ptr AllocateServer(size_t s = 1) { + return server_mp_->Allocate(s); + } + + template + remote_ptr Read(remote_ptr ptr, + remote_ptr prealloc = remote_nullptr) { + return client_mp_->Read(ptr, prealloc); + } + + template + remote_ptr PartialRead(remote_ptr ptr, size_t offset, size_t bytes, + remote_ptr prealloc = remote_nullptr) { + return client_mp_->PartialRead(ptr, offset, bytes, prealloc); + } + + template + void Write(remote_ptr ptr, const T& value, + remote_ptr prealloc = remote_nullptr) { + return client_mp_->Write(ptr, value, prealloc); + } + + template + T AtomicSwap(remote_ptr ptr, uint64_t swap, uint64_t hint = 0) { + return client_mp_->AtomicSwap(ptr, swap, hint); + } + + template + T CompareAndSwap(remote_ptr ptr, uint64_t expected, uint64_t swap) { + return client_mp_->CompareAndSwap(ptr, expected, swap); + } + + MemoryPool::DoorbellBatchBuilder CreateDoorbellBatchBuilder(int num_ops) { + return MemoryPool::DoorbellBatchBuilder(client_mp_.get(), server_.id, + num_ops); + } + + void Execute(MemoryPool::DoorbellBatch* batch) { + return client_mp_->Execute(batch); + } + + private: + const MemoryPool::Peer server_ = kServer; + const MemoryPool::Peer client_ = kClient; + std::unique_ptr server_mp_; + std::unique_ptr client_mp_; +}; + +template +class MemoryPoolTestFixture : public Policy { + protected: + MemoryPoolTestFixture() : Policy() {} +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(MemoryPoolTestFixture, TestTypes); + +TYPED_TEST(MemoryPoolTestFixture, InitTest) { + // Test plan: Do nothing to ensure that setup is done correctly. +} + +TYPED_TEST(MemoryPoolTestFixture, WriteTest) { + // Test plan: Allocate some memory on the server then write to it. + const int kValue = 42; + auto target = TestFixture::template AllocateServer(); + *target = -1; + TestFixture::Write(target, kValue); + EXPECT_EQ(*target, kValue); + + *target = -1; + auto dest = TestFixture::template AllocateClient(); + TestFixture::Write(target, kValue, dest); + EXPECT_EQ(*target, kValue); +} + +TYPED_TEST(MemoryPoolTestFixture, ReadTest) { + // Test plan: Allocate some memory to write to then write to it. + const int kValue = 42; + auto target = TestFixture::template AllocateServer(); + *target = kValue; + auto result = TestFixture::Read(target); + EXPECT_EQ(*result, kValue); + + auto dest = TestFixture::template AllocateClient(); + result = TestFixture::Read(target, dest); + EXPECT_EQ(*result, kValue); + EXPECT_EQ(result, dest); +} + +TYPED_TEST(MemoryPoolTestFixture, AtomicSwapTest) { + // Test plan: Allocate some memory to write to then write to it. + auto target = TestFixture::template AllocateServer(); + for (uint64_t value = 0; value < 1000; ++value) { + *target = value; + EXPECT_EQ(TestFixture::AtomicSwap(target, -1), value); + } +} + +TYPED_TEST(MemoryPoolTestFixture, CompareAndSwapTest) { + // Test plan: Allocate some memory to write to then write to it. + auto target = TestFixture::template AllocateServer(); + for (uint64_t value = 0; value < 1000; ++value) { + *target = value; + auto expected = (value / 2) * 2; // Fails every other attempt + EXPECT_EQ(TestFixture::CompareAndSwap(target, expected, 0), value); + } +} + +template +class PartialReadConfig : public Policy { + private: + static constexpr size_t kStructSize = 256; + + public: + static constexpr size_t kReadSize = S; + PartialReadConfig() : Policy() {} + + struct TestStruct { + char buffer[kStructSize]; + }; +}; + +template +class PartialReadTestFixture : public Config { + protected: + using config = Config; + PartialReadTestFixture() : Config() {} +}; + +using PartialReadTestTypes = + ::testing::Types, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig, + PartialReadConfig>; +TYPED_TEST_SUITE(PartialReadTestFixture, PartialReadTestTypes); + +TYPED_TEST(PartialReadTestFixture, PartialReadTest) { + // Test plan: Given a fixed size struct, fill the local copy with known bytes + // then perform a partial read. Ensure that all expected bytes from the remote + // object contain the expected value. + using type = typename TestFixture::config::TestStruct; + constexpr size_t size = sizeof(type); + auto target = TestFixture::template AllocateServer(); + auto target_addr = std::to_address(target); + + constexpr char kTargetByte = -1; + std::memset(std::to_address(target), kTargetByte, size); + + auto local = TestFixture::template AllocateClient(); + auto local_addr = std::to_address(local); + constexpr char kLocalByte = 0; + for (size_t offset = 0; offset < size - TestFixture::kReadSize; offset += 8) { + std::memset(std::to_address(local), kLocalByte, size); + + TestFixture::PartialRead(target, offset, TestFixture::kReadSize, local); + auto local_offset_addr = reinterpret_cast(local_addr) + offset; + auto target_offset_addr = reinterpret_cast(target_addr) + offset; + EXPECT_EQ(std::memcmp(local_offset_addr, target_offset_addr, + TestFixture::kReadSize), + 0); + + char expected[size]; + std::memset(expected, kLocalByte, size); + if (offset != 0) { + EXPECT_EQ( + std::memcmp(reinterpret_cast(local_addr), expected, offset), + 0); + } + const size_t remainder_size = (size - offset) - TestFixture::kReadSize; + EXPECT_EQ(std::memcmp(reinterpret_cast(local_offset_addr) + + TestFixture::kReadSize, + expected, remainder_size), + 0); + } +} + +template +class DoorbellBatchTestConfig : public Policy { + public: + DoorbellBatchTestConfig() : Policy() {} +}; + +template +class DoorbellBatchTestFixture : public Config { + protected: + using config = Config; + DoorbellBatchTestFixture() : Config() {} +}; + +using DoorbellBatchTestTypes = + ::testing::Types, + DoorbellBatchTestConfig>; +TYPED_TEST_SUITE(DoorbellBatchTestFixture, DoorbellBatchTestTypes); + +TYPED_TEST(DoorbellBatchTestFixture, SingleReadTest) { + const uint64_t kValue = 0xf0f0f0f0f0f0f0f0; + auto src = TestFixture::template AllocateServer(); + *(std::to_address(src)) = kValue; + ASSERT_EQ(*(std::to_address(src)), kValue); + + auto builder = TestFixture::CreateDoorbellBatchBuilder(1); + auto dest = builder.AddRead(src); + auto batch = builder.Build(); + TestFixture::Execute(batch.get()); + + EXPECT_EQ(*(std::to_address(dest)), kValue); +} + +TYPED_TEST(DoorbellBatchTestFixture, SingleWriteTest) { + const uint64_t kValue = 0xf0f0f0f0f0f0f0f0; + auto dest = TestFixture::template AllocateServer(); + *(std::to_address(dest)) = 0; + ASSERT_EQ(*(std::to_address(dest)), 0); + + auto builder = TestFixture::CreateDoorbellBatchBuilder(1); + builder.AddWrite(dest, kValue); + auto batch = builder.Build(); + TestFixture::Execute(batch.get()); + + EXPECT_EQ(*(std::to_address(dest)), kValue); +} + +TYPED_TEST(DoorbellBatchTestFixture, WriteThenReadTest) { + const uint64_t kValue = 0xf0f0f0f0f0f0f0f0; + auto dest = TestFixture::template AllocateServer(); + *(std::to_address(dest)) = 0; + ASSERT_EQ(*(std::to_address(dest)), 0); + + auto builder = TestFixture::CreateDoorbellBatchBuilder(2); + builder.AddWrite(dest, kValue); + auto read = builder.AddRead(dest); + *(std::to_address(read)) = 0; + auto batch = builder.Build(); + TestFixture::Execute(batch.get()); + + EXPECT_EQ(*(std::to_address(dest)), kValue); + EXPECT_EQ(*(std::to_address(read)), kValue); +} + +TYPED_TEST(DoorbellBatchTestFixture, ReuseTest) { + const uint64_t kValue = 0xf0f0f0f0f0f0f0f0; + auto dest = TestFixture::template AllocateServer(); + *dest = 0; + ASSERT_EQ(*(std::to_address(dest)), 0); + + auto builder = TestFixture::CreateDoorbellBatchBuilder(2); + auto src = TestFixture::template AllocateClient(); + *src = kValue; + builder.AddWrite(dest, src); + auto read = builder.AddRead(dest); + *read = 0; + auto batch = builder.Build(); + TestFixture::Execute(batch.get()); + + EXPECT_EQ(*(std::to_address(dest)), kValue); + EXPECT_EQ(*(std::to_address(read)), kValue); + + *src = kValue + 1; + TestFixture::Execute(batch.get()); + EXPECT_EQ(*dest, kValue + 1); + EXPECT_EQ(*read, kValue + 1); +} + +} // namespace +} // namespace rome::rdma \ No newline at end of file diff --git a/tests/rome/rdma/memory_pool/remote_ptr_test.cc b/tests/rome/rdma/memory_pool/remote_ptr_test.cc new file mode 100644 index 0000000..c7c5081 --- /dev/null +++ b/tests/rome/rdma/memory_pool/remote_ptr_test.cc @@ -0,0 +1,58 @@ +#include "remote_ptr.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "rome/logging/logging.h" + +namespace rome::rdma { +namespace { + +class RemotePtrTest : public ::testing::Test { + void SetUp() { ROME_INIT_LOG(); } +}; + +TEST(RemotePtrTest, Test) { + // Test plan: test_plan + remote_ptr p = remote_nullptr; + EXPECT_EQ(p, remote_nullptr); +} + +TEST(RemotePtrTest, Equivalence) { + remote_ptr p1; + EXPECT_TRUE(p1 == remote_nullptr); +} + +TEST(RemotePtrTest, GettersTest) { + remote_ptr p; + p = remote_ptr(1, (uint64_t)0x0fedbeef); + EXPECT_EQ(p.id(), 1); + EXPECT_EQ(p.address(), 0x0fedbeef); + EXPECT_EQ(p.raw(), (1ul << 48) | 0x0fedbeef); +} + +TEST(RemotePtrTest, CopyTest) { + remote_ptr p; + p = remote_ptr(1, (uint64_t)0x0fedbeef); + auto q = p; + EXPECT_EQ(q.id(), 1); + EXPECT_EQ(q.address(), 0x0fedbeef); + EXPECT_EQ(q.raw(), (1ul << 48) | 0x0fedbeef); +} + +TEST(RemotePtrTest, IncrementTest) { + remote_ptr p; + p = remote_ptr(4, (uint64_t)0); + ++p; + EXPECT_EQ(p.address(), sizeof(int)); + + auto q = p++; + EXPECT_EQ(q.address(), sizeof(int)); + EXPECT_EQ(p.address(), 2 * sizeof(int)); + + auto r = (p += 4); + EXPECT_EQ(r.address(), 6 * sizeof(int)); + EXPECT_EQ(p.address(), r.address()); +} + +} // namespace +} // namespace rome::rdma \ No newline at end of file diff --git a/tests/rome/rdma/rmalloc/CMakeLists.txt b/tests/rome/rdma/rmalloc/CMakeLists.txt new file mode 100644 index 0000000..b684c2f --- /dev/null +++ b/tests/rome/rdma/rmalloc/CMakeLists.txt @@ -0,0 +1,5 @@ +if(NOT ${HAVE_RDMA_CARD}) +add_test_executable(rmalloc_test rmalloc_test.cc DISABLE_TEST) +else() +add_test_executable(rmalloc_test rmalloc_test.cc) +endif() diff --git a/tests/rome/rdma/rmalloc/rmalloc_test.cc b/tests/rome/rdma/rmalloc/rmalloc_test.cc new file mode 100644 index 0000000..a974e8c --- /dev/null +++ b/tests/rome/rdma/rmalloc/rmalloc_test.cc @@ -0,0 +1,193 @@ +#include "rmalloc.h" + +#include + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "rome/rdma/connection_manager/connection_manager.h" +#include "rome/rdma/rdma_device.h" +#include "rome/rdma/rdma_util.h" + +namespace rome::rdma { +namespace { + +constexpr uint32_t kArenaCapacity = 1024; + +template +struct MyStruct { + std::array bytes; +}; + +class RdmaAllocatorTest : public ::testing::Test { + protected: + void SetUp() { + ROME_INIT_LOG(); + dev_ = RdmaDevice::Create(RdmaDevice::GetAvailableDevices()->front().first, + std::nullopt); + ASSERT_OK(dev_->CreateProtectionDomain("rdma_allocator")); + pd_ = VALUE_OR_DIE(dev_->GetProtectionDomain("rdma_allocator")); + memory_resource_ = + std::make_unique(kArenaCapacity, pd_); + } + + std::unique_ptr dev_; + std::unique_ptr memory_resource_; + ibv_pd* pd_; +}; + +TEST_F(RdmaAllocatorTest, AllocateSingleUint64) { + // Test plan: Allocate a single `uint64_t` element and check that the returned + // memory is not `nullptr`. + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(); + EXPECT_NE(x, nullptr); +} + +TEST_F(RdmaAllocatorTest, AllocateSingleUint64Repeated) { + // Test plan: Allocate a single `uint64_t` element and check that the returned + // memory is not `nullptr`. + for (int i = 0; i < 10; ++i) { + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(); + EXPECT_NE(x, nullptr); + } +} + +TEST_F(RdmaAllocatorTest, AllocateMultipleUint64) { + // Test plan: Allocate mutliple `uint64_t` elements and check that the + // returned memory is not `nullptr`. + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(10); + EXPECT_NE(x, nullptr); +} + +TEST_F(RdmaAllocatorTest, ReallocateMultipleUint64) { + // Test plan: Allocate mutliple `uint64_t` elements and check that the + // returned memory is not `nullptr`. Then, deallocate them and check that a + // reallocation of the same size returns valid memory. + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(10); + ASSERT_NE(x, nullptr); + alloc.deallocate(x, 10); + x = alloc.allocate(10); + EXPECT_NE(x, nullptr); +} + +TEST_F(RdmaAllocatorTest, AllocateStruct) { + // Test plan: Allocate a struct the size of the memory capacity of the + // underlying `rdma_memory_resource` and check that the returned pointer to + // memory is not `nullptr`. + using TestStruct = MyStruct; + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(); + EXPECT_NE(x, nullptr); +} + +TEST_F(RdmaAllocatorTest, AllocateStructFailure) { + // Test plan: Allocate a struct greater than the size of the memory capacity + // of the underlying `rdma_memory_resource` and check that the returned + // pointer to memory is `nullptr`. In this case, the memory allocation request + // cannot be serviced. + using TestStruct = MyStruct; + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(); + EXPECT_EQ(x, nullptr); +} + +TEST_F(RdmaAllocatorTest, ReallocateStruct) { + // Test plan: Allocate a struct the size of the memory capacity of the + // underlying `rdma_memory_resource` and check that the returned pointer to + // memory is not `nullptr`. Then, deallocate the struct and try again to + // ensure that the memory is pulled from the freelist. Since the struct is the + // memory capacity, a non-freelist allocation would fail due to the bump + // allocation running out of free memory. + using TestStruct = MyStruct; + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(); + ASSERT_NE(x, nullptr); + ASSERT_EQ(sizeof(*x), sizeof(TestStruct)); + + alloc.deallocate(x); + x = alloc.allocate(); + EXPECT_NE(x, nullptr); + EXPECT_EQ(sizeof(*x), sizeof(TestStruct)); +} + +TEST_F(RdmaAllocatorTest, ReallocateDifferentType) { + // Test plan: Allocate a region of memory pointing to `uint8_t` that is + // equivalent to the size of the underlying memory capacity. Then, deallocated + // and create a new allocator that shares the underlying + // `rdma_memory_resource` via a conversion constructor and reallocate a new + // region of memory with the same size but different type. + auto alloc = rdma_allocator(memory_resource_.get()); + auto* x = alloc.allocate(kArenaCapacity); + ASSERT_NE(x, nullptr); + alloc.deallocate(x, kArenaCapacity); + + using TestStruct = MyStruct; + rdma_allocator new_alloc(alloc); + auto* y = new_alloc.allocate(); + EXPECT_NE(y, nullptr); +} + +TEST_F(RdmaAllocatorTest, RemotelyAccessMemory) { + // Test plan: Allocate a region of memory then test that we can remotely + // access it. Exemplifies how to use the allocator in practice. + constexpr int kServerId = 0; + constexpr int kClientId = 1; + constexpr char kAddress[] = "10.0.0.1"; + using TestStruct = MyStruct<512>; + + ConnectionManager> server( + kServerId); + ASSERT_OK(server.Start(kAddress, std::nullopt)); + rdma_memory_resource server_memory_resource(sizeof(TestStruct), server.pd()); + rdma_allocator server_rmalloc(&server_memory_resource); + auto* remote = server_rmalloc.allocate(1); + + ConnectionManager> client( + kClientId); + ASSERT_OK(client.Start(kAddress, std::nullopt)); + rdma_memory_resource client_memory_resource(sizeof(TestStruct), client.pd()); + rdma_allocator client_rmalloc(&client_memory_resource); + auto* local = client_rmalloc.allocate(1); + + auto* client_conn = + VALUE_OR_DIE(client.Connect(kServerId, server.address(), server.port())); + + ibv_sge sge; + std::memset(&sge, 0, sizeof(sge)); + sge.addr = reinterpret_cast(local); + sge.length = sizeof(TestStruct); // Read `num_nodes` nodes. + sge.lkey = client_memory_resource.mr()->lkey; + + ibv_send_wr wr; + std::memset(&wr, 0, sizeof(wr)); + wr.send_flags = IBV_SEND_SIGNALED; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.sg_list = &sge; + wr.wr.rdma.remote_addr = reinterpret_cast(remote); + wr.wr.rdma.rkey = server_memory_resource.mr()->rkey; + + ROME_DEBUG("Accessing {} bytes @ {}", sge.length, fmt::ptr(remote)); + + ibv_send_wr* bad; + RDMA_CM_ASSERT(ibv_post_send, client_conn->id()->qp, &wr, &bad); + + ibv_wc wc; + int ret = ibv_poll_cq(client_conn->id()->send_cq, 1, &wc); + while ((ret < 0 && errno == EAGAIN) || ret == 0) { + ret = ibv_poll_cq(client_conn->id()->send_cq, 1, &wc); + } + + ROME_DEBUG(ibv_wc_status_str(wc.status)); + + EXPECT_EQ(ret, 1); + EXPECT_EQ(wc.status, IBV_WC_SUCCESS); +} + +} // namespace +} // namespace rome::rdma \ No newline at end of file