Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance querying kernels preferred wgsize #16186

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(std::shared_ptr<queue_impl> Queue,
size_t LocalMemBytesPerWorkItem);
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
size_t &NWorkGroups);
__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
size_t LocalMemBytesPerWorkItem);
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem);

template <typename T, class BinaryOperation, bool IsOptional>
class ReducerElement;
Expand Down Expand Up @@ -1200,6 +1200,25 @@ void reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
});
}

template <typename KernelName>
size_t reduGetPreferredKernelWGSize(std::shared_ptr<queue_impl> &Queue) {
using namespace info::kernel_device_specific;
auto SyclQueue = createSyclObjFromImpl<queue>(Queue);
auto Ctx = SyclQueue.get_context();
auto Dev = SyclQueue.get_device();
size_t MaxWGSize = SIZE_MAX;
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};

if (!IsUndefinedKernelName) {
auto ExecBundle =
get_kernel_bundle<KernelName, bundle_state::executable>(Ctx, {Dev});
kernel Kernel = ExecBundle.template get_kernel<KernelName>();
MaxWGSize = Kernel.template get_info<work_group_size>(Dev);
}

return MaxWGSize;
}

namespace reduction {
template <typename KernelName, strategy S, class... Ts> struct MainKrn;
template <typename KernelName, strategy S, class... Ts> struct AuxKrn;
Expand Down Expand Up @@ -1302,6 +1321,8 @@ struct NDRangeReduction<
reduction::strategy::group_reduce_and_last_wg_detection,
decltype(NWorkGroupsFinished)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer;
Expand Down Expand Up @@ -1515,6 +1536,8 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
using Name = __sycl_reduction_kernel<reduction::MainKrn, KernelName,
reduction::strategy::range_basic>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
// Call user's functions. Reducer.MValue gets initialized there.
reducer_type Reducer = reducer_type(IdentityContainer, BOp);
Expand Down Expand Up @@ -1628,14 +1651,14 @@ struct NDRangeReduction<
using reducer_type = typename Reduction::reducer_type;
using element_type = typename ReducerTraits<reducer_type>::element_type;

std::ignore = Queue;
using Name = __sycl_reduction_kernel<
reduction::MainKrn, KernelName,
reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
Redu.template withInitializedMem<Name>(CGH, [&](auto Out) {
size_t NElements = Reduction::num_elements;
size_t WGSize = NDRange.get_local_range().size();

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
// Use local memory to reduce elements in work-groups into zero-th
// element.
local_accessor<element_type, 1> LocalReds{WGSize, CGH};
Expand Down Expand Up @@ -1722,6 +1745,8 @@ struct NDRangeReduction<
reduction::MainKrn, KernelName,
reduction::strategy::group_reduce_and_multiple_kernels>;

MaxWGSize = std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer;
Expand Down Expand Up @@ -1781,6 +1806,8 @@ struct NDRangeReduction<
reduction::AuxKrn, KernelName,
reduction::strategy::group_reduce_and_multiple_kernels>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

bool IsUpdateOfUserVar = !Reduction::is_usm &&
!Redu.initializeToIdentity() &&
NWorkGroups == 1;
Expand Down Expand Up @@ -1874,6 +1901,9 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
reduction::strategy::basic,
decltype(KernelTag)>;

MaxWGSize =
std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer =
Expand Down Expand Up @@ -1978,6 +2008,8 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
reduction::strategy::basic,
decltype(KernelTag)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

range<1> GlobalRange = {UniformPow2WG ? NWorkItems
: NWorkGroups * WGSize};
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
Expand Down Expand Up @@ -2295,8 +2327,9 @@ template <class KernelName, class Accessor> struct NDRangeMulti;
} // namespace reduction::main_krn
template <typename KernelName, typename KernelType, int Dims,
typename PropertiesT, typename... Reductions, size_t... Is>
void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
const nd_range<Dims> &Range, PropertiesT Properties,
void reduCGFuncMulti(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
KernelType KernelFunc, const nd_range<Dims> &Range,
PropertiesT Properties,
std::tuple<Reductions...> &ReduTuple,
std::index_sequence<Is...> ReduIndices) {
size_t WGSize = Range.get_local_range().size();
Expand Down Expand Up @@ -2334,6 +2367,8 @@ void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
reduction::strategy::multi,
decltype(KernelTag)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(Range, Properties, [=](nd_item<Dims> NDIt) {
// We can deduce IsOneWG from the tag type.
constexpr bool IsOneWG =
Expand Down Expand Up @@ -2495,7 +2530,8 @@ template <class KernelName, class Predicate> struct Multi;
} // namespace reduction::aux_krn
template <typename KernelName, typename KernelType, typename... Reductions,
size_t... Is>
size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
size_t reduAuxCGFunc(handler &CGH, std::shared_ptr<queue_impl> &Queue,
size_t NWorkItems, size_t MaxWGSize,
std::tuple<Reductions...> &ReduTuple,
std::index_sequence<Is...> ReduIndices) {
size_t NWorkGroups;
Expand Down Expand Up @@ -2533,6 +2569,8 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
using Name = __sycl_reduction_kernel<reduction::AuxKrn, KernelName,
reduction::strategy::multi,
decltype(Predicate)>;
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

// TODO: Opportunity to parallelize across number of elements
range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
Expand Down Expand Up @@ -2617,15 +2655,15 @@ template <> struct NDRangeReduction<reduction::strategy::multi> {
" than " +
std::to_string(MaxWGSize));

reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
ReduIndices);
reduCGFuncMulti<KernelName>(CGH, Queue, KernelFunc, NDRange, Properties,
ReduTuple, ReduIndices);
reduction::finalizeHandler(CGH);

size_t NWorkItems = NDRange.get_group_range().size();
while (NWorkItems > 1) {
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
AuxHandler, Queue, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
});
} // end while (NWorkItems > 1)
}
Expand Down Expand Up @@ -2741,7 +2779,29 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
// TODO: currently the preferred work group size is determined for the given
// queue/device, while it is safer to use queries to the kernel pre-compiled
// for the device.
size_t PrefWGSize = reduGetPreferredWGSize(CGH.MQueue, OneElemSize);
size_t PrefWGSize = reduGetPreferredDeviceWGSize(CGH.MQueue, OneElemSize);

auto SyclQueue = createSyclObjFromImpl<queue>(CGH.MQueue);
auto Ctx = SyclQueue.get_context();
auto Dev = SyclQueue.get_device();

// If the reduction kernel is not name defined, we won't be able to query the
// exact kernel for the best wgsize, so we query all the reduction kernels for
// thier wgsize and use the minimum wgsize as a safe and approximate option.
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
if (IsUndefinedKernelName) {
std::vector<kernel_id> ReductionKernelIDs = get_kernel_ids();
for (auto KernelID : ReductionKernelIDs) {
std::string ReduKernelName = KernelID.get_name();
if (ReduKernelName.find("reduction") != std::string::npos) {
auto KB = get_kernel_bundle<bundle_state::executable>(Ctx, {KernelID});
kernel krn = KB.get_kernel(KernelID);
using namespace info::kernel_device_specific;
size_t MaxSize = krn.template get_info<work_group_size>(Dev);
PrefWGSize = std::min(PrefWGSize, MaxSize);
}
}
}

size_t NWorkItems = Range.size();
size_t WGSize = std::min(NWorkItems, PrefWGSize);
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ reduGetMaxWGSize(std::shared_ptr<sycl::detail::queue_impl> Queue,
return WGSize;
}

__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
size_t LocalMemBytesPerWorkItem) {
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem) {
// TODO: Graphs extension explicit API uses a handler with a null queue to
// process CGFs, in future we should have access to the device so we can
// correctly calculate this.
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3294,7 +3294,7 @@ _ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6devi
_ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22reduGetPreferredWGSizeERSt10shared_ptrINS1_10queue_implEEm
_ZN4sycl3_V16detail28reduGetPreferredDeviceWGSizeERSt10shared_ptrINS1_10queue_implEEm
_ZN4sycl3_V16detail22removeDuplicateDevicesERKSt6vectorINS0_6deviceESaIS3_EE
_ZN4sycl3_V16detail23constructorNotificationEPvS2_NS0_6access6targetENS3_4modeERKNS1_13code_locationE
_ZN4sycl3_V16detail24find_device_intersectionERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EE
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,7 @@
?reduComputeWGSize@detail@_V1@sycl@@YA_K_K0AEA_K@Z
?reduGetMaxNumConcurrentWorkGroups@detail@_V1@sycl@@YAIV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@@Z
?reduGetMaxWGSize@detail@_V1@sycl@@YA_KV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?reduGetPreferredWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?reduGetPreferredDeviceWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?registerDynamicParameter@handler@_V1@sycl@@AEAAXAEAVdynamic_parameter_base@detail@experimental@oneapi@ext@23@H@Z
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVdevice@45@AEBVcontext@45@@Z
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVqueue@45@@Z
Expand Down
Loading