diff --git a/include/mlir-tcp/Dialect/IR/TcpEnums.td b/include/mlir-tcp/Dialect/IR/TcpEnums.td index 6d78072a..e4aa98f2 100644 --- a/include/mlir-tcp/Dialect/IR/TcpEnums.td +++ b/include/mlir-tcp/Dialect/IR/TcpEnums.td @@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness", def Tcp_SignednessAttr : EnumAttr; +// TCP rounding mode +def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>; +def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>; +def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>; + +def Tcp_RoundingMode : I32EnumAttr<"RoundingMode", + "Rounding mode for integer operations which need a rounding mode", + [ + Tcp_RoundingMode_Trunc, + Tcp_RoundingMode_Floor, + Tcp_RoundingMode_Ceil + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tcp"; +} + +def Tcp_RoundingModeAttr : EnumAttr; + #endif // TCP_ENUMS diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 45fa0e90..d73d28b5 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -160,6 +160,46 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; } +def Tcp_DivSIOp : Tcp_BinaryElementwiseOp<"divsi", [SameOperandsAndResultElementType]> { + let summary = "Computes elementwise signed integer division"; + + let description = [{ + Computes the signed integer division of `in1` and `in2`. + }]; + + let arguments = (ins + Tcp_IntTensor:$in1, + Tcp_IntTensor:$in2, + Tcp_RoundingModeAttr:$rounding_mode + ); + + let results = (outs + Tcp_IntTensor:$out + ); + + let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; +} + +def Tcp_DivUIOp : Tcp_BinaryElementwiseOp<"divui", [SameOperandsAndResultElementType]> { + let summary = "Computes elementwise unsigned integer division"; + + let description = [{ + Computes the unsigned integer division of `in1` and `in2`. + }]; + + let arguments = (ins + Tcp_IntTensor:$in1, + Tcp_IntTensor:$in2, + Tcp_RoundingModeAttr:$rounding_mode + ); + + let results = (outs + Tcp_IntTensor:$out + ); + + let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; +} + def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> { let summary = "Constant op"; diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 50dc553b..3ff292b4 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -195,6 +195,32 @@ createLinalgPayloadForElementwiseOp(Operation *op, "createLinalgPayloadForElementwiseOp for tcp.divf"); } + if (auto divOp = dyn_cast(op)) { + if (!elemType.isa()) + llvm_unreachable("unsupported element type in " + "createLinalgPayloadForElementwiseOp for tcp.divsi"); + if (divOp.getRoundingMode() == RoundingMode::Trunc) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else if (divOp.getRoundingMode() == RoundingMode::Ceil) + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + } + + if (auto divOp = dyn_cast(op)) { + if (!elemType.isa()) + llvm_unreachable("unsupported element type in " + "createLinalgPayloadForElementwiseOp for tcp.divui"); + if (divOp.getRoundingMode() == RoundingMode::Trunc || + divOp.getRoundingMode() == RoundingMode::Floor) + return {b.create(loc, payloadArgs[0], payloadArgs[1])}; + else + return { + b.create(loc, payloadArgs[0], payloadArgs[1])}; + } + if (isa(op)) { if (elemType.isa()) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; @@ -330,6 +356,8 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality( INSERT_TCP_TO_LINALG_PATTERN(ClampOp); INSERT_TCP_TO_LINALG_PATTERN(MulOp); INSERT_TCP_TO_LINALG_PATTERN(DivFOp); + INSERT_TCP_TO_LINALG_PATTERN(DivSIOp); + INSERT_TCP_TO_LINALG_PATTERN(DivUIOp); INSERT_TCP_TO_LINALG_PATTERN(SubOp); INSERT_TCP_TO_LINALG_PATTERN(TanhOp); INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp); diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index d28f28da..3ac8c86f 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); @@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - // TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are - // added - if (resultType.getElementType().isa()) { - return rewriter.notifyMatchFailure( - op, "Only floating point division supported for now"); - } - auto inputAType = op.getSelf() .getType() .template dyn_cast() @@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern { .template dyn_cast() .getDtype(); + Type inputBType = nullptr; if (isa(op)) { + inputBType = adaptor.getOther().getType(); + rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(), adaptor.getOther(), outputType, resultType.getElementType()); if (!rhs) return rewriter.notifyMatchFailure(op, "Unsupported rhs data type"); } else { - auto inputBType = op.getOther() - .getType() - .template dyn_cast() - .getDtype(); + inputBType = op.getOther() + .getType() + .template dyn_cast() + .getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); } @@ -337,7 +333,29 @@ class ConvertAtenDivOp : public OpConversionPattern { std::tie(lhs, rhs) = torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs); - rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + if (isa(outputType)) { + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs); + } else { + auto in1IntType = cast(inputAType); + auto in2IntType = cast(inputBType); + auto outIntType = cast(outputType); + if ((in1IntType.getSignedness() != in2IntType.getSignedness()) || + (in1IntType.getSignedness() != outIntType.getSignedness())) + return rewriter.notifyMatchFailure(op, + "Mixed signedness not supported"); + if (in1IntType.getSignedness() == + mlir::IntegerType::SignednessSemantics::Signless) + return rewriter.notifyMatchFailure( + op, "Signless division not supported in TCP"); + + if (outIntType.getSignedness() == + mlir::IntegerType::SignednessSemantics::Unsigned) + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + tcp::RoundingMode::Trunc); + else + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, + tcp::RoundingMode::Trunc); + } return success(); } }; diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index c762ff13..85a52673 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context, return SignednessAttr::get(context, Signedness::Unsigned); } +Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) { + if (signednessInfo == IntegerType::SignednessSemantics::Signless) + return Signedness::Signless; + if (signednessInfo == IntegerType::SignednessSemantics::Signed) + return Signedness::Signed; + return Signedness::Unsigned; +} + // The parameter input is expected to be of RankedTensorType. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index 4ebfd122..42156eab 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr getTcpSignednessAttr(MLIRContext *context, IntegerType::SignednessSemantics signednessInfo); +mlir::tcp::Signedness +getTcpSignedness(IntegerType::SignednessSemantics signednessInfo); + // Helper function to expand the rank of the input tensor. Works by // adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`. Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, diff --git a/test/Pipeline/torch_to_tcp_pipeline.mlir b/test/Pipeline/torch_to_tcp_pipeline.mlir index 83c08bbb..33c45be7 100644 --- a/test/Pipeline/torch_to_tcp_pipeline.mlir +++ b/test/Pipeline/torch_to_tcp_pipeline.mlir @@ -108,8 +108,30 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // ----- +// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp, out_int_signedness = #tcp} : tensor -> tensor +// CHECK: %[[V1:.+]] = tcp.divsi %[[V0]], %[[ARG1]] {rounding_mode = #tcp} : tensor, tensor -> tensor +// CHECK: return %[[V1]] : tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { - // expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}} %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> } + +// ----- + +// CHECK: func.func @torch.aten.div.Tensor$mixed_type_uint(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp, out_int_signedness = #tcp} : tensor -> tensor +// CHECK: %[[V1:.+]] = tcp.divui %[[V0]], %[[ARG1]] {rounding_mode = #tcp} : tensor, tensor -> tensor +// CHECK: return %[[V1]] : tensor +func.func @torch.aten.div.Tensor$mixed_type_uint(%arg0: !torch.vtensor<[?, ?],ui16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],ui16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32> + return %0 : !torch.vtensor<[?, ?],ui32> +} + +// ----- + +func.func @torch.aten.div.Tensor$mixed_signed_int_div(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> { + // expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}} + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32> + return %0 : !torch.vtensor<[?, ?],ui32> +}