From 4fe55f559216493b6f8296fcd20a4e2a5673fa17 Mon Sep 17 00:00:00 2001 From: Greg Shikhman Date: Tue, 17 Sep 2024 17:44:35 -0400 Subject: [PATCH] Fix unbind op. (#8033) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b95476d413..7b7004d3eae 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -190,7 +190,6 @@ "take_along_dim", "to_sparse", # We are not supporting sparse tensors yet. "triu", - "unbind", "unfold_copy", "unfold", "unique_consecutive", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a7ffa8ddec9..ea7b60484f9 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2396,13 +2396,9 @@ def _aten_trunc(a): return jnp.trunc(a) -@op(torch.ops.aten.unbind) @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): - return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim]) - ) + return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])] # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d