From 526735419f68c9923bb75b819c576db6e756719b Mon Sep 17 00:00:00 2001 From: Anish Karthik <89824626+anishfish2@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:39:54 -0700 Subject: [PATCH] Fixed masked_scatter and masked_select (#8037) --- experimental/torch_xla2/test/test_ops.py | 2 -- .../torch_xla2/torch_xla2/ops/jaten.py | 35 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3ac2ab83dfc..1638dff404a 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -78,8 +78,6 @@ "lu_solve", "lu_unpack", "masked.median", - "masked_scatter", - "masked_select", "max_pool2d_with_indices_backward", "min", "mode", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 2e0e5ed35c5..35e01da90ae 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1366,6 +1366,41 @@ def _aten_scatter_add(input, dim, index, src): input_indexes, source_indexes = _scatter_index(dim, index) return input.at[input_indexes].add(src[source_indexes]) +# aten.masked_scatter +@op(torch.ops.aten.masked_scatter) +def _aten_masked_scatter(self, mask, source): + + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + elif mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) + + self_flat = self.flatten() + mask_flat = mask.flatten() + source_flat = source.flatten() + + true_indices = jnp.where(mask_flat)[0] + self_flat = self_flat.at[true_indices].set(source_flat[:len(true_indices)]) + final_arr = self_flat.reshape(self.shape) + + return final_arr + +@op(torch.ops.aten.masked_select) +def _aten_masked_select(self, mask, *args, **kwargs): + broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) + + if self.shape != broadcast_shape: + self = jnp.broadcast_to(self, broadcast_shape) + if mask.shape != broadcast_shape: + mask = jnp.broadcast_to(mask, broadcast_shape) + + self_flat = self.flatten() + mask_flat = mask.flatten() + true_indices = jnp.where(mask_flat)[0] + + return self_flat[true_indices] # aten.logical_not