Skip to content

Commit

Permalink
[torch] Supporting torch.aten.mul.float lowering to arith (llvm#2833
Browse files Browse the repository at this point in the history
)

Simple missing scalar operation for multiply floats was missing.
  • Loading branch information
rsuderman authored Feb 6, 2024
1 parent e3faef5 commit 041a54a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,11 @@ class ConvertTorchToArith
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
target.addIllegalOp<AtenSubFloatOp>();
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
typeConverter, context);
target.addIllegalOp<AtenDivIntOp>();
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
target.addIllegalOp<AtenDivFloatOp>();
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@

# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
'AtenSubFloatModule_basic',
'AtenMulFloatModule_basic',
'BoolFloatFalseModule_basic',
'BoolFloatTrueModule_basic',
'CeilFloatModule_basic',
Expand All @@ -109,6 +110,7 @@
'GtFloatIntModule_basic',
'NeFloatIntModule_basic',
'SubFloatModule_basic',
'MulFloatModule_basic',
'TensorToFloatZeroRank_basic',
'TensorToFloat_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
Expand Down Expand Up @@ -1489,6 +1491,7 @@
"SliceStartEqEndModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"MulFloatModule_basic",
"SubIntModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic",
Expand Down
22 changes: 22 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double(), tu.rand().double())


# ==============================================================================

class MulFloatModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.float64, True),
])
def forward(self, lhs, rhs):
return float(lhs) * float(rhs)


@register_test_case(module_factory=lambda: MulFloatModule())
def MulFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double(), tu.rand().double())


# ==============================================================================


Expand Down

0 comments on commit 041a54a

Please sign in to comment.