diff --git a/src/ATen/native/xpu/sycl/Loops.h b/src/ATen/native/xpu/sycl/Loops.h index 05a2f6b46..690ad92f1 100644 --- a/src/ATen/native/xpu/sycl/Loops.h +++ b/src/ATen/native/xpu/sycl/Loops.h @@ -305,7 +305,7 @@ struct LegacyKernelWithCastScalarFunctor { const func_t f_; }; -template +template static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { @@ -316,7 +316,12 @@ static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) { int64_t wg_sz = syclMaxWorkItemsPerSubSlice(); int64_t num_wg = ceil_div(N, wg_sz * vec_size); - sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + if constexpr (force_small_grf) { + sycl_kernel_submit_small_grf( + wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + } else { + sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + } } template @@ -341,7 +346,8 @@ template < typename in_calc_t, typename out_calc_t, typename loader_t, - typename storer_t> + typename storer_t, + bool force_small_grf = false> static inline void launch_unrolled_kernel( int64_t N, const func_t& f, @@ -357,7 +363,12 @@ static inline void launch_unrolled_kernel( int64_t wg_sz = syclMaxWorkItemsPerSubSlice(); int64_t num_wg = ceil_div(N, wg_sz * ker_t::item_work_size); - sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + if constexpr (force_small_grf) { + sycl_kernel_submit_small_grf( + wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + } else { + sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); + } } constexpr int max_scalar_size_(std::tuple<>) { @@ -570,7 +581,14 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { auto storer = memory::StoreWithCast<1>(iter); auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); - launch_unrolled_kernel( + launch_unrolled_kernel< + func_t, + decltype(data), + decltype(input_offset_calculator), + decltype(output_offset_calculator), + decltype(loader), + decltype(storer), + true>( numel, f, data, @@ -585,13 +603,13 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { } auto offset_calc = ::make_offset_calculator(iter); constexpr int unroll_factor = sizeof(arg0_t) > 4 ? 2 : 4; - launch_legacy_group_range_kernel( - numel, - LegacyKernelWithCastScalarFunctor< - arg0_t, - ntensors, - decltype(offset_calc), - func_t>(data, dtypes, offset_calc, f)); + using functor = LegacyKernelWithCastScalarFunctor< + arg0_t, + ntensors, + decltype(offset_calc), + func_t>; + launch_legacy_group_range_kernel( + numel, functor(data, dtypes, offset_calc, f)); } } diff --git a/src/comm/SYCLHelpers.h b/src/comm/SYCLHelpers.h index 8ceaa3c3f..97b847561 100644 --- a/src/comm/SYCLHelpers.h +++ b/src/comm/SYCLHelpers.h @@ -1,6 +1,9 @@ #pragma once #include +#include +#include +#include #include // sycl access address space @@ -139,6 +142,52 @@ sycl_kernel_submit( q.submit(cgf); } +template +struct SmallGRF; + +template +static inline typename std::enable_if< + std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit_small_grf( + int64_t global_range, + int64_t local_range, + ::sycl::queue q, + ker_t ker) { + ::sycl::ext::oneapi::experimental::properties kernel_props{ + ::sycl::ext::intel::experimental::grf_size<128>}; + auto range = ::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range)); + ::sycl::ext::oneapi::experimental::launch_config config(range, kernel_props); + auto cgf = [&](::sycl::handler& cgh) { + ker.sycl_ker_config_convention(cgh); + ::sycl::ext::oneapi::experimental::nd_launch>( + cgh, config, ker); + }; + q.submit(cgf); +} + +template +static inline typename std::enable_if< + !std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit_small_grf( + int64_t global_range, + int64_t local_range, + ::sycl::queue q, + ker_t ker) { + ::sycl::ext::oneapi::experimental::properties kernel_props{ + ::sycl::ext::intel::experimental::grf_size<128>}; + auto range = ::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range)); + ::sycl::ext::oneapi::experimental::launch_config config(range, kernel_props); + auto cgf = [&](::sycl::handler& cgh) { + ::sycl::ext::oneapi::experimental::nd_launch>( + cgh, config, ker); + }; + q.submit(cgf); +} + #ifdef __SYCL_DEVICE_ONLY__ #define SYCL_KERNEL_STRING(var, str) \ static const __attribute__((opencl_constant)) char var[] = str