Skip to content

Commit

Permalink
Fix nonzeros stride error. (#1366)
Browse files Browse the repository at this point in the history
Fix nonzeros stride error found in
pytorch/pytorch#146883.

---------

Co-authored-by: zhuyuhua-v <[email protected]>
  • Loading branch information
xiaowangintel and zhuyuhua-v authored Feb 19, 2025
1 parent 354725e commit d9918c6
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/ATen/native/xpu/sycl/NonzeroKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <comm/TensorInfo.h>

#include <ATen/native/xpu/sycl/NonzeroKernel.h>
#include <ATen/xpu/EmptyTensor.h>

namespace at::native::xpu {

Expand All @@ -14,21 +15,23 @@ struct FlattenIdxtoRealIdxKernelFunctor {
auto global_id = item_id.get_global_linear_id();

if (global_id < N_) {
auto index = global_id / num_dim_;
auto dim = global_id % num_dim_;
auto dim = global_id / num_nonzeros_;
auto index = global_id % num_nonzeros_;
tensor_begin_[global_id] =
idx_flat_begin_[index] / divisor_[dim] % sizes_[dim];
}
}
FlattenIdxtoRealIdxKernelFunctor(
int64_t N,
const int64_t num_dim,
const int64_t num_nonzeros,
int64_t* tensor_begin,
int64_t* idx_flat_begin,
int64_t* divisor,
int64_t* sizes)
: N_(N),
num_dim_(num_dim),
num_nonzeros_(num_nonzeros),
tensor_begin_(tensor_begin),
idx_flat_begin_(idx_flat_begin) {
for (auto dim = num_dim - 1; dim >= 0; dim--) {
Expand All @@ -40,6 +43,7 @@ struct FlattenIdxtoRealIdxKernelFunctor {
private:
int64_t N_;
const int64_t num_dim_;
const int64_t num_nonzeros_;
int64_t* tensor_begin_;
int64_t* idx_flat_begin_;
int64_t divisor_[XPU_MAX_TENSORINFO_DIMS];
Expand Down Expand Up @@ -99,7 +103,14 @@ void nonzero_template(const Tensor& self_, Tensor& tensor) {

auto num_nonzeros = std::distance(idx_flat_begin, idx_flat_end);

Tensor tensor_ = tensor.resize_({num_nonzeros, num_dim}).contiguous();
bool need_to_copy = tensor.dim() == 2 &&
tensor.sizes()[0] == num_nonzeros && tensor.sizes()[1] == self_.dim() &&
!tensor.t().is_contiguous();
at::Tensor tensor_ = need_to_copy
? Tensor(at::detail::empty_xpu(
{self_.dim(), num_nonzeros}, tensor.options()))
: tensor.resize_({self_.dim(), num_nonzeros});

if (num_nonzeros > 0 && num_dim > 0) {
int64_t* tensor_begin = tensor_.data_ptr<int64_t>();

Expand All @@ -116,22 +127,30 @@ void nonzero_template(const Tensor& self_, Tensor& tensor) {
const int64_t N = num_nonzeros * num_dim;
// restore flatten idx to indices
FlattenIdxtoRealIdxKernelFunctor kfn(
N, num_dim, tensor_begin, idx_flat_begin, divisor, sizes);
N,
num_dim,
num_nonzeros,
tensor_begin,
idx_flat_begin,
divisor,
sizes);

const auto wg_sz = std::min(syclMaxWorkGroupSize(kfn), N);
const auto num_wg = (N + wg_sz - 1) / wg_sz;

sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), kfn);

// Support non-contiguous/outplace cases
// TODO: Next step, we will give state of art algo/implementation.
// Non-contiguous/outplace cases performance will be covered there.
if (tensor.data_ptr() != tensor_.data_ptr()) {
tensor.copy_(tensor_);
}
}
if (need_to_copy) {
tensor.copy_(tensor_.t());
} else {
// transpose out so it is correct size
Tensor tensor_temp = tensor_.t();
tensor.set_(tensor_temp);
}

} else {
tensor = tensor.resize_({N, num_dim}).contiguous();
tensor = tensor.resize_({num_dim, N}).contiguous().t();
}
}

Expand Down

0 comments on commit d9918c6

Please sign in to comment.