From 1544ee16b4bf34b950d5ca10ce6156c10a7ee7df Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 17:06:35 -0700 Subject: [PATCH] [torchlib] Do not register rsub (#1907) Remove rsub since it is handled by decomp, and torch doesn't have a type promotion rule for rsub so we use sub instead. Tested with ```python import torch class Model(torch.nn.Module): def forward(self, x): return 1 - x ep = torch.export.export(Model(), (torch.tensor(1),)) print(ep) program = torch.onnx.export(Model(), (torch.tensor(1),), dynamo=True) print(program) ``` --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++--------- tests/function_libs/torch_lib/ops_test_data.py | 2 -- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 395f1fcac..9a6057150 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7352,18 +7352,11 @@ def aten_rsqrt(self: TFloat) -> TFloat: return op.Reciprocal(op.Sqrt(self)) -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar")) +# Do not register rsub. It will be decomposed and type promoted by torch def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - return op.Sub(other, op.Mul(self, alpha)) - - -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"), trace_only=True, complex=True) -def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: - """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - - return aten_rsub(self, other, alpha) + raise NotImplementedError @torch_op("aten::scalar_tensor", trace_only=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c180c1b71..35c691109 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1360,8 +1360,6 @@ def _where_input_wrangler( ), TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), - TorchLibOpInfo("rsub", core_ops.aten_rsub), - TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor,