Skip to content

Commit

Permalink
Do not use deprecated sub-group load/store extension (#1979)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriy-sobolev authored Jan 6, 2025
1 parent 07dfb58 commit 77e9f07
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 68 deletions.
77 changes: 12 additions & 65 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,59 +492,26 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
const ::std::uint16_t __subgroup_id = __subgroup.get_group_id();
const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range();

#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT
constexpr bool __can_use_subgroup_load_store =
_IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>;
#else
constexpr bool __can_use_subgroup_load_store = false;
#endif

auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc);
if constexpr (__can_use_subgroup_load_store)
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
auto __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx));
__subgroup.store(__lacc_ptr + __idx, __val);
}
}
else
{
for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
__lacc[__idx] = __unary_op(__in_rng[__idx]);
}
__lacc[__idx] = __unary_op(__in_rng[__idx]);
}

__scan_work_group<_ValueType, _Inclusive>(__group, __lacc_ptr, __lacc_ptr + __n,
__lacc_ptr, __bin_op, __init);

if constexpr (__can_use_subgroup_load_store)
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
auto __val = __subgroup.load(__lacc_ptr + __idx);
__subgroup.store(__out_rng.begin() + __idx, __val);
}
__out_rng[__idx] = __lacc[__idx];
}
else
{
for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
__out_rng[__idx] = __lacc[__idx];
}

const ::std::uint16_t __residual = __n % _WGSize;
const ::std::uint16_t __residual_start = __n - __residual;
if (__item_id < __residual)
{
auto __idx = __residual_start + __item_id;
__out_rng[__idx] = __lacc[__idx];
}
const std::uint16_t __residual = __n % _WGSize;
const std::uint16_t __residual_start = __n - __residual;
if (__item_id < __residual)
{
auto __idx = __residual_start + __item_id;
__out_rng[__idx] = __lacc[__idx];
}
});
});
Expand Down Expand Up @@ -597,30 +564,10 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
const ::std::uint16_t __item_id = __self_item.get_local_linear_id();
const ::std::uint16_t __subgroup_id = __subgroup.get_group_id();
const ::std::uint16_t __subgroup_size = __subgroup.get_local_linear_range();

#if _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT
constexpr bool __can_use_subgroup_load_store =
_IsFullGroup && oneapi::dpl::__internal::__range_has_raw_ptr_iterator_v<::std::decay_t<_InRng>>;
#else
constexpr bool __can_use_subgroup_load_store = false;
#endif
auto __lacc_ptr = __dpl_sycl::__get_accessor_ptr(__lacc);
if constexpr (__can_use_subgroup_load_store)
for (std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
_ONEDPL_PRAGMA_UNROLL
for (::std::uint16_t __i = 0; __i < _ElemsPerItem; ++__i)
{
auto __idx = __i * _WGSize + __subgroup_id * __subgroup_size;
uint16_t __val = __unary_op(__subgroup.load(__in_rng.begin() + __idx));
__subgroup.store(__lacc_ptr + __idx, __val);
}
}
else
{
for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
__lacc[__idx] = __unary_op(__in_rng[__idx]);
}
__lacc[__idx] = __unary_op(__in_rng[__idx]);
}

__scan_work_group<_ValueType, /* _Inclusive */ false>(
Expand Down
3 changes: 0 additions & 3 deletions include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@
#define _ONEDPL_SYCL_DEVICE_COPYABLE_SPECIALIZATION_BROKEN (_ONEDPL_LIBSYCL_VERSION_LESS_THAN(70100))
// TODO: determine which compiler configurations provide subgroup load/store
#define _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT false
// Macro to check if we are compiling for SPIR-V devices. This macro must only be used within
// SYCL kernels for determining SPIR-V compilation. Using this macro on the host may lead to incorrect behavior.
#ifndef _ONEDPL_DETECT_SPIRV_COMPILATION // Check if overridden for testing
Expand Down

0 comments on commit 77e9f07

Please sign in to comment.