Skip to content

Commit

Permalink
[#206] Implemented AsyncApplyWithRetBuff in Array
Browse files Browse the repository at this point in the history
  • Loading branch information
VitoCastellana committed Feb 5, 2022
1 parent 4cc1cec commit c77ac93
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
48 changes: 48 additions & 0 deletions include/shad/data_structures/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ class Array : public AbstractDataStructure<Array<T>> {
template <typename ApplyFunT, typename... Args>
void AsyncApply(rt::Handle &handle, const size_t pos, ApplyFunT &&function,
Args &... args);

template <typename ApplyFunT, typename... Args>
void AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos,
ApplyFunT &&function, uint8_t* result,
uint32_t* resultSize, Args &... args);

/// @brief Applies a user-defined function to every element
/// in the specified range.
Expand Down Expand Up @@ -617,6 +622,30 @@ class Array : public AbstractDataStructure<Array<T>> {
std::get<4>(tuple), std::make_index_sequence<Size>{});
}

template <typename ApplyFunT, typename... Args, std::size_t... is>
static void AsyncCallApplyWRBFun(rt::Handle &handle, ObjectID &oid, size_t pos,
size_t loffset, ApplyFunT function,
std::tuple<Args...> &args, std::index_sequence<is...>,
uint8_t* result, uint32_t* resultSize) {
// Get a local instance on the remote node.
auto arrayPtr = Array<T>::GetPtr(oid);
T &element = arrayPtr->data_[loffset];
function(handle, pos, element, std::get<is>(args)..., result, resultSize);
}

template <typename Tuple, typename... Args>
static void AsyncApplyWRBFunWrapper(rt::Handle &handle, const Tuple &args,
uint8_t* result, uint32_t* resultSize) {
constexpr auto Size = std::tuple_size<
typename std::decay<decltype(std::get<4>(args))>::type>::value;

Tuple &tuple = const_cast<Tuple &>(args);

AsyncCallApplyWRBFun(handle, std::get<0>(tuple), std::get<1>(tuple),
std::get<2>(tuple), std::get<3>(tuple),
std::get<4>(tuple), std::make_index_sequence<Size>{}, result, resultSize);
}

template <typename ApplyFunT, typename... Args, std::size_t... is>
static void CallForEachInRangeFun(size_t i, ObjectID &oid, size_t pos,
size_t lpos, ApplyFunT function,
Expand Down Expand Up @@ -954,6 +983,25 @@ void Array<T>::AsyncApply(rt::Handle &handle, const size_t pos,
AsyncApplyFunWrapper<ArgsTuple, Args...>, argsTuple);
}

template <typename T>
template <typename ApplyFunT, typename... Args>
void Array<T>::AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos,
ApplyFunT &&function, uint8_t* result,
uint32_t* resultSize, Args &... args) {
auto target = getTargetLocalityFromTargePosition(dataDistribution_, pos);

using FunctionTy = void (*)(rt::Handle &, size_t, T &, Args & ..., uint8_t*, uint32_t*);
FunctionTy fn = std::forward<decltype(function)>(function);
using ArgsTuple =
std::tuple<ObjectID, size_t, size_t, FunctionTy, std::tuple<Args...>>;
ArgsTuple argsTuple{oid_, pos, target.second, fn,
std::tuple<Args...>(args...)};

rt::asyncExecuteAtWithRetBuff(handle, target.first,
AsyncApplyWRBFunWrapper<ArgsTuple, Args...>, argsTuple,
result, resultSize);
}

template <typename T>
template <typename ApplyFunT, typename... Args>
void Array<T>::ForEachInRange(const size_t first, const size_t last,
Expand Down
41 changes: 41 additions & 0 deletions test/unit_tests/data_structures/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ static void asyncApplyFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem,
elem += kInitValue;
}

static void asyncApplyWRBFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem, size_t &incr,
uint8_t* result, uint32_t* resultSize) {
ASSERT_EQ(incr, kInitValue);
ASSERT_EQ(elem, i + 1);
elem += kInitValue;
*resultSize = sizeof(elem);
memcpy(result, &elem, sizeof(elem));
}

static void asyncApplyFunNoArgs(shad::rt::Handle & /*unused*/, size_t i,
size_t &elem) {
ASSERT_EQ(elem, i + kInitValue + 1);
Expand Down Expand Up @@ -298,6 +307,38 @@ TEST_F(ArrayTest, AsyncInsertAsyncApplyAndAsyncGet) {
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, AsyncInsertAsyncApplyWRBAndAsyncGet) {
std::vector<size_t> values(kArraySize);
auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);

shad::rt::Handle handle;
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncInsertAt(handle, i, i + 1);
}
shad::rt::waitForCompletion(handle);

shad::rt::Handle handle2;
std::vector<size_t> ret_values(kArraySize);
std::vector<uint32_t> ret_sizes(kArraySize);
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncApplyWithRetBuff(handle2, i, asyncApplyWRBFun,
(uint8_t*)(&ret_values[i]),
&ret_sizes[i], kInitValue);
}
shad::rt::waitForCompletion(handle2);

shad::rt::Handle handle3;
for (size_t i = 0; i < kArraySize; i++) {
edsPtr->AsyncAt(handle3, i, &values[i]);
}
shad::rt::waitForCompletion(handle3);
for (size_t i = 0; i < kArraySize; i++) {
ASSERT_EQ(values[i], i + 1 + kInitValue);
ASSERT_EQ(ret_values[i], i + 1 + kInitValue);
}
shad::Array<size_t>::Destroy(edsPtr->GetGlobalID());
}

TEST_F(ArrayTest, AsyncInsertSyncForEachInRangeAndAsyncGet) {
std::vector<size_t> values(kArraySize);
auto edsPtr = shad::Array<size_t>::Create(kArraySize, kInitValue);
Expand Down

0 comments on commit c77ac93

Please sign in to comment.