Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for integer division to TCP #97

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness",

def Tcp_SignednessAttr : EnumAttr<Tcp_Dialect, Tcp_Signedness, "signedness">;

// TCP rounding mode
def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>;
def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>;
def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>;

def Tcp_RoundingMode : I32EnumAttr<"RoundingMode",
"Rounding mode for integer operations which need a rounding mode",
[
Tcp_RoundingMode_Trunc,
Tcp_RoundingMode_Floor,
Tcp_RoundingMode_Ceil
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tcp";
}

def Tcp_RoundingModeAttr : EnumAttr<Tcp_Dialect, Tcp_RoundingMode, "roundingMode">;

#endif // TCP_ENUMS
40 changes: 40 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,46 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_DivSIOp : Tcp_BinaryElementwiseOp<"divsi", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise signed integer division";

let description = [{
Computes the signed integer division of `in1` and `in2`.
}];

let arguments = (ins
Tcp_IntTensor:$in1,
Tcp_IntTensor:$in2,
Tcp_RoundingModeAttr:$rounding_mode
);

let results = (outs
Tcp_IntTensor:$out
);

let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_DivUIOp : Tcp_BinaryElementwiseOp<"divui", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise unsigned integer division";

let description = [{
Computes the unsigned integer division of `in1` and `in2`.
}];

let arguments = (ins
Tcp_IntTensor:$in1,
Tcp_IntTensor:$in2,
Tcp_RoundingModeAttr:$rounding_mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the rounding mode here? Should we just stick to using trunc mode for divui?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the arith dialect has straightforward mapping to ceildiv, why not just have it for the sake of expressibility? I do not have a strong opinion on this though. We could just leave it out now and introduce it later if we find a need for it.

);

let results = (outs
Tcp_IntTensor:$out
);

let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let summary = "Constant op";

Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/TcpToLinalg/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,32 @@ createLinalgPayloadForElementwiseOp(Operation *op,
"createLinalgPayloadForElementwiseOp for tcp.divf");
}

if (auto divOp = dyn_cast<DivSIOp>(op)) {
if (!elemType.isa<mlir::IntegerType>())
llvm_unreachable("unsupported element type in "
"createLinalgPayloadForElementwiseOp for tcp.divsi");
if (divOp.getRoundingMode() == RoundingMode::Trunc)
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
else if (divOp.getRoundingMode() == RoundingMode::Ceil)
return {
b.create<arith::CeilDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::FloorDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
}

if (auto divOp = dyn_cast<DivUIOp>(op)) {
if (!elemType.isa<mlir::IntegerType>())
llvm_unreachable("unsupported element type in "
"createLinalgPayloadForElementwiseOp for tcp.divui");
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
divOp.getRoundingMode() == RoundingMode::Floor)
return {b.create<arith::DivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
}

if (isa<Atan2Op>(op)) {
if (elemType.isa<mlir::FloatType>())
return {b.create<math::Atan2Op>(loc, payloadArgs[0], payloadArgs[1])};
Expand Down Expand Up @@ -330,6 +356,8 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality(
INSERT_TCP_TO_LINALG_PATTERN(ClampOp);
INSERT_TCP_TO_LINALG_PATTERN(MulOp);
INSERT_TCP_TO_LINALG_PATTERN(DivFOp);
INSERT_TCP_TO_LINALG_PATTERN(DivSIOp);
INSERT_TCP_TO_LINALG_PATTERN(DivUIOp);
INSERT_TCP_TO_LINALG_PATTERN(SubOp);
INSERT_TCP_TO_LINALG_PATTERN(TanhOp);
INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp);
Expand Down
44 changes: 31 additions & 13 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf();
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());

Value rhs = adaptor.getOther();

Expand All @@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only Ranked Tensor types are supported in TCP");

// TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
// added
if (resultType.getElementType().isa<mlir::IntegerType>()) {
return rewriter.notifyMatchFailure(
op, "Only floating point division supported for now");
}

auto inputAType = op.getSelf()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
Expand All @@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();

Type inputBType = nullptr;
if (isa<AtenDivScalarOp>(op)) {
inputBType = adaptor.getOther().getType();

rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
adaptor.getOther(), outputType,
resultType.getElementType());
if (!rhs)
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
} else {
auto inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
rhs, resultType.getElementType());
}
Expand All @@ -337,7 +333,29 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
std::tie(lhs, rhs) =
torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs);

rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
if (isa<mlir::FloatType>(outputType)) {
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
} else {
auto in1IntType = cast<mlir::IntegerType>(inputAType);
auto in2IntType = cast<mlir::IntegerType>(inputBType);
auto outIntType = cast<mlir::IntegerType>(outputType);
if ((in1IntType.getSignedness() != in2IntType.getSignedness()) ||
(in1IntType.getSignedness() != outIntType.getSignedness()))
return rewriter.notifyMatchFailure(op,
"Mixed signedness not supported");
if (in1IntType.getSignedness() ==
mlir::IntegerType::SignednessSemantics::Signless)
return rewriter.notifyMatchFailure(
op, "Signless division not supported in TCP");

if (outIntType.getSignedness() ==
mlir::IntegerType::SignednessSemantics::Unsigned)
rewriter.replaceOpWithNewOp<tcp::DivUIOp>(op, resultType, lhs, rhs,
tcp::RoundingMode::Trunc);
else
rewriter.replaceOpWithNewOp<tcp::DivSIOp>(op, resultType, lhs, rhs,
tcp::RoundingMode::Trunc);
}
return success();
}
};
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context,
return SignednessAttr::get(context, Signedness::Unsigned);
}

Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
if (signednessInfo == IntegerType::SignednessSemantics::Signless)
return Signedness::Signless;
if (signednessInfo == IntegerType::SignednessSemantics::Signed)
return Signedness::Signed;
return Signedness::Unsigned;
}

// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr
getTcpSignednessAttr(MLIRContext *context,
IntegerType::SignednessSemantics signednessInfo);

mlir::tcp::Signedness
getTcpSignedness(IntegerType::SignednessSemantics signednessInfo);

// Helper function to expand the rank of the input tensor. Works by
// adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Expand Down
24 changes: 23 additions & 1 deletion test/Pipeline/torch_to_tcp_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,30 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>

// -----

// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Signed>, out_int_signedness = #tcp<signedness Signed>} : tensor<?x?xi16> -> tensor<?x?xi32>
// CHECK: %[[V1:.+]] = tcp.divsi %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
// CHECK: return %[[V1]] : tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
}

// -----

// CHECK: func.func @torch.aten.div.Tensor$mixed_type_uint(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Unsigned>, out_int_signedness = #tcp<signedness Unsigned>} : tensor<?x?xi16> -> tensor<?x?xi32>
// CHECK: %[[V1:.+]] = tcp.divui %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
// CHECK: return %[[V1]] : tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_uint(%arg0: !torch.vtensor<[?, ?],ui16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],ui16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
return %0 : !torch.vtensor<[?, ?],ui32>
}

// -----

func.func @torch.aten.div.Tensor$mixed_signed_int_div(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
return %0 : !torch.vtensor<[?, ?],ui32>
}