Skip to content

Commit

Permalink
fix: convert scalar tensors to native python type for values in torch…
Browse files Browse the repository at this point in the history
… backend put_along_axis
  • Loading branch information
Sam-Armstrong committed Jul 9, 2024
1 parent cf9e5b4 commit ace8434
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ def put_along_axis(
}
mode = mode_mappings.get(mode, mode)
indices = indices.to(torch.int64)
if not isinstance(values, torch.Tensor):
values = torch.tensor(values)
if isinstance(values, torch.Tensor) and values.dim() == 0:
values = values.item()
if mode == "replace":
return torch.scatter(arr, axis, indices, values, out=out)
else:
Expand Down

0 comments on commit ace8434

Please sign in to comment.