Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for integer division to TCP #97

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness",

def Tcp_SignednessAttr : EnumAttr<Tcp_Dialect, Tcp_Signedness, "signedness">;

// TCP rounding mode
def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>;
def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: fix spacing here and below

def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>;

def Tcp_RoundingMode : I32EnumAttr<"RoundingMode",
"Rounding mode for integer operations which need a rounding mode",
[
Tcp_RoundingMode_Trunc,
Tcp_RoundingMode_Floor,
Tcp_RoundingMode_Ceil
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tcp";
}

def Tcp_RoundingModeAttr : EnumAttr<Tcp_Dialect, Tcp_RoundingMode, "roundingMode">;

#endif // TCP_ENUMS
21 changes: 21 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,27 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_DivIOp : Tcp_BinaryElementwiseOp<"divi", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise integer division";

let description = [{
Computes the integer division of `in1` and `in2`.
}];

let arguments = (ins
Tcp_IntTensor:$in1,
Tcp_IntTensor:$in2,
Tcp_SignednessAttr:$signedness,
Tcp_RoundingModeAttr:$rounding_mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a "Ceil" rounding mode here? Are they any cases where that is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a good use-case for the tcp.diviop itself to support a ceil rounding mode, given that such a thing does not exist for the torch.div op. I added it because there is a natural mapping to the arith.ceildiviop. I can remove it.

However, I think the RoundingMode type itself should have ceil in it for completion though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the rounding mode here? Should we just stick to using trunc mode for divui?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the arith dialect has straightforward mapping to ceildiv, why not just have it for the sake of expressibility? I do not have a strong opinion on this though. We could just leave it out now and introduce it later if we find a need for it.

);

let results = (outs
Tcp_IntTensor:$out
);

let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let summary = "Constant op";

Expand Down
31 changes: 29 additions & 2 deletions lib/Conversion/TcpToLinalg/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,35 @@ createLinalgPayloadForElementwiseOp(Operation *op,
if (isa<DivFOp>(op)) {
if (elemType.isa<mlir::FloatType>())
return {b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1])};
else
else if (elemType.isa<mlir::IntegerType>()) {
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
}
}

if (auto divOp = dyn_cast<DivIOp>(op)) {
if (!elemType.isa<mlir::IntegerType>())
llvm_unreachable("unsupported element type in "
"createLinalgPayloadForElementwiseOp for tcp.divf");
"createLinalgPayloadForElementwiseOp for tcp.divi");
if (divOp.getSignedness() == Signedness::Unsigned) {
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
divOp.getRoundingMode() == RoundingMode::Floor)
return {b.create<arith::DivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
} else if (divOp.getSignedness() == Signedness::Signed) {
if (divOp.getRoundingMode() == RoundingMode::Trunc)
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
else if (divOp.getRoundingMode() == RoundingMode::Ceil)
return {
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::FloorDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
} else {
llvm_unreachable("unsupported signedness in "
"createLinalgPayloadForElementwiseOp for tcp.divi");
}
}

if (isa<Atan2Op>(op)) {
Expand Down Expand Up @@ -330,6 +356,7 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality(
INSERT_TCP_TO_LINALG_PATTERN(ClampOp);
INSERT_TCP_TO_LINALG_PATTERN(MulOp);
INSERT_TCP_TO_LINALG_PATTERN(DivFOp);
INSERT_TCP_TO_LINALG_PATTERN(DivIOp);
INSERT_TCP_TO_LINALG_PATTERN(SubOp);
INSERT_TCP_TO_LINALG_PATTERN(TanhOp);
INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp);
Expand Down
41 changes: 28 additions & 13 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf();
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());

Value rhs = adaptor.getOther();

Expand All @@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only Ranked Tensor types are supported in TCP");

// TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
// added
if (resultType.getElementType().isa<mlir::IntegerType>()) {
return rewriter.notifyMatchFailure(
op, "Only floating point division supported for now");
}

auto inputAType = op.getSelf()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
Expand All @@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();

Type inputBType = nullptr;
if (isa<AtenDivScalarOp>(op)) {
inputBType = adaptor.getOther().getType();

rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
adaptor.getOther(), outputType,
resultType.getElementType());
if (!rhs)
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
} else {
auto inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
rhs, resultType.getElementType());
}
Expand All @@ -337,7 +333,26 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
std::tie(lhs, rhs) =
torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs);

rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
if (isa<mlir::FloatType>(outputType)) {
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
} else {
auto in1IntType = cast<mlir::IntegerType>(inputAType);
auto in2IntType = cast<mlir::IntegerType>(inputBType);
auto outIntType = cast<mlir::IntegerType>(outputType);
if ((in1IntType.getSignedness() != in2IntType.getSignedness()) ||
(in1IntType.getSignedness() != outIntType.getSignedness()))
return rewriter.notifyMatchFailure(op,
"Mixed signedness not supported");
if (in1IntType.getSignedness() ==
mlir::IntegerType::SignednessSemantics::Signless)
return rewriter.notifyMatchFailure(
op, "Signless division not supported in TCP");

rewriter.replaceOpWithNewOp<tcp::DivIOp>(
op, resultType, lhs, rhs,
torch_to_tcp::getTcpSignedness(outIntType.getSignedness()),
tcp::RoundingMode::Trunc);
}
return success();
}
};
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context,
return SignednessAttr::get(context, Signedness::Unsigned);
}

Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
if (signednessInfo == IntegerType::SignednessSemantics::Signless)
return Signedness::Signless;
if (signednessInfo == IntegerType::SignednessSemantics::Signed)
return Signedness::Signed;
return Signedness::Unsigned;
}

// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr
getTcpSignednessAttr(MLIRContext *context,
IntegerType::SignednessSemantics signednessInfo);

mlir::tcp::Signedness
getTcpSignedness(IntegerType::SignednessSemantics signednessInfo);

// Helper function to expand the rank of the input tensor. Works by
// adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Expand Down
5 changes: 4 additions & 1 deletion test/Pipeline/torch_to_tcp_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>

// -----

// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Signed>, out_int_signedness = #tcp<signedness Signed>} : tensor<?x?xi16> -> tensor<?x?xi32>
// CHECK: %[[V1:.+]] = tcp.divi %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>, signedness = #tcp<signedness Signed>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
// CHECK: return %[[V1]] : tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
}