-
Notifications
You must be signed in to change notification settings - Fork 42
Fix deterministic indexing with broadcast #1705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4f94483
17b2e10
c6744b8
b2f566d
d86da37
24c9870
acd7ee3
53e2ed3
3a18757
abc5b6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -609,6 +609,21 @@ void index_put_kernel( | |
} | ||
} | ||
|
||
DimVector valsShape( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the helper implemented in PyTorch? Could we share the implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot use the implementation from PyTorch, it is not in the header file. |
||
IntArrayRef self_sizes, | ||
int64_t dims_before, | ||
int64_t dims_indexed, | ||
IntArrayRef replacement_shape) { | ||
auto shape = DimVector(self_sizes); | ||
int64_t end = dims_before + dims_indexed; | ||
shape.erase(shape.begin() + dims_before, shape.begin() + end); | ||
shape.insert( | ||
shape.begin() + dims_before, | ||
replacement_shape.begin(), | ||
replacement_shape.end()); | ||
return shape; | ||
} | ||
|
||
void index_put_deterministic_kernel( | ||
Tensor& self, | ||
const c10::List<std::optional<Tensor>>& indices, | ||
|
@@ -633,30 +648,21 @@ void index_put_deterministic_kernel( | |
bool self_contiguous = self.is_contiguous(); | ||
auto self_ = self_contiguous ? self : self.contiguous(); | ||
Tensor linearIndex, src, expandedValue = value; | ||
int64_t nElemBefore, strideBefore, sliceSize; | ||
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed; | ||
std::vector<int64_t> inversePerm; | ||
std::tie( | ||
linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = | ||
makeLinearIndex(self_, indices, !unsafe); | ||
linearIndex, | ||
src, | ||
nElemBefore, | ||
strideBefore, | ||
sliceSize, | ||
inversePerm, | ||
dims_before, | ||
dims_indexed) = makeLinearIndex(self_, indices, !unsafe); | ||
auto vals_shape = | ||
valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes()); | ||
int64_t num_indices = linearIndex.numel(); | ||
|
||
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) { | ||
auto expanded_size = at::DimVector(expandedValue.sizes()); | ||
|
||
auto size1 = expandedValue.sizes(); | ||
auto size2 = linearIndex.sizes(); | ||
if (are_expandable(size1, size2)) { | ||
expanded_size = infer_size_dimvector(size1, size2); | ||
} | ||
if (nElemBefore > 1) { | ||
expanded_size.insert(expanded_size.begin(), nElemBefore); | ||
} | ||
if (sliceSize > 1) { | ||
expanded_size.insert(expanded_size.end(), sliceSize); | ||
} | ||
expandedValue = expandedValue.expand(expanded_size); | ||
} | ||
expandedValue = expandedValue.contiguous(); | ||
expandedValue = expandedValue.expand(vals_shape).contiguous(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A question: Does the CUDA indexing kernel support expand case without making the tensor contiguous? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cuda also need to make the tensor contiguous |
||
|
||
if (num_indices > 0 && sliceSize > 0) { | ||
const bool permuted = !src.is_contiguous(); | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -57,10 +57,8 @@ static std::vector<int64_t> computeLinearStride(const Tensor& tensor) { | |||||||||
return stride; | ||||||||||
} | ||||||||||
|
||||||||||
static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex( | ||||||||||
const Tensor& src, | ||||||||||
TensorList indices, | ||||||||||
bool check_range) { | ||||||||||
static std::tuple<Tensor, int64_t, int64_t, int64_t, int64_t, int64_t> | ||||||||||
computeLinearIndex(const Tensor& src, TensorList indices, bool check_range) { | ||||||||||
auto strides = computeLinearStride(src); | ||||||||||
const auto& device = src.options().device(); | ||||||||||
|
||||||||||
|
@@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex( | |||||||||
// are not being index. | ||||||||||
Tensor linearIndex; | ||||||||||
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new tuple fields
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
int64_t dims_before = 0, dims_indexed = 0; | ||||||||||
for (const auto i : c10::irange(src.dim())) { | ||||||||||
if (indices[i].defined()) { | ||||||||||
dims_indexed++; | ||||||||||
// Cast index to the longType matching src's device | ||||||||||
// This allows us to support ie indexing a xpu tensor with a cpu tensor | ||||||||||
Tensor index = | ||||||||||
|
@@ -88,17 +88,30 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex( | |||||||||
} else if (linearIndex.defined()) { | ||||||||||
nElemAfter *= src.size(i); | ||||||||||
} else { | ||||||||||
dims_before++; | ||||||||||
nElemBefore *= src.size(i); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
return std::make_tuple( | ||||||||||
std::move(linearIndex), nElemBefore, strideBefore, nElemAfter); | ||||||||||
std::move(linearIndex), | ||||||||||
nElemBefore, | ||||||||||
strideBefore, | ||||||||||
nElemAfter, | ||||||||||
dims_before, | ||||||||||
dims_indexed); | ||||||||||
} | ||||||||||
|
||||||||||
static std:: | ||||||||||
tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> | ||||||||||
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { | ||||||||||
static std::tuple< | ||||||||||
Tensor, | ||||||||||
Tensor, | ||||||||||
int64_t, | ||||||||||
int64_t, | ||||||||||
int64_t, | ||||||||||
std::vector<int64_t>, | ||||||||||
int64_t, | ||||||||||
int64_t> | ||||||||||
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { | ||||||||||
checkIndexTensorTypes(orig, /*allow_int*/ true); | ||||||||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more | ||||||||||
// LongTensors | ||||||||||
|
@@ -121,10 +134,22 @@ static std:: | |||||||||
std::tie(self, indices, inversePerm) = | ||||||||||
transposeToFrontAndInvPerm(self, indices); | ||||||||||
} | ||||||||||
auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = | ||||||||||
computeLinearIndex(self, indices, check_range); | ||||||||||
auto | ||||||||||
[linearIndex, | ||||||||||
nElemBefore, | ||||||||||
strideBefore, | ||||||||||
nElemAfter, | ||||||||||
dims_before, | ||||||||||
dims_indexed] = computeLinearIndex(self, indices, check_range); | ||||||||||
return std::make_tuple( | ||||||||||
linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm); | ||||||||||
linearIndex, | ||||||||||
self, | ||||||||||
nElemBefore, | ||||||||||
strideBefore, | ||||||||||
nElemAfter, | ||||||||||
inversePerm, | ||||||||||
dims_before, | ||||||||||
dims_indexed); | ||||||||||
} | ||||||||||
|
||||||||||
} // namespace at::native::xpu |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||||||
# Owner(s): ["module: intel"] | ||||||||||
|
||||||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests | ||||||||||
from torch.testing._internal.common_utils import run_tests | ||||||||||
from torch.testing._internal.common_utils import DeterministicGuard, run_tests | ||||||||||
|
||||||||||
try: | ||||||||||
from xpu_test_utils import XPUPatchForImport | ||||||||||
|
@@ -14,15 +14,15 @@ | |||||||||
|
||||||||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu | ||||||||||
|
||||||||||
def __test_index_put_accumulate_with_optional_tensors(self, device): | ||||||||||
# TODO: replace with a better solution. | ||||||||||
# Currently, here using torchscript to put None into indices. | ||||||||||
# on C++ it gives indices as a list of 2 optional tensors: first is null and | ||||||||||
# the second is a valid tensor. | ||||||||||
@torch.jit.script | ||||||||||
def __test_index_put_deterministic_with_optional_tensors(self, device): | ||||||||||
def func(x, i, v): | ||||||||||
idx = [None, i] | ||||||||||
x.index_put_(idx, v, accumulate=True) | ||||||||||
with DeterministicGuard(True): | ||||||||||
x[..., i] = v | ||||||||||
return x | ||||||||||
|
||||||||||
def func1(x, i, v): | ||||||||||
with DeterministicGuard(True): | ||||||||||
x[i] = v | ||||||||||
return x | ||||||||||
|
||||||||||
n = 4 | ||||||||||
|
@@ -32,17 +32,38 @@ def func(x, i, v): | |||||||||
indices_dev = indices.to(device) | ||||||||||
value0d = torch.tensor(10.0) | ||||||||||
value1d = torch.tensor([1.0, 2.0]) | ||||||||||
values2d = torch.randn(n, 1) | ||||||||||
|
||||||||||
out_cuda = func(t_dev, indices_dev, value0d.xpu()) | ||||||||||
out_cpu = func(t, indices, value0d) | ||||||||||
for val in (value0d, value1d, values2d): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] The loop reuses the same
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
out_cuda = func(t_dev, indices_dev, val.to(device)) | ||||||||||
out_cpu = func(t, indices, val) | ||||||||||
self.assertEqual(out_cuda.cpu(), out_cpu) | ||||||||||
|
||||||||||
t = torch.zeros((5, 4)) | ||||||||||
t_dev = t.to(device) | ||||||||||
indices = torch.tensor([1, 4, 3]) | ||||||||||
indices_dev = indices.to(device) | ||||||||||
val = torch.randn(4) | ||||||||||
out_cuda = func1(t_dev, indices_dev, val.xpu()) | ||||||||||
out_cpu = func1(t, indices, val) | ||||||||||
self.assertEqual(out_cuda.cpu(), out_cpu) | ||||||||||
|
||||||||||
out_cuda = func(t_dev, indices_dev, value1d.xpu()) | ||||||||||
out_cpu = func(t, indices, value1d) | ||||||||||
t = torch.zeros(2, 3, 4) | ||||||||||
ind = torch.tensor([0, 1]) | ||||||||||
val = torch.randn(6, 2) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests use different error regexes (
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
with self.assertRaisesRegex(RuntimeError, "shape mismatch"): | ||||||||||
func(t, ind, val) | ||||||||||
|
||||||||||
with self.assertRaisesRegex(RuntimeError, "must match"): | ||||||||||
func(t.to(device), ind.to(device), val.to(device)) | ||||||||||
|
||||||||||
val = torch.randn(2, 3, 1) | ||||||||||
out_cuda = func1(t.to(device), ind.to(device), val.to(device)) | ||||||||||
out_cpu = func1(t, ind, val) | ||||||||||
self.assertEqual(out_cuda.cpu(), out_cpu) | ||||||||||
|
||||||||||
TestIndexing.test_index_put_accumulate_with_optional_tensors = ( | ||||||||||
__test_index_put_accumulate_with_optional_tensors | ||||||||||
TestIndexing.test_index_put_deterministic_with_optional_tensors = ( | ||||||||||
__test_index_put_deterministic_with_optional_tensors | ||||||||||
) | ||||||||||
|
||||||||||
instantiate_device_type_tests(NumpyTests, globals(), only_for=("xpu"), allow_xpu=True) | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider marking
valsShape
asstatic inline
or moving its declaration to the header with a doc comment, so its purpose and usage are clearer and the compiler can inline it across translation units.Copilot uses AI. Check for mistakes.