diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h index f4eb557170..a5f5321768 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h @@ -701,35 +701,41 @@ struct __deferrable_mode { }; -//A contract for future class: -//Impl details: inheritance (private) instead of aggregation for enabling the empty base optimization. -template -class __future : private std::tuple<_Args...> +// An overload of __wait_and_get_value for 'sycl::buffer' +template +constexpr auto +__wait_and_get_value(_Event&&, const sycl::buffer<_T>& __buf, std::size_t __idx = 0) { - _Event __my_event; + //according to a contract, returned value is one-element sycl::buffer + return __buf.get_host_access(sycl::read_only)[__idx]; +} - template - constexpr auto - __wait_and_get_value(const sycl::buffer<_T>& __buf) - { - //according to a contract, returned value is one-element sycl::buffer - return __buf.get_host_access(sycl::read_only)[0]; - } +// An overload of __wait_and_get_value for '__result_and_scratch_storage' +template +constexpr auto +__wait_and_get_value(_Event&& __event, const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage, + std::size_t __idx = 0) +{ + return __storage.__wait_and_get_value(__event, __idx); +} - template - constexpr auto - __wait_and_get_value(const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage) - { - return __storage.__wait_and_get_value(__my_event); - } +template +constexpr auto +__wait_and_get_value(_Event&& __event, const _T& __val, std::size_t) +{ + __event.wait_and_throw(); + return __val; +} - template - constexpr auto - __wait_and_get_value(const _T& __val) - { - wait(); - return __val; - } +//A contract for 'future' class: +//* The first argument is an event, which has 'wait_and_throw' method +//* The second and the following argument a trivial type T or RAII storage. +//* The 'future' class extends the lifetime for such RAII objects. +//* Impl details: inheritance (private) instead of aggregation for enabling the empty base optimization. +template +class __future : private std::tuple<_Args...> +{ + _Event __my_event; public: __future(_Event __e, _Args... __args) : std::tuple<_Args...>(__args...), __my_event(__e) {} @@ -764,13 +770,17 @@ class __future : private std::tuple<_Args...> #endif } + //_ArgsIdx specifies a compile time index of i-th argument passed into '__future' constructor after an event + //__elem_idx specifies a runtime time index of k-th element of i-th argument in case when i-th argument + // is not scalar value - an array/buffer like type. + template auto - get() + get(std::size_t __elem_idx = 0) { if constexpr (sizeof...(_Args) > 0) { - auto& __val = std::get<0>(*this); - return __wait_and_get_value(__val); + auto& __val = std::get<_ArgsIdx>(*this); + return __wait_and_get_value(event(), __val, __elem_idx); } else wait();