From 83b2a0935cf3c196dbe99918f4535b9c59587949 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 23:40:19 +0000 Subject: [PATCH] [torchlib] Do not register rsub --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 395f1fcac..9a6057150 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7352,18 +7352,11 @@ def aten_rsqrt(self: TFloat) -> TFloat: return op.Reciprocal(op.Sqrt(self)) -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar")) +# Do not register rsub. It will be decomposed and type promoted by torch def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - return op.Sub(other, op.Mul(self, alpha)) - - -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"), trace_only=True, complex=True) -def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: - """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - - return aten_rsub(self, other, alpha) + raise NotImplementedError @torch_op("aten::scalar_tensor", trace_only=True)