Skip to content

Commit

Permalink
Merge remote-tracking branch 'gh-public/main' into mfl/index-tensor-h…
Browse files Browse the repository at this point in the history
…acked-twin
  • Loading branch information
matthewfl committed Sep 20, 2024
2 parents ea0e5d8 + 3fc3290 commit 5439258
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 14 deletions.
18 changes: 18 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness",

def Tcp_SignednessAttr : EnumAttr<Tcp_Dialect, Tcp_Signedness, "signedness">;

// TCP rounding mode
def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>;
def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>;
def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>;

def Tcp_RoundingMode : I32EnumAttr<"RoundingMode",
"Rounding mode for integer operations which need a rounding mode",
[
Tcp_RoundingMode_Trunc,
Tcp_RoundingMode_Floor,
Tcp_RoundingMode_Ceil
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tcp";
}

def Tcp_RoundingModeAttr : EnumAttr<Tcp_Dialect, Tcp_RoundingMode, "roundingMode">;

#endif // TCP_ENUMS
40 changes: 40 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,46 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_DivSIOp : Tcp_BinaryElementwiseOp<"divsi", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise signed integer division";

let description = [{
Computes the signed integer division of `in1` and `in2`.
}];

let arguments = (ins
Tcp_IntTensor:$in1,
Tcp_IntTensor:$in2,
Tcp_RoundingModeAttr:$rounding_mode
);

let results = (outs
Tcp_IntTensor:$out
);

let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_DivUIOp : Tcp_BinaryElementwiseOp<"divui", [SameOperandsAndResultElementType]> {
let summary = "Computes elementwise unsigned integer division";

let description = [{
Computes the unsigned integer division of `in1` and `in2`.
}];

let arguments = (ins
Tcp_IntTensor:$in1,
Tcp_IntTensor:$in2,
Tcp_RoundingModeAttr:$rounding_mode
);

let results = (outs
Tcp_IntTensor:$out
);

let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
}

def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let summary = "Constant op";

Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/TcpToLinalg/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,32 @@ createLinalgPayloadForElementwiseOp(Operation *op,
"createLinalgPayloadForElementwiseOp for tcp.divf");
}

if (auto divOp = dyn_cast<DivSIOp>(op)) {
if (!elemType.isa<mlir::IntegerType>())
llvm_unreachable("unsupported element type in "
"createLinalgPayloadForElementwiseOp for tcp.divsi");
if (divOp.getRoundingMode() == RoundingMode::Trunc)
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
else if (divOp.getRoundingMode() == RoundingMode::Ceil)
return {
b.create<arith::CeilDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::FloorDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
}

if (auto divOp = dyn_cast<DivUIOp>(op)) {
if (!elemType.isa<mlir::IntegerType>())
llvm_unreachable("unsupported element type in "
"createLinalgPayloadForElementwiseOp for tcp.divui");
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
divOp.getRoundingMode() == RoundingMode::Floor)
return {b.create<arith::DivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
else
return {
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
}

if (isa<Atan2Op>(op)) {
if (elemType.isa<mlir::FloatType>())
return {b.create<math::Atan2Op>(loc, payloadArgs[0], payloadArgs[1])};
Expand Down Expand Up @@ -330,6 +356,8 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality(
INSERT_TCP_TO_LINALG_PATTERN(ClampOp);
INSERT_TCP_TO_LINALG_PATTERN(MulOp);
INSERT_TCP_TO_LINALG_PATTERN(DivFOp);
INSERT_TCP_TO_LINALG_PATTERN(DivSIOp);
INSERT_TCP_TO_LINALG_PATTERN(DivUIOp);
INSERT_TCP_TO_LINALG_PATTERN(SubOp);
INSERT_TCP_TO_LINALG_PATTERN(TanhOp);
INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp);
Expand Down
44 changes: 31 additions & 13 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.getSelf();
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());

Value rhs = adaptor.getOther();

Expand All @@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only Ranked Tensor types are supported in TCP");

// TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
// added
if (resultType.getElementType().isa<mlir::IntegerType>()) {
return rewriter.notifyMatchFailure(
op, "Only floating point division supported for now");
}

auto inputAType = op.getSelf()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
Expand All @@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();

Type inputBType = nullptr;
if (isa<AtenDivScalarOp>(op)) {
inputBType = adaptor.getOther().getType();

rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
adaptor.getOther(), outputType,
resultType.getElementType());
if (!rhs)
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
} else {
auto inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
inputBType = op.getOther()
.getType()
.template dyn_cast<torch::Torch::ValueTensorType>()
.getDtype();
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
rhs, resultType.getElementType());
}
Expand All @@ -337,7 +333,29 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
std::tie(lhs, rhs) =
torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs);

rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
if (isa<mlir::FloatType>(outputType)) {
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
} else {
auto in1IntType = cast<mlir::IntegerType>(inputAType);
auto in2IntType = cast<mlir::IntegerType>(inputBType);
auto outIntType = cast<mlir::IntegerType>(outputType);
if ((in1IntType.getSignedness() != in2IntType.getSignedness()) ||
(in1IntType.getSignedness() != outIntType.getSignedness()))
return rewriter.notifyMatchFailure(op,
"Mixed signedness not supported");
if (in1IntType.getSignedness() ==
mlir::IntegerType::SignednessSemantics::Signless)
return rewriter.notifyMatchFailure(
op, "Signless division not supported in TCP");

if (outIntType.getSignedness() ==
mlir::IntegerType::SignednessSemantics::Unsigned)
rewriter.replaceOpWithNewOp<tcp::DivUIOp>(op, resultType, lhs, rhs,
tcp::RoundingMode::Trunc);
else
rewriter.replaceOpWithNewOp<tcp::DivSIOp>(op, resultType, lhs, rhs,
tcp::RoundingMode::Trunc);
}
return success();
}
};
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context,
return SignednessAttr::get(context, Signedness::Unsigned);
}

Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
if (signednessInfo == IntegerType::SignednessSemantics::Signless)
return Signedness::Signless;
if (signednessInfo == IntegerType::SignednessSemantics::Signed)
return Signedness::Signed;
return Signedness::Unsigned;
}

// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr
getTcpSignednessAttr(MLIRContext *context,
IntegerType::SignednessSemantics signednessInfo);

mlir::tcp::Signedness
getTcpSignedness(IntegerType::SignednessSemantics signednessInfo);

// Helper function to expand the rank of the input tensor. Works by
// adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`.
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
Expand Down
24 changes: 23 additions & 1 deletion test/Pipeline/torch_to_tcp_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,30 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>

// -----

// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Signed>, out_int_signedness = #tcp<signedness Signed>} : tensor<?x?xi16> -> tensor<?x?xi32>
// CHECK: %[[V1:.+]] = tcp.divsi %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
// CHECK: return %[[V1]] : tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
}

// -----

// CHECK: func.func @torch.aten.div.Tensor$mixed_type_uint(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Unsigned>, out_int_signedness = #tcp<signedness Unsigned>} : tensor<?x?xi16> -> tensor<?x?xi32>
// CHECK: %[[V1:.+]] = tcp.divui %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
// CHECK: return %[[V1]] : tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_uint(%arg0: !torch.vtensor<[?, ?],ui16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],ui16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
return %0 : !torch.vtensor<[?, ?],ui32>
}

// -----

func.func @torch.aten.div.Tensor$mixed_signed_int_div(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
return %0 : !torch.vtensor<[?, ?],ui32>
}

0 comments on commit 5439258

Please sign in to comment.