From d9918c646e12382783b19acb1d25953eb67925ee Mon Sep 17 00:00:00 2001 From: "Xiao, Wang" <109140002+xiaowangintel@users.noreply.github.com> Date: Wed, 19 Feb 2025 15:51:10 +0800 Subject: [PATCH] Fix nonzeros stride error. (#1366) Fix nonzeros stride error found in https://github.com/pytorch/pytorch/issues/146883. --------- Co-authored-by: zhuyuhua-v --- src/ATen/native/xpu/sycl/NonzeroKernel.cpp | 41 ++++++++++++++++------ 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/sycl/NonzeroKernel.cpp b/src/ATen/native/xpu/sycl/NonzeroKernel.cpp index 5b1a86973..84d1687c5 100644 --- a/src/ATen/native/xpu/sycl/NonzeroKernel.cpp +++ b/src/ATen/native/xpu/sycl/NonzeroKernel.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace at::native::xpu { @@ -14,8 +15,8 @@ struct FlattenIdxtoRealIdxKernelFunctor { auto global_id = item_id.get_global_linear_id(); if (global_id < N_) { - auto index = global_id / num_dim_; - auto dim = global_id % num_dim_; + auto dim = global_id / num_nonzeros_; + auto index = global_id % num_nonzeros_; tensor_begin_[global_id] = idx_flat_begin_[index] / divisor_[dim] % sizes_[dim]; } @@ -23,12 +24,14 @@ struct FlattenIdxtoRealIdxKernelFunctor { FlattenIdxtoRealIdxKernelFunctor( int64_t N, const int64_t num_dim, + const int64_t num_nonzeros, int64_t* tensor_begin, int64_t* idx_flat_begin, int64_t* divisor, int64_t* sizes) : N_(N), num_dim_(num_dim), + num_nonzeros_(num_nonzeros), tensor_begin_(tensor_begin), idx_flat_begin_(idx_flat_begin) { for (auto dim = num_dim - 1; dim >= 0; dim--) { @@ -40,6 +43,7 @@ struct FlattenIdxtoRealIdxKernelFunctor { private: int64_t N_; const int64_t num_dim_; + const int64_t num_nonzeros_; int64_t* tensor_begin_; int64_t* idx_flat_begin_; int64_t divisor_[XPU_MAX_TENSORINFO_DIMS]; @@ -99,7 +103,14 @@ void nonzero_template(const Tensor& self_, Tensor& tensor) { auto num_nonzeros = std::distance(idx_flat_begin, idx_flat_end); - Tensor tensor_ = tensor.resize_({num_nonzeros, num_dim}).contiguous(); + bool need_to_copy = tensor.dim() == 2 && + tensor.sizes()[0] == num_nonzeros && tensor.sizes()[1] == self_.dim() && + !tensor.t().is_contiguous(); + at::Tensor tensor_ = need_to_copy + ? Tensor(at::detail::empty_xpu( + {self_.dim(), num_nonzeros}, tensor.options())) + : tensor.resize_({self_.dim(), num_nonzeros}); + if (num_nonzeros > 0 && num_dim > 0) { int64_t* tensor_begin = tensor_.data_ptr(); @@ -116,22 +127,30 @@ void nonzero_template(const Tensor& self_, Tensor& tensor) { const int64_t N = num_nonzeros * num_dim; // restore flatten idx to indices FlattenIdxtoRealIdxKernelFunctor kfn( - N, num_dim, tensor_begin, idx_flat_begin, divisor, sizes); + N, + num_dim, + num_nonzeros, + tensor_begin, + idx_flat_begin, + divisor, + sizes); const auto wg_sz = std::min(syclMaxWorkGroupSize(kfn), N); const auto num_wg = (N + wg_sz - 1) / wg_sz; sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), kfn); - // Support non-contiguous/outplace cases - // TODO: Next step, we will give state of art algo/implementation. - // Non-contiguous/outplace cases performance will be covered there. - if (tensor.data_ptr() != tensor_.data_ptr()) { - tensor.copy_(tensor_); - } } + if (need_to_copy) { + tensor.copy_(tensor_.t()); + } else { + // transpose out so it is correct size + Tensor tensor_temp = tensor_.t(); + tensor.set_(tensor_temp); + } + } else { - tensor = tensor.resize_({N, num_dim}).contiguous(); + tensor = tensor.resize_({num_dim, N}).contiguous().t(); } }