From b31f5c51862d88e079ee14bf22efbe79d5645cf3 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:29:22 -0700 Subject: [PATCH 1/3] add `resize_as_` --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index dc7eb8eee9c..bbf3ff1048f 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -151,7 +151,6 @@ "quantile", "repeat_interleave", "resize_", - "resize_as_", "rot90", "rsub", "scatter_add", From 9c8ae53ed98e008534a6875ebf1f6941080e0973 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:30:45 -0700 Subject: [PATCH 2/3] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index bbf3ff1048f..ed2e25003c4 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -150,7 +150,6 @@ "put", "quantile", "repeat_interleave", - "resize_", "rot90", "rsub", "scatter_add", From f66f8a48df8d7ae6c8c065d483a75eb1855ae2e4 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:32:35 -0700 Subject: [PATCH 3/3] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ea7b60484f9..ccddcfd2122 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -239,6 +239,16 @@ def _aten_real(x): return jnp.real(x) +@op(torch.ops.aten.resize_) +def _aten_resize_as_(x, y): + return jax.image.resize(x, size, method=interpolation) + + +@op(torch.ops.aten.resize_as_) +def _aten_resize_as_(x, y): + return jax.image.resize(x, y.shape, method='linear') + + @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): real = jnp.real(x)