diff --git a/include/shad/data_structures/array.h b/include/shad/data_structures/array.h index 169ecd93..a7c8082f 100644 --- a/include/shad/data_structures/array.h +++ b/include/shad/data_structures/array.h @@ -319,6 +319,11 @@ class Array : public AbstractDataStructure> { template void AsyncApply(rt::Handle &handle, const size_t pos, ApplyFunT &&function, Args &... args); + + template + void AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos, + ApplyFunT &&function, uint8_t* result, + uint32_t* resultSize, Args &... args); /// @brief Applies a user-defined function to every element /// in the specified range. @@ -617,6 +622,30 @@ class Array : public AbstractDataStructure> { std::get<4>(tuple), std::make_index_sequence{}); } + template + static void AsyncCallApplyWRBFun(rt::Handle &handle, ObjectID &oid, size_t pos, + size_t loffset, ApplyFunT function, + std::tuple &args, std::index_sequence, + uint8_t* result, uint32_t* resultSize) { + // Get a local instance on the remote node. + auto arrayPtr = Array::GetPtr(oid); + T &element = arrayPtr->data_[loffset]; + function(handle, pos, element, std::get(args)..., result, resultSize); + } + + template + static void AsyncApplyWRBFunWrapper(rt::Handle &handle, const Tuple &args, + uint8_t* result, uint32_t* resultSize) { + constexpr auto Size = std::tuple_size< + typename std::decay(args))>::type>::value; + + Tuple &tuple = const_cast(args); + + AsyncCallApplyWRBFun(handle, std::get<0>(tuple), std::get<1>(tuple), + std::get<2>(tuple), std::get<3>(tuple), + std::get<4>(tuple), std::make_index_sequence{}, result, resultSize); + } + template static void CallForEachInRangeFun(size_t i, ObjectID &oid, size_t pos, size_t lpos, ApplyFunT function, @@ -954,6 +983,25 @@ void Array::AsyncApply(rt::Handle &handle, const size_t pos, AsyncApplyFunWrapper, argsTuple); } +template +template +void Array::AsyncApplyWithRetBuff(rt::Handle &handle, const size_t pos, + ApplyFunT &&function, uint8_t* result, + uint32_t* resultSize, Args &... args) { + auto target = getTargetLocalityFromTargePosition(dataDistribution_, pos); + + using FunctionTy = void (*)(rt::Handle &, size_t, T &, Args & ..., uint8_t*, uint32_t*); + FunctionTy fn = std::forward(function); + using ArgsTuple = + std::tuple>; + ArgsTuple argsTuple{oid_, pos, target.second, fn, + std::tuple(args...)}; + + rt::asyncExecuteAtWithRetBuff(handle, target.first, + AsyncApplyWRBFunWrapper, argsTuple, + result, resultSize); +} + template template void Array::ForEachInRange(const size_t first, const size_t last, diff --git a/test/unit_tests/data_structures/array_test.cc b/test/unit_tests/data_structures/array_test.cc index d5fa6efb..3627c4bb 100644 --- a/test/unit_tests/data_structures/array_test.cc +++ b/test/unit_tests/data_structures/array_test.cc @@ -247,6 +247,15 @@ static void asyncApplyFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem, elem += kInitValue; } +static void asyncApplyWRBFun(shad::rt::Handle & /*unused*/, size_t i, size_t &elem, size_t &incr, + uint8_t* result, uint32_t* resultSize) { + ASSERT_EQ(incr, kInitValue); + ASSERT_EQ(elem, i + 1); + elem += kInitValue; + *resultSize = sizeof(elem); + memcpy(result, &elem, sizeof(elem)); +} + static void asyncApplyFunNoArgs(shad::rt::Handle & /*unused*/, size_t i, size_t &elem) { ASSERT_EQ(elem, i + kInitValue + 1); @@ -298,6 +307,38 @@ TEST_F(ArrayTest, AsyncInsertAsyncApplyAndAsyncGet) { shad::Array::Destroy(edsPtr->GetGlobalID()); } +TEST_F(ArrayTest, AsyncInsertAsyncApplyWRBAndAsyncGet) { + std::vector values(kArraySize); + auto edsPtr = shad::Array::Create(kArraySize, kInitValue); + + shad::rt::Handle handle; + for (size_t i = 0; i < kArraySize; i++) { + edsPtr->AsyncInsertAt(handle, i, i + 1); + } + shad::rt::waitForCompletion(handle); + + shad::rt::Handle handle2; + std::vector ret_values(kArraySize); + std::vector ret_sizes(kArraySize); + for (size_t i = 0; i < kArraySize; i++) { + edsPtr->AsyncApplyWithRetBuff(handle2, i, asyncApplyWRBFun, + (uint8_t*)(&ret_values[i]), + &ret_sizes[i], kInitValue); + } + shad::rt::waitForCompletion(handle2); + + shad::rt::Handle handle3; + for (size_t i = 0; i < kArraySize; i++) { + edsPtr->AsyncAt(handle3, i, &values[i]); + } + shad::rt::waitForCompletion(handle3); + for (size_t i = 0; i < kArraySize; i++) { + ASSERT_EQ(values[i], i + 1 + kInitValue); + ASSERT_EQ(ret_values[i], i + 1 + kInitValue); + } + shad::Array::Destroy(edsPtr->GetGlobalID()); +} + TEST_F(ArrayTest, AsyncInsertSyncForEachInRangeAndAsyncGet) { std::vector values(kArraySize); auto edsPtr = shad::Array::Create(kArraySize, kInitValue);