From fffb0b58d58d00e04cb987a28b64b40b2eada704 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 10 Sep 2024 19:23:39 -0700 Subject: [PATCH 1/6] Add support for integer division to TCP --- include/mlir-tcp/Dialect/IR/TcpEnums.td | 18 ++++++++++ include/mlir-tcp/Dialect/IR/TcpOps.td | 21 +++++++++++ lib/Conversion/TcpToLinalg/Elementwise.cpp | 30 ++++++++++++++-- lib/Conversion/TorchToTcp/Elementwise.cpp | 41 +++++++++++++++------- lib/Conversion/TorchToTcp/Utils.cpp | 8 +++++ lib/Conversion/TorchToTcp/Utils.h | 3 ++ test/Pipeline/torch_to_tcp_pipeline.mlir | 5 ++- 7 files changed, 110 insertions(+), 16 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpEnums.td b/include/mlir-tcp/Dialect/IR/TcpEnums.td index 6d78072a..af422e18 100644 --- a/include/mlir-tcp/Dialect/IR/TcpEnums.td +++ b/include/mlir-tcp/Dialect/IR/TcpEnums.td @@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness", def Tcp_SignednessAttr : EnumAttr; +// TCP rounding mode +def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>; +def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>; +def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>; + +def Tcp_RoundingMode : I32EnumAttr<"RoundingMode", + "Rounding mode for integer operations which need a rounding mode", + [ + Tcp_RoundingMode_Trunc, + Tcp_RoundingMode_Floor, + Tcp_RoundingMode_Ceil + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tcp"; +} + +def Tcp_RoundingModeAttr : EnumAttr; + #endif // TCP_ENUMS diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 45fa0e90..99b8bafc 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -160,6 +160,27 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; } +def Tcp_DivIOp : Tcp_BinaryElementwiseOp<"divi", [SameOperandsAndResultElementType]> { + let summary = "Computes elementwise integer division"; + + let description = [{ + Computes the integer division of `in1` and `in2`. + }]; + + let arguments = (ins + Tcp_IntTensor:$in1, + Tcp_IntTensor:$in2, + Tcp_SignednessAttr:$signedness, + Tcp_RoundingModeAttr:$rounding_mode + ); + + let results = (outs + Tcp_IntTensor:$out + ); + + let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; +} + def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> { let summary = "Constant op"; diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 50dc553b..ebb51163 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -190,9 +190,35 @@ createLinalgPayloadForElementwiseOp(Operation *op, if (isa(op)) { if (elemType.isa()) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else + else if (elemType.isa()) { + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + } + } + + if (auto divOp = dyn_cast(op)) { + if (!elemType.isa()) llvm_unreachable("unsupported element type in " - "createLinalgPayloadForElementwiseOp for tcp.divf"); + "createLinalgPayloadForElementwiseOp for tcp.divi"); + if (divOp.getSignedness() == Signedness::Unsigned) { + if (divOp.getRoundingMode() == RoundingMode::Trunc || + divOp.getRoundingMode() == RoundingMode::Floor) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + } else if (divOp.getSignedness() == Signedness::Signed) { + if (divOp.getRoundingMode() == RoundingMode::Trunc) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else if (divOp.getRoundingMode() == RoundingMode::Ceil) + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + } else { + llvm_unreachable("unsupported signedness in " + "createLinalgPayloadForElementwiseOp for tcp.divi"); + } } if (isa(op)) { diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index d28f28da..5b2ae720 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); @@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - // TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are - // added - if (resultType.getElementType().isa()) { - return rewriter.notifyMatchFailure( - op, "Only floating point division supported for now"); - } - auto inputAType = op.getSelf() .getType() .template dyn_cast() @@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern { .template dyn_cast() .getDtype(); + Type inputBType = nullptr; if (isa(op)) { + inputBType = adaptor.getOther().getType(); + rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(), adaptor.getOther(), outputType, resultType.getElementType()); if (!rhs) return rewriter.notifyMatchFailure(op, "Unsupported rhs data type"); } else { - auto inputBType = op.getOther() - .getType() - .template dyn_cast() - .getDtype(); + inputBType = op.getOther() + .getType() + .template dyn_cast() + .getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); } @@ -337,7 +333,26 @@ class ConvertAtenDivOp : public OpConversionPattern { std::tie(lhs, rhs) = torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs); - rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + if (isa(outputType)) { + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + } else { + auto in1IntType = cast(inputAType); + auto in2IntType = cast(inputBType); + auto outIntType = cast(outputType); + if ((in1IntType.getSignedness() != in2IntType.getSignedness()) || + (in1IntType.getSignedness() != outIntType.getSignedness())) + return rewriter.notifyMatchFailure(op, + "Mixed signedness not supported"); + if (in1IntType.getSignedness() == + mlir::IntegerType::SignednessSemantics::Signless) + return rewriter.notifyMatchFailure( + op, "Signless division not supported in TCP"); + + rewriter.replaceOpWithNewOp( + op, resultType, lhs, rhs, + torch_to_tcp::getTcpSignedness(outIntType.getSignedness()), + tcp::RoundingMode::Trunc); + } return success(); } }; diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index c762ff13..85a52673 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context, return SignednessAttr::get(context, Signedness::Unsigned); } +Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) { + if (signednessInfo == IntegerType::SignednessSemantics::Signless) + return Signedness::Signless; + if (signednessInfo == IntegerType::SignednessSemantics::Signed) + return Signedness::Signed; + return Signedness::Unsigned; +} + // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index 4ebfd122..42156eab 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr getTcpSignednessAttr(MLIRContext *context, IntegerType::SignednessSemantics signednessInfo); +mlir::tcp::Signedness +getTcpSignedness(IntegerType::SignednessSemantics signednessInfo); + // Helper function to expand the rank of the input tensor. Works by // adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, diff --git a/test/Pipeline/torch_to_tcp_pipeline.mlir b/test/Pipeline/torch_to_tcp_pipeline.mlir index 83c08bbb..edc055a5 100644 --- a/test/Pipeline/torch_to_tcp_pipeline.mlir +++ b/test/Pipeline/torch_to_tcp_pipeline.mlir @@ -108,8 +108,11 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // ----- +// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp, out_int_signedness = #tcp} : tensor -> tensor +// CHECK: %[[V1:.+]] = tcp.divi %[[V0]], %[[ARG1]] {rounding_mode = #tcp, signedness = #tcp} : tensor, tensor -> tensor +// CHECK: return %[[V1]] : tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { - // expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}} %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> } From 7a5ebc87296a9d93ab8b8376b4d1e01b448c099b Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Thu, 12 Sep 2024 09:52:27 -0700 Subject: [PATCH 2/6] fixed missing location --- lib/Conversion/TcpToLinalg/Elementwise.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index ebb51163..13ee75b6 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -356,6 +356,7 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality( INSERT_TCP_TO_LINALG_PATTERN(ClampOp); INSERT_TCP_TO_LINALG_PATTERN(MulOp); INSERT_TCP_TO_LINALG_PATTERN(DivFOp); + INSERT_TCP_TO_LINALG_PATTERN(DivIOp); INSERT_TCP_TO_LINALG_PATTERN(SubOp); INSERT_TCP_TO_LINALG_PATTERN(TanhOp); INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp); From 52a95dc3bdc15b4810d58d149a65ed50ddc3e8c8 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 17 Sep 2024 12:27:57 -0700 Subject: [PATCH 3/6] Use separate ops for signed/unsigned integer divison --- include/mlir-tcp/Dialect/IR/TcpOps.td | 27 +++++++++++-- lib/Conversion/TcpToLinalg/Elementwise.cpp | 45 ++++++++++------------ lib/Conversion/TorchToTcp/Elementwise.cpp | 11 ++++-- test/Pipeline/torch_to_tcp_pipeline.mlir | 2 +- 4 files changed, 51 insertions(+), 34 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 99b8bafc..d73d28b5 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -160,17 +160,36 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; } -def Tcp_DivIOp : Tcp_BinaryElementwiseOp<"divi", [SameOperandsAndResultElementType]> { - let summary = "Computes elementwise integer division"; +def Tcp_DivSIOp : Tcp_BinaryElementwiseOp<"divsi", [SameOperandsAndResultElementType]> { + let summary = "Computes elementwise signed integer division"; let description = [{ - Computes the integer division of `in1` and `in2`. + Computes the signed integer division of `in1` and `in2`. + }]; + + let arguments = (ins + Tcp_IntTensor:$in1, + Tcp_IntTensor:$in2, + Tcp_RoundingModeAttr:$rounding_mode + ); + + let results = (outs + Tcp_IntTensor:$out + ); + + let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; +} + +def Tcp_DivUIOp : Tcp_BinaryElementwiseOp<"divui", [SameOperandsAndResultElementType]> { + let summary = "Computes elementwise unsigned integer division"; + + let description = [{ + Computes the unsigned integer division of `in1` and `in2`. }]; let arguments = (ins Tcp_IntTensor:$in1, Tcp_IntTensor:$in2, - Tcp_SignednessAttr:$signedness, Tcp_RoundingModeAttr:$rounding_mode ); diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 13ee75b6..4232deb3 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -195,30 +195,24 @@ createLinalgPayloadForElementwiseOp(Operation *op, } } - if (auto divOp = dyn_cast(op)) { - if (!elemType.isa()) - llvm_unreachable("unsupported element type in " - "createLinalgPayloadForElementwiseOp for tcp.divi"); - if (divOp.getSignedness() == Signedness::Unsigned) { - if (divOp.getRoundingMode() == RoundingMode::Trunc || - divOp.getRoundingMode() == RoundingMode::Floor) - return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else - return { - b.create(loc, payloadArgs[0], payloadArgs[1])}; - } else if (divOp.getSignedness() == Signedness::Signed) { - if (divOp.getRoundingMode() == RoundingMode::Trunc) - return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else if (divOp.getRoundingMode() == RoundingMode::Ceil) - return { - b.create(loc, payloadArgs[0], payloadArgs[1])}; - else - return { - b.create(loc, payloadArgs[0], payloadArgs[1])}; - } else { - llvm_unreachable("unsupported signedness in " - "createLinalgPayloadForElementwiseOp for tcp.divi"); - } + if (auto divOp = dyn_cast(op)) { + if (divOp.getRoundingMode() == RoundingMode::Trunc) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else if (divOp.getRoundingMode() == RoundingMode::Ceil) + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + } + + if (auto divOp = dyn_cast(op)) { + if (divOp.getRoundingMode() == RoundingMode::Trunc || + divOp.getRoundingMode() == RoundingMode::Floor) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; } if (isa(op)) { @@ -356,7 +350,8 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality( INSERT_TCP_TO_LINALG_PATTERN(ClampOp); INSERT_TCP_TO_LINALG_PATTERN(MulOp); INSERT_TCP_TO_LINALG_PATTERN(DivFOp); - INSERT_TCP_TO_LINALG_PATTERN(DivIOp); + INSERT_TCP_TO_LINALG_PATTERN(DivSIOp); + INSERT_TCP_TO_LINALG_PATTERN(DivUIOp); INSERT_TCP_TO_LINALG_PATTERN(SubOp); INSERT_TCP_TO_LINALG_PATTERN(TanhOp); INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp); diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index 5b2ae720..3ac8c86f 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -348,10 +348,13 @@ class ConvertAtenDivOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Signless division not supported in TCP"); - rewriter.replaceOpWithNewOp( - op, resultType, lhs, rhs, - torch_to_tcp::getTcpSignedness(outIntType.getSignedness()), - tcp::RoundingMode::Trunc); + if (outIntType.getSignedness() == + mlir::IntegerType::SignednessSemantics::Unsigned) + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + tcp::RoundingMode::Trunc); + else + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + tcp::RoundingMode::Trunc); } return success(); } diff --git a/test/Pipeline/torch_to_tcp_pipeline.mlir b/test/Pipeline/torch_to_tcp_pipeline.mlir index edc055a5..896dc1d3 100644 --- a/test/Pipeline/torch_to_tcp_pipeline.mlir +++ b/test/Pipeline/torch_to_tcp_pipeline.mlir @@ -110,7 +110,7 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { // CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp, out_int_signedness = #tcp} : tensor -> tensor -// CHECK: %[[V1:.+]] = tcp.divi %[[V0]], %[[ARG1]] {rounding_mode = #tcp, signedness = #tcp} : tensor, tensor -> tensor +// CHECK: %[[V1:.+]] = tcp.divsi %[[V0]], %[[ARG1]] {rounding_mode = #tcp} : tensor, tensor -> tensor // CHECK: return %[[V1]] : tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> From 17e1ad205bdc1433f5ebb2042c2913813c2ac8ec Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 17 Sep 2024 12:33:33 -0700 Subject: [PATCH 4/6] bit more cleanup --- lib/Conversion/TcpToLinalg/Elementwise.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 4232deb3..41782d23 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -190,12 +190,15 @@ createLinalgPayloadForElementwiseOp(Operation *op, if (isa(op)) { if (elemType.isa()) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else if (elemType.isa()) { - return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - } + else + llvm_unreachable("unsupported element type in " + "createLinalgPayloadForElementwiseOp for tcp.divf"); } if (auto divOp = dyn_cast(op)) { + if (!elemType.isa()) + llvm_unreachable("unsupported element type in " + "createLinalgPayloadForElementwiseOp for tcp.divsi"); if (divOp.getRoundingMode() == RoundingMode::Trunc) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else if (divOp.getRoundingMode() == RoundingMode::Ceil) @@ -207,6 +210,9 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (auto divOp = dyn_cast(op)) { + if (!elemType.isa()) + llvm_unreachable("unsupported element type in " + "createLinalgPayloadForElementwiseOp for tcp.divui"); if (divOp.getRoundingMode() == RoundingMode::Trunc || divOp.getRoundingMode() == RoundingMode::Floor) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; From e965953615b02004aded36907e71b43c248f769d Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 17 Sep 2024 12:38:23 -0700 Subject: [PATCH 5/6] more lit tests --- test/Pipeline/torch_to_tcp_pipeline.mlir | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/Pipeline/torch_to_tcp_pipeline.mlir b/test/Pipeline/torch_to_tcp_pipeline.mlir index 896dc1d3..33c45be7 100644 --- a/test/Pipeline/torch_to_tcp_pipeline.mlir +++ b/test/Pipeline/torch_to_tcp_pipeline.mlir @@ -116,3 +116,22 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> } + +// ----- + +// CHECK: func.func @torch.aten.div.Tensor$mixed_type_uint(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp, out_int_signedness = #tcp} : tensor -> tensor +// CHECK: %[[V1:.+]] = tcp.divui %[[V0]], %[[ARG1]] {rounding_mode = #tcp} : tensor, tensor -> tensor +// CHECK: return %[[V1]] : tensor +func.func @torch.aten.div.Tensor$mixed_type_uint(%arg0: !torch.vtensor<[?, ?],ui16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],ui16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32> + return %0 : !torch.vtensor<[?, ?],ui32> +} + +// ----- + +func.func @torch.aten.div.Tensor$mixed_signed_int_div(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> { + // expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}} + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32> + return %0 : !torch.vtensor<[?, ?],ui32> +} From 7fc2f1bee37f3d4d04acea63b932f70c6eb8d0b3 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Wed, 18 Sep 2024 06:14:23 -0700 Subject: [PATCH 6/6] PR feedback --- include/mlir-tcp/Dialect/IR/TcpEnums.td | 4 ++-- lib/Conversion/TcpToLinalg/Elementwise.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpEnums.td b/include/mlir-tcp/Dialect/IR/TcpEnums.td index af422e18..e4aa98f2 100644 --- a/include/mlir-tcp/Dialect/IR/TcpEnums.td +++ b/include/mlir-tcp/Dialect/IR/TcpEnums.td @@ -34,8 +34,8 @@ def Tcp_SignednessAttr : EnumAttr; // TCP rounding mode def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>; -def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>; -def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>; +def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>; +def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>; def Tcp_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode for integer operations which need a rounding mode", diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 41782d23..3ff292b4 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -203,7 +203,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else if (divOp.getRoundingMode() == RoundingMode::Ceil) return { - b.create(loc, payloadArgs[0], payloadArgs[1])}; + b.create(loc, payloadArgs[0], payloadArgs[1])}; else return { b.create(loc, payloadArgs[0], payloadArgs[1])};