From 041a54ae0c29e04703f0d9616bc4effebf2e6998 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 5 Feb 2024 16:23:04 -0800 Subject: [PATCH] [torch] Supporting `torch.aten.mul.float` lowering to `arith` (#2833) Simple missing scalar operation for multiply floats was missing. --- lib/Conversion/TorchToArith/TorchToArith.cpp | 4 +++- projects/pt1/e2e_testing/xfail_sets.py | 3 +++ .../torch_mlir_e2e_test/test_suite/scalar.py | 22 +++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index d2000d7fc3d2..0ca2d108a5e3 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -443,9 +443,11 @@ class ConvertTorchToArith typeConverter, context); patterns.add>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2ee5d279a9d3..973f75a2637a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -100,6 +100,7 @@ # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} 'AtenSubFloatModule_basic', + 'AtenMulFloatModule_basic', 'BoolFloatFalseModule_basic', 'BoolFloatTrueModule_basic', 'CeilFloatModule_basic', @@ -109,6 +110,7 @@ 'GtFloatIntModule_basic', 'NeFloatIntModule_basic', 'SubFloatModule_basic', + 'MulFloatModule_basic', 'TensorToFloatZeroRank_basic', 'TensorToFloat_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} @@ -1489,6 +1491,7 @@ "SliceStartEqEndModule_basic", "SqrtIntModule_basic", "SubFloatModule_basic", + "MulFloatModule_basic", "SubIntModule_basic", "TensorsStackPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 303c3f0a801a..51b9fb993088 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand().double(), tu.rand().double()) +# ============================================================================== + +class MulFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) * float(rhs) + + +@register_test_case(module_factory=lambda: MulFloatModule()) +def MulFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double(), tu.rand().double()) + + # ==============================================================================