Skip to content

Commit

Permalink
Merge pull request #208 from pnnl/rdma_get_ops
Browse files Browse the repository at this point in the history
Rdma get ops
  • Loading branch information
VitoCastellana authored Feb 8, 2022
2 parents 85a60bf + c77ac93 commit a425cc2
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 3 deletions.
100 changes: 99 additions & 1 deletion include/shad/data_structures/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ class Array : public AbstractDataStructure<Array<T>> {
template <typename ApplyFunT, typename... Args>
void AsyncApply(rt::Handle &handle, const size_t pos, ApplyFunT &&function,
Args &... args);

template <typename ApplyFunT, typename... Args>
void AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos,
ApplyFunT &&function, uint8_t* result,
uint32_t* resultSize, Args &... args);

/// @brief Applies a user-defined function to every element
/// in the specified range.
Expand Down Expand Up @@ -458,6 +463,21 @@ class Array : public AbstractDataStructure<Array<T>> {
data_[std::get<0>(entry)] = std::get<1>(entry);
}

constexpr void FillPtrs() {
rt::executeOnAll([](const ObjectID &oid) {
auto This = Array<T>::GetPtr(oid);
rt::executeOnAll([](const std::tuple<ObjectID, rt::Locality, T*> &args) {
auto This = Array<T>::GetPtr(std::get<0>(args));

This->ptrs_[(uint32_t)std::get<1>(args)] = std::get<2>(args);
},
std::make_tuple(This->GetGlobalID(), rt::thisLocality(), This->data_.data()));
}, GetGlobalID());
}

void AsyncGetElements(rt::Handle& h, T* local_data,
const uint64_t idx, const uint64_t num_el);

protected:
Array(ObjectID oid, size_t size, const T &initValue)
: oid_(oid),
Expand All @@ -467,7 +487,8 @@ class Array : public AbstractDataStructure<Array<T>> {
: rt::numLocalities() - (size % rt::numLocalities())),
data_(),
dataDistribution_(),
buffers_(oid) {
buffers_(oid),
ptrs_(rt::numLocalities()) {
rt::Locality pivot(pivot_);
size_t start = 0;
size_t chunkSize = size / rt::numLocalities();
Expand Down Expand Up @@ -498,6 +519,7 @@ class Array : public AbstractDataStructure<Array<T>> {
std::vector<T> data_;
std::vector<std::pair<size_t, size_t>> dataDistribution_;
BuffersVector buffers_;
std::vector<T*> ptrs_;

struct InsertAtArgs {
ObjectID oid;
Expand Down Expand Up @@ -600,6 +622,30 @@ class Array : public AbstractDataStructure<Array<T>> {
std::get<4>(tuple), std::make_index_sequence<Size>{});
}

template <typename ApplyFunT, typename... Args, std::size_t... is>
static void AsyncCallApplyWRBFun(rt::Handle &handle, ObjectID &oid, size_t pos,
size_t loffset, ApplyFunT function,
std::tuple<Args...> &args, std::index_sequence<is...>,
uint8_t* result, uint32_t* resultSize) {
// Get a local instance on the remote node.
auto arrayPtr = Array<T>::GetPtr(oid);
T &element = arrayPtr->data_[loffset];
function(handle, pos, element, std::get<is>(args)..., result, resultSize);
}

template <typename Tuple, typename... Args>
static void AsyncApplyWRBFunWrapper(rt::Handle &handle, const Tuple &args,
uint8_t* result, uint32_t* resultSize) {
constexpr auto Size = std::tuple_size<
typename std::decay<decltype(std::get<4>(args))>::type>::value;

Tuple &tuple = const_cast<Tuple &>(args);

AsyncCallApplyWRBFun(handle, std::get<0>(tuple), std::get<1>(tuple),
std::get<2>(tuple), std::get<3>(tuple),
std::get<4>(tuple), std::make_index_sequence<Size>{}, result, resultSize);
}

template <typename ApplyFunT, typename... Args, std::size_t... is>
static void CallForEachInRangeFun(size_t i, ObjectID &oid, size_t pos,
size_t lpos, ApplyFunT function,
Expand Down Expand Up @@ -937,6 +983,25 @@ void Array<T>::AsyncApply(rt::Handle &handle, const size_t pos,
AsyncApplyFunWrapper<ArgsTuple, Args...>, argsTuple);
}

template <typename T>
template <typename ApplyFunT, typename... Args>
void Array<T>::AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos,
ApplyFunT &&function, uint8_t* result,
uint32_t* resultSize, Args &... args) {
auto target = getTargetLocalityFromTargePosition(dataDistribution_, pos);

using FunctionTy = void (*)(rt::Handle &, size_t, T &, Args & ..., uint8_t*, uint32_t*);
FunctionTy fn = std::forward<decltype(function)>(function);
using ArgsTuple =
std::tuple<ObjectID, size_t, size_t, FunctionTy, std::tuple<Args...>>;
ArgsTuple argsTuple{oid_, pos, target.second, fn,
std::tuple<Args...>(args...)};

rt::asyncExecuteAtWithRetBuff(handle, target.first,
AsyncApplyWRBFunWrapper<ArgsTuple, Args...>, argsTuple,
result, resultSize);
}

template <typename T>
template <typename ApplyFunT, typename... Args>
void Array<T>::ForEachInRange(const size_t first, const size_t last,
Expand Down Expand Up @@ -1073,6 +1138,39 @@ void Array<T>::ForEach(ApplyFunT &&function, Args &... args) {
rt::executeOnAll(feLambda, arguments);
}

template <typename T>
void Array<T>::AsyncGetElements(rt::Handle& h, T* local_data,
const uint64_t idx, const uint64_t num_el) {

size_t tgtPos = 0, firstPos = idx;
rt::Locality tgtLoc;
size_t remainingValues = num_el;
size_t chunkSize = 0;
T* tgtAddress;

while (remainingValues > 0) {
if (firstPos < pivot_ * (size_ / rt::numLocalities())) {
tgtLoc = rt::Locality(firstPos / (size_ / rt::numLocalities()));
tgtPos = firstPos % (size_ / rt::numLocalities());
chunkSize =
std::min((size_ / rt::numLocalities() - tgtPos), remainingValues);
} else {
size_t newPos = firstPos - (pivot_ * (size_ / rt::numLocalities()));
tgtLoc =
rt::Locality(pivot_ + newPos / ((size_ / rt::numLocalities() + 1)));
tgtPos = newPos % ((size_ / rt::numLocalities() + 1));
chunkSize =
std::min((size_ / rt::numLocalities() + 1 - tgtPos), remainingValues);
}

tgtAddress = ptrs_[(uint32_t)tgtLoc] + tgtPos;
rt::asyncDma(h, local_data, tgtLoc, tgtAddress, chunkSize);
local_data += chunkSize;
firstPos += chunkSize;
remainingValues -= chunkSize;
}
}

} // namespace shad

#endif // INCLUDE_SHAD_DATA_STRUCTURES_ARRAY_H_
74 changes: 72 additions & 2 deletions test/unit_tests/data_structures/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ TEST_F(ArrayTest, RangedAsyncInsertAndAsyncGet) {
shad::rt::Handle handle;
edsPtr->AsyncInsertAt(handle, 0, inputData_.data(), kArraySize);
shad::rt::waitForCompletion(handle);
;

shad::rt::Handle handle2;
for (size_t i = 0; i < kArraySize; i++) {
Expand All @@ -136,6 +135,36 @@ TEST_F(ArrayTest, RangedAsyncInsertAndAsyncGet) {
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, RangedAsyncInsertAndAsyncGetElements) {
std::vector<size_t> values(kArraySize);

auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);
edsPtr->FillPtrs();
shad::rt::Handle handle;
edsPtr->AsyncInsertAt(handle, 0, inputData_.data(), kArraySize);
shad::rt::waitForCompletion(handle);

shad::rt::Handle handle2;
edsPtr->AsyncGetElements(handle2, values.data(), 0, kArraySize);
shad::rt::waitForCompletion(handle2);

for (size_t i = 0; i < kArraySize; i++) {
ASSERT_EQ(values[i], i + 1);
}

uint64_t to_insert2 = kArraySize/2;
uint64_t idx2 = kArraySize/6;
std::vector<size_t> values2(to_insert2);

edsPtr->AsyncGetElements(handle2, values2.data(), idx2, to_insert2);
shad::rt::waitForCompletion(handle2);

for (size_t i = 0; i < to_insert2; i++) {
ASSERT_EQ(values2[i], i + idx2 + 1);
}
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, BufferedSyncInsertAndSyncGet) {
auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);
for (size_t i = 0; i < kArraySize; i++) {
Expand Down Expand Up @@ -218,6 +247,15 @@ static void asyncApplyFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem,
elem += kInitValue;
}

static void asyncApplyWRBFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem, size_t &incr,
uint8_t* result, uint32_t* resultSize) {
ASSERT_EQ(incr, kInitValue);
ASSERT_EQ(elem, i + 1);
elem += kInitValue;
*resultSize = sizeof(elem);
memcpy(result, &elem, sizeof(elem));
}

static void asyncApplyFunNoArgs(shad::rt::Handle & /*unused*/, size_t i,
size_t &elem) {
ASSERT_EQ(elem, i + kInitValue + 1);
Expand Down Expand Up @@ -269,6 +307,38 @@ TEST_F(ArrayTest, AsyncInsertAsyncApplyAndAsyncGet) {
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, AsyncInsertAsyncApplyWRBAndAsyncGet) {
std::vector<size_t> values(kArraySize);
auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);

shad::rt::Handle handle;
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncInsertAt(handle, i, i + 1);
}
shad::rt::waitForCompletion(handle);

shad::rt::Handle handle2;
std::vector<size_t> ret_values(kArraySize);
std::vector<uint32_t> ret_sizes(kArraySize);
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncApplyWithRetBuff(handle2, i, asyncApplyWRBFun,
(uint8_t*)(&ret_values[i]),
&ret_sizes[i], kInitValue);
}
shad::rt::waitForCompletion(handle2);

shad::rt::Handle handle3;
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncAt(handle3, i, &values[i]);
}
shad::rt::waitForCompletion(handle3);
for (size_t i = 0; i < kArraySize; i++) {
ASSERT_EQ(values[i], i + 1 + kInitValue);
ASSERT_EQ(ret_values[i], i + 1 + kInitValue);
}
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, AsyncInsertSyncForEachInRangeAndAsyncGet) {
std::vector<size_t> values(kArraySize);
auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);
Expand Down Expand Up @@ -367,4 +437,4 @@ TEST_F(ArrayTest, AsyncInsertAsyncForEachAndAsyncGet) {
ASSERT_EQ(values[i], i + 1 + (2 * kInitValue));
}
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}
}

0 comments on commit a425cc2

Please sign in to comment.