Skip to content

Commit

Permalink
Enhance querying kernels preferred wgsize
Browse files Browse the repository at this point in the history
Co-authored-by: Georgi Mirazchiyski <[email protected]>
  • Loading branch information
omarahmed1111 and GeorgeWeb committed Nov 28, 2024
1 parent 08a2edc commit 71739a8
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 14 deletions.
80 changes: 70 additions & 10 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(std::shared_ptr<queue_impl> Queue,
size_t LocalMemBytesPerWorkItem);
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
size_t &NWorkGroups);
__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
size_t LocalMemBytesPerWorkItem);
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem);

template <typename T, class BinaryOperation, bool IsOptional>
class ReducerElement;
Expand Down Expand Up @@ -1200,6 +1200,25 @@ void reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
});
}

template <typename KernelName>
size_t reduGetPreferredKernelWGSize(std::shared_ptr<queue_impl> &Queue) {
using namespace info::kernel_device_specific;
auto SyclQueue = createSyclObjFromImpl<queue>(Queue);
auto Ctx = SyclQueue.get_context();
auto Dev = SyclQueue.get_device();
size_t MaxWGSize = SIZE_MAX;
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};

if (!IsUndefinedKernelName) {
auto ExecBundle =
get_kernel_bundle<KernelName, bundle_state::executable>(Ctx, {Dev});
kernel Kernel = ExecBundle.template get_kernel<KernelName>();
MaxWGSize = Kernel.template get_info<work_group_size>(Dev);
}

return MaxWGSize;
}

namespace reduction {
template <typename KernelName, strategy S, class... Ts> struct MainKrn;
template <typename KernelName, strategy S, class... Ts> struct AuxKrn;
Expand Down Expand Up @@ -1302,6 +1321,8 @@ struct NDRangeReduction<
reduction::strategy::group_reduce_and_last_wg_detection,
decltype(NWorkGroupsFinished)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer;
Expand Down Expand Up @@ -1515,6 +1536,8 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
using Name = __sycl_reduction_kernel<reduction::MainKrn, KernelName,
reduction::strategy::range_basic>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
// Call user's functions. Reducer.MValue gets initialized there.
reducer_type Reducer = reducer_type(IdentityContainer, BOp);
Expand Down Expand Up @@ -1628,14 +1651,14 @@ struct NDRangeReduction<
using reducer_type = typename Reduction::reducer_type;
using element_type = typename ReducerTraits<reducer_type>::element_type;

std::ignore = Queue;
using Name = __sycl_reduction_kernel<
reduction::MainKrn, KernelName,
reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
Redu.template withInitializedMem<Name>(CGH, [&](auto Out) {
size_t NElements = Reduction::num_elements;
size_t WGSize = NDRange.get_local_range().size();

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
// Use local memory to reduce elements in work-groups into zero-th
// element.
local_accessor<element_type, 1> LocalReds{WGSize, CGH};
Expand Down Expand Up @@ -1722,6 +1745,8 @@ struct NDRangeReduction<
reduction::MainKrn, KernelName,
reduction::strategy::group_reduce_and_multiple_kernels>;

MaxWGSize = std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer;
Expand Down Expand Up @@ -1781,6 +1806,8 @@ struct NDRangeReduction<
reduction::AuxKrn, KernelName,
reduction::strategy::group_reduce_and_multiple_kernels>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

bool IsUpdateOfUserVar = !Reduction::is_usm &&
!Redu.initializeToIdentity() &&
NWorkGroups == 1;
Expand Down Expand Up @@ -1874,6 +1901,9 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
reduction::strategy::basic,
decltype(KernelTag)>;

MaxWGSize =
std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
// Call user's functions. Reducer.MValue gets initialized there.
typename Reduction::reducer_type Reducer =
Expand Down Expand Up @@ -1978,6 +2008,8 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
reduction::strategy::basic,
decltype(KernelTag)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

range<1> GlobalRange = {UniformPow2WG ? NWorkItems
: NWorkGroups * WGSize};
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
Expand Down Expand Up @@ -2295,8 +2327,9 @@ template <class KernelName, class Accessor> struct NDRangeMulti;
} // namespace reduction::main_krn
template <typename KernelName, typename KernelType, int Dims,
typename PropertiesT, typename... Reductions, size_t... Is>
void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
const nd_range<Dims> &Range, PropertiesT Properties,
void reduCGFuncMulti(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
KernelType KernelFunc, const nd_range<Dims> &Range,
PropertiesT Properties,
std::tuple<Reductions...> &ReduTuple,
std::index_sequence<Is...> ReduIndices) {
size_t WGSize = Range.get_local_range().size();
Expand Down Expand Up @@ -2334,6 +2367,8 @@ void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
reduction::strategy::multi,
decltype(KernelTag)>;

WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

CGH.parallel_for<Name>(Range, Properties, [=](nd_item<Dims> NDIt) {
// We can deduce IsOneWG from the tag type.
constexpr bool IsOneWG =
Expand Down Expand Up @@ -2495,7 +2530,8 @@ template <class KernelName, class Predicate> struct Multi;
} // namespace reduction::aux_krn
template <typename KernelName, typename KernelType, typename... Reductions,
size_t... Is>
size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
size_t reduAuxCGFunc(handler &CGH, std::shared_ptr<queue_impl> &Queue,
size_t NWorkItems, size_t MaxWGSize,
std::tuple<Reductions...> &ReduTuple,
std::index_sequence<Is...> ReduIndices) {
size_t NWorkGroups;
Expand Down Expand Up @@ -2533,6 +2569,8 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
using Name = __sycl_reduction_kernel<reduction::AuxKrn, KernelName,
reduction::strategy::multi,
decltype(Predicate)>;
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));

// TODO: Opportunity to parallelize across number of elements
range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
Expand Down Expand Up @@ -2617,15 +2655,15 @@ template <> struct NDRangeReduction<reduction::strategy::multi> {
" than " +
std::to_string(MaxWGSize));

reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
ReduIndices);
reduCGFuncMulti<KernelName>(CGH, Queue, KernelFunc, NDRange, Properties,
ReduTuple, ReduIndices);
reduction::finalizeHandler(CGH);

size_t NWorkItems = NDRange.get_group_range().size();
while (NWorkItems > 1) {
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
AuxHandler, Queue, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
});
} // end while (NWorkItems > 1)
}
Expand Down Expand Up @@ -2741,7 +2779,29 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
// TODO: currently the preferred work group size is determined for the given
// queue/device, while it is safer to use queries to the kernel pre-compiled
// for the device.
size_t PrefWGSize = reduGetPreferredWGSize(CGH.MQueue, OneElemSize);
size_t PrefWGSize = reduGetPreferredDeviceWGSize(CGH.MQueue, OneElemSize);

auto SyclQueue = createSyclObjFromImpl<queue>(CGH.MQueue);
auto Ctx = SyclQueue.get_context();
auto Dev = SyclQueue.get_device();

// If the reduction kernel is not name defined, we won't be able to query the
// exact kernel for the best wgsize, so we query all the reduction kernels for
// thier wgsize and use the minimum wgsize as a safe and approximate option.
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
if (IsUndefinedKernelName) {
std::vector<kernel_id> ReductionKernelIDs = get_kernel_ids();
for (auto KernelID : ReductionKernelIDs) {
std::string ReduKernelName = KernelID.get_name();
if (ReduKernelName.find("reduction") != std::string::npos) {
auto KB = get_kernel_bundle<bundle_state::executable>(Ctx, {KernelID});
kernel krn = KB.get_kernel(KernelID);
using namespace info::kernel_device_specific;
size_t MaxSize = krn.template get_info<work_group_size>(Dev);
PrefWGSize = std::min(PrefWGSize, MaxSize);
}
}
}

size_t NWorkItems = Range.size();
size_t WGSize = std::min(NWorkItems, PrefWGSize);
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ reduGetMaxWGSize(std::shared_ptr<sycl::detail::queue_impl> Queue,
return WGSize;
}

__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
size_t LocalMemBytesPerWorkItem) {
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem) {
// TODO: Graphs extension explicit API uses a handler with a null queue to
// process CGFs, in future we should have access to the device so we can
// correctly calculate this.
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3294,7 +3294,7 @@ _ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6devi
_ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
_ZN4sycl3_V16detail22reduGetPreferredWGSizeERSt10shared_ptrINS1_10queue_implEEm
_ZN4sycl3_V16detail28reduGetPreferredDeviceWGSizeERSt10shared_ptrINS1_10queue_implEEm
_ZN4sycl3_V16detail22removeDuplicateDevicesERKSt6vectorINS0_6deviceESaIS3_EE
_ZN4sycl3_V16detail23constructorNotificationEPvS2_NS0_6access6targetENS3_4modeERKNS1_13code_locationE
_ZN4sycl3_V16detail24find_device_intersectionERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EE
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,7 @@
?reduComputeWGSize@detail@_V1@sycl@@YA_K_K0AEA_K@Z
?reduGetMaxNumConcurrentWorkGroups@detail@_V1@sycl@@YAIV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@@Z
?reduGetMaxWGSize@detail@_V1@sycl@@YA_KV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?reduGetPreferredWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?reduGetPreferredDeviceWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
?registerDynamicParameter@handler@_V1@sycl@@AEAAXAEAVdynamic_parameter_base@detail@experimental@oneapi@ext@23@H@Z
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVdevice@45@AEBVcontext@45@@Z
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVqueue@45@@Z
Expand Down

0 comments on commit 71739a8

Please sign in to comment.