diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 6c972b4829..3c7eb21607 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -492,59 +492,26 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem const ::std::uint16_t __subgroup_id = __subgroup.get_group_id(); const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range(); -#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT - constexpr bool __can_use_subgroup_load_store = - _IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>; -#else - constexpr bool __can_use_subgroup_load_store = false; -#endif - auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc); - if constexpr (__can_use_subgroup_load_store) + for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) { - _ONEDPL_PRAGMA_UNROLL - for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i) - { - auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size; - auto __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx)); - __subgroup.store(__lacc_ptr + __idx, __val); - } - } - else - { - for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) - { - __lacc[__idx] = __unary_op(__in_rng[__idx]); - } + __lacc[__idx] = __unary_op(__in_rng[__idx]); } __scan_work_group<_ValueType, _Inclusive>(__group, __lacc_ptr, __lacc_ptr + __n, __lacc_ptr, __bin_op, __init); - if constexpr (__can_use_subgroup_load_store) + for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) { - _ONEDPL_PRAGMA_UNROLL - for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i) - { - auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size; - auto __val = __subgroup.load(__lacc_ptr + __idx); - __subgroup.store(__out_rng.begin() + __idx, __val); - } + __out_rng[__idx] = __lacc[__idx]; } - else - { - for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) - { - __out_rng[__idx] = __lacc[__idx]; - } - const ::std::uint16_t __residual = __n % _WGSize; - const ::std::uint16_t __residual_start = __n - __residual; - if (__item_id < __residual) - { - auto __idx = __residual_start + __item_id; - __out_rng[__idx] = __lacc[__idx]; - } + const std::uint16_t __residual = __n % _WGSize; + const std::uint16_t __residual_start = __n - __residual; + if (__item_id < __residual) + { + auto __idx = __residual_start + __item_id; + __out_rng[__idx] = __lacc[__idx]; } }); }); @@ -597,30 +564,10 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W const ::std::uint16_t __item_id = __self_item.get_local_linear_id(); const ::std::uint16_t __subgroup_id = __subgroup.get_group_id(); const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range(); - -#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT - constexpr bool __can_use_subgroup_load_store = - _IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>; -#else - constexpr bool __can_use_subgroup_load_store = false; -#endif auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc); - if constexpr (__can_use_subgroup_load_store) + for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) { - _ONEDPL_PRAGMA_UNROLL - for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i) - { - auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size; - uint16_t __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx)); - __subgroup.store(__lacc_ptr + __idx, __val); - } - } - else - { - for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize) - { - __lacc[__idx] = __unary_op(__in_rng[__idx]); - } + __lacc[__idx] = __unary_op(__in_rng[__idx]); } __scan_work_group<_ValueType, /* _Inclusive */ false>( diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h index c8618dfa62..72540c492b 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h @@ -89,9 +89,6 @@ #define _ONEDPL_SYCL_DEVICE_COPYABLE_SPECIALIZATION_BROKEN (_ONEDPL_LIBSYCL_VERSION_LESS_THAN(70100)) -// TODO: determine which compiler configurations provide subgroup load/store -#define _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT false - // Macro to check if we are compiling for SPIR-V devices. This macro must only be used within // SYCL kernels for determining SPIR-V compilation. Using this macro on the host may lead to incorrect behavior. #ifndef _ONEDPL_DETECT_SPIRV_COMPILATION // Check if overridden for testing