Skip to content

Commit

Permalink
Fix negative dims in scatter_index
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Sep 19, 2024
1 parent c0501f0 commit 65c3333
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
"resize_as_",
"rot90",
"rsub",
"scatter_add",
"scatter",
"scatter_reduce",
"searchsorted",
Expand Down
2 changes: 2 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,8 @@ def _scatter_index(dim, index):
index_shape = list(index.shape)
input_indexes = []
source_indexes = []
if dim < 0:
dim += len(index_shape)
for i in range(len(index_shape)):
source_indexes.append(slice(0, index_shape[i]))
if i == dim:
Expand Down

0 comments on commit 65c3333

Please sign in to comment.