Skip to content

Fix deterministic indexing with broadcast #1705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,21 @@ void index_put_kernel(
}
}

DimVector valsShape(
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider marking valsShape as static inline or moving its declaration to the header with a doc comment, so its purpose and usage are clearer and the compiler can inline it across translation units.

Copilot uses AI. Check for mistakes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the helper implemented in PyTorch? Could we share the implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use the implementation from PyTorch, it is not in the header file.
Ref:https://github.com/pytorch/pytorch/blob/28cb3c0fe5dec58c595617066acd8bd082aa867e/aten/src/ATen/native/cuda/Indexing.cu#L645

IntArrayRef self_sizes,
int64_t dims_before,
int64_t dims_indexed,
IntArrayRef replacement_shape) {
auto shape = DimVector(self_sizes);
int64_t end = dims_before + dims_indexed;
shape.erase(shape.begin() + dims_before, shape.begin() + end);
shape.insert(
shape.begin() + dims_before,
replacement_shape.begin(),
replacement_shape.end());
return shape;
}

void index_put_deterministic_kernel(
Tensor& self,
const c10::List<std::optional<Tensor>>& indices,
Expand All @@ -633,30 +648,21 @@ void index_put_deterministic_kernel(
bool self_contiguous = self.is_contiguous();
auto self_ = self_contiguous ? self : self.contiguous();
Tensor linearIndex, src, expandedValue = value;
int64_t nElemBefore, strideBefore, sliceSize;
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed;
std::vector<int64_t> inversePerm;
std::tie(
linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) =
makeLinearIndex(self_, indices, !unsafe);
linearIndex,
src,
nElemBefore,
strideBefore,
sliceSize,
inversePerm,
dims_before,
dims_indexed) = makeLinearIndex(self_, indices, !unsafe);
auto vals_shape =
valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
int64_t num_indices = linearIndex.numel();

if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
auto expanded_size = at::DimVector(expandedValue.sizes());

auto size1 = expandedValue.sizes();
auto size2 = linearIndex.sizes();
if (are_expandable(size1, size2)) {
expanded_size = infer_size_dimvector(size1, size2);
}
if (nElemBefore > 1) {
expanded_size.insert(expanded_size.begin(), nElemBefore);
}
if (sliceSize > 1) {
expanded_size.insert(expanded_size.end(), sliceSize);
}
expandedValue = expandedValue.expand(expanded_size);
}
expandedValue = expandedValue.contiguous();
expandedValue = expandedValue.expand(vals_shape).contiguous();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question: Does the CUDA indexing kernel support expand case without making the tensor contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if (num_indices > 0 && sliceSize > 0) {
const bool permuted = !src.is_contiguous();
Expand Down
47 changes: 36 additions & 11 deletions src/ATen/native/xpu/sycl/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ static std::vector<int64_t> computeLinearStride(const Tensor& tensor) {
return stride;
}

static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
const Tensor& src,
TensorList indices,
bool check_range) {
static std::tuple<Tensor, int64_t, int64_t, int64_t, int64_t, int64_t>
computeLinearIndex(const Tensor& src, TensorList indices, bool check_range) {
auto strides = computeLinearStride(src);
const auto& device = src.options().device();

Expand All @@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
// are not being index.
Tensor linearIndex;
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tuple fields dims_before and dims_indexed would benefit from a brief inline comment explaining their meaning and relationship to the indexing algorithm.

Suggested change
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
// `dims_before` counts the number of dimensions before the indexed dimensions.
// `dims_indexed` counts the number of dimensions that are being indexed.

Copilot uses AI. Check for mistakes.

int64_t dims_before = 0, dims_indexed = 0;
for (const auto i : c10::irange(src.dim())) {
if (indices[i].defined()) {
dims_indexed++;
// Cast index to the longType matching src's device
// This allows us to support ie indexing a xpu tensor with a cpu tensor
Tensor index =
Expand All @@ -88,17 +88,30 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
} else if (linearIndex.defined()) {
nElemAfter *= src.size(i);
} else {
dims_before++;
nElemBefore *= src.size(i);
}
}

return std::make_tuple(
std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
std::move(linearIndex),
nElemBefore,
strideBefore,
nElemAfter,
dims_before,
dims_indexed);
}

static std::
tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>>
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
static std::tuple<
Tensor,
Tensor,
int64_t,
int64_t,
int64_t,
std::vector<int64_t>,
int64_t,
int64_t>
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
// LongTensors
Expand All @@ -121,10 +134,22 @@ static std::
std::tie(self, indices, inversePerm) =
transposeToFrontAndInvPerm(self, indices);
}
auto [linearIndex, nElemBefore, strideBefore, nElemAfter] =
computeLinearIndex(self, indices, check_range);
auto
[linearIndex,
nElemBefore,
strideBefore,
nElemAfter,
dims_before,
dims_indexed] = computeLinearIndex(self, indices, check_range);
return std::make_tuple(
linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
linearIndex,
self,
nElemBefore,
strideBefore,
nElemAfter,
inversePerm,
dims_before,
dims_indexed);
}

} // namespace at::native::xpu
2 changes: 0 additions & 2 deletions test/xpu/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,6 @@
# https://github.com/intel/torch-xpu-ops/issues/461
"test_index_put_src_datatype_xpu_float8_e5m2",
"test_index_put_src_datatype_xpu_float8_e4m3fn",
# https://github.com/intel/torch-xpu-ops/issues/1702
"test_index_put_deterministic_with_optional_tensors_xpu",
),
"nn/test_pooling_xpu.py": None,
"nn/test_dropout_xpu.py": None,
Expand Down
51 changes: 36 additions & 15 deletions test/xpu/test_indexing_xpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Owner(s): ["module: intel"]

from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import DeterministicGuard, run_tests

try:
from xpu_test_utils import XPUPatchForImport
Expand All @@ -14,15 +14,15 @@

torch.Tensor.is_cuda = torch.Tensor.is_xpu

def __test_index_put_accumulate_with_optional_tensors(self, device):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
# on C++ it gives indices as a list of 2 optional tensors: first is null and
# the second is a valid tensor.
@torch.jit.script
def __test_index_put_deterministic_with_optional_tensors(self, device):
def func(x, i, v):
idx = [None, i]
x.index_put_(idx, v, accumulate=True)
with DeterministicGuard(True):
x[..., i] = v
return x

def func1(x, i, v):
with DeterministicGuard(True):
x[i] = v
return x

n = 4
Expand All @@ -32,17 +32,38 @@ def func(x, i, v):
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
values2d = torch.randn(n, 1)

out_cuda = func(t_dev, indices_dev, value0d.xpu())
out_cpu = func(t, indices, value0d)
for val in (value0d, value1d, values2d):
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The loop reuses the same t_dev/t across multiple func calls, mutating them cumulatively—consider reinitializing t and t_dev inside the loop to make each subtest independent.

Suggested change
for val in (value0d, value1d, values2d):
for val in (value0d, value1d, values2d):
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)

Copilot uses AI. Check for mistakes.

out_cuda = func(t_dev, indices_dev, val.to(device))
out_cpu = func(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

t = torch.zeros((5, 4))
t_dev = t.to(device)
indices = torch.tensor([1, 4, 3])
indices_dev = indices.to(device)
val = torch.randn(4)
out_cuda = func1(t_dev, indices_dev, val.xpu())
out_cpu = func1(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

out_cuda = func(t_dev, indices_dev, value1d.xpu())
out_cpu = func(t, indices, value1d)
t = torch.zeros(2, 3, 4)
ind = torch.tensor([0, 1])
val = torch.randn(6, 2)
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests use different error regexes ("shape mismatch" vs "must match") for CPU vs XPU—consider unifying the expected message or adding a brief comment explaining the discrepancy to prevent brittleness.

Suggested change
val = torch.randn(6, 2)
val = torch.randn(6, 2)
# The error messages differ between CPU ("shape mismatch") and XPU ("must match")
# due to implementation-specific differences in error handling.

Copilot uses AI. Check for mistakes.

with self.assertRaisesRegex(RuntimeError, "shape mismatch"):
func(t, ind, val)

with self.assertRaisesRegex(RuntimeError, "must match"):
func(t.to(device), ind.to(device), val.to(device))

val = torch.randn(2, 3, 1)
out_cuda = func1(t.to(device), ind.to(device), val.to(device))
out_cpu = func1(t, ind, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

TestIndexing.test_index_put_accumulate_with_optional_tensors = (
__test_index_put_accumulate_with_optional_tensors
TestIndexing.test_index_put_deterministic_with_optional_tensors = (
__test_index_put_deterministic_with_optional_tensors
)

instantiate_device_type_tests(NumpyTests, globals(), only_for=("xpu"), allow_xpu=True)
Expand Down
Loading