Skip to content

Commit

Permalink
Cherry pick: Lower torch transposed convolution to a custom TCP op (c…
Browse files Browse the repository at this point in the history
…ruise-automation#25) (cruise-automation#10)

As titled, lower torch transposed convolution to a custom TCP op to
avoid a mis-compilation in `TorchToTosa`.

Cherry-pick from upstream: cruise-automation#25

---------

Co-authored-by: Srinath Avadhanula <[email protected]>
  • Loading branch information
Srinath Avadhanula authored and GitHub Enterprise committed Jan 16, 2024
2 parents aec758d + fdf0338 commit bf6191f
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 2 deletions.
97 changes: 97 additions & 0 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::tcp;
Expand Down Expand Up @@ -146,6 +147,88 @@ class ConvertAten_IndexPutImplOp
}
};

class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(OpConversionPattern<AtenConvolutionOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> operands;
SmallVector<StringRef> operandNames;

auto addOperand = [&](std::string name, Value value) {
operandNames.push_back(name);
operands.push_back(value);
};

addOperand("input", adaptor.getInput());
addOperand("weight", adaptor.getWeight());
if (!adaptor.getBias().getType().isa<Torch::NoneType>()) {
addOperand("bias", adaptor.getBias());
}

SmallVector<NamedAttribute> attrs;

attrs.push_back(rewriter.getNamedAttr(
"torch_operand_names", rewriter.getStrArrayAttr(operandNames)));

auto addListOfIntAttr = [&](const std::string &name, Value value) {
SmallVector<int64_t> valueInt;
if (!matchPattern(value, m_TorchListOfConstantInts(valueInt)))
return rewriter.notifyMatchFailure(op, std::string("non-const") + name +
"list unsupported");
attrs.push_back(
rewriter.getNamedAttr(name, rewriter.getIndexArrayAttr(valueInt)));
return success();
};

if (auto result = addListOfIntAttr("stride", adaptor.getStride());
result.failed()) {
return result;
}
if (auto result = addListOfIntAttr("padding", adaptor.getPadding());
result.failed()) {
return result;
}
if (auto result = addListOfIntAttr("dilation", adaptor.getDilation());
result.failed()) {
return result;
}
if (auto result =
addListOfIntAttr("output_padding", adaptor.getOutputPadding());
result.failed()) {
return result;
}

bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(op,
"non const transposed unsupported");
attrs.push_back(
rewriter.getNamedAttr("transposed", rewriter.getBoolAttr(transposed)));

int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups)))
return rewriter.notifyMatchFailure(op, "non const groups unsupported");
attrs.push_back(
rewriter.getNamedAttr("groups", rewriter.getI64IntegerAttr(groups)));

auto replOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes,
operands, attrs);

replOp.setOpName(op->getName().getStringRef());

return success();
}
};

} // namespace

void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
Expand All @@ -159,4 +242,18 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(Aten_IndexPutImplOp);
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN

auto isTransposedConvOp = [](AtenConvolutionOp op) {
bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return false;
return transposed;
};

// Only want to convert transposed conv ops, i.e., if its not transposed,
// its "legal", i.e., will not get converted.
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenConvolutionOp,
AtenConvolutionOp>(
typeConverter, patterns, target, convertTorchOpsSet,
[&](AtenConvolutionOp op) { return !isTransposedConvOp(op); });
}
7 changes: 5 additions & 2 deletions lib/Conversion/TorchToTcp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ bool getConstTensorWithType(ConversionPatternRewriter &rewriter, Operation *op,
template <typename TorchToTcpPattern, typename AtenOp>
inline void addPatternIfOpInConvertTorchOpsSet(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) {
ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet,
std::function<bool(AtenOp)> dynamicLegalityFcn = [](AtenOp) {
return false;
}) {
MLIRContext *context = patterns.getContext();
std::optional<OperationName> opName =
TorchToTcpPattern(context).getRootKind();
Expand All @@ -90,7 +93,7 @@ inline void addPatternIfOpInConvertTorchOpsSet(
if (convertTorchOpsSet.empty() ||
convertTorchOpsSet.contains(
opName->getStringRef().ltrim(torch::Torch::kTorchOpPrefix))) {
target.addIllegalOp<AtenOp>();
target.addDynamicallyLegalOp<AtenOp>(dynamicLegalityFcn);
patterns.add<TorchToTcpPattern>(typeConverter, context);
}
}
Expand Down
41 changes: 41 additions & 0 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,44 @@ func.func @torch.aten.index_put_impl_op(%arg0: !torch.vtensor<[25],f32>, %arg1:
%1 = torch.aten._index_put_impl %arg0, %0, %arg2, %false, %false : !torch.vtensor<[25],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[25],f32>
return %1 : !torch.vtensor<[25],f32>
}


// -----

// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {
// CHECK-SAME: dilation = [1 : index, 1 : index],
// CHECK-SAME: groups = 1 : i64,
// CHECK-SAME: output_padding = [1 : index, 1 : index],
// CHECK-SAME: padding = [1 : index, 1 : index],
// CHECK-SAME: stride = [2 : index, 2 : index],
// CHECK-SAME: torch_operand_names = ["input", "weight", "bias"],
// CHECK-SAME: transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
func.func @torcn.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>

return %output : !torch.vtensor<[1,64,2,200],f32>
}

// -----

// CHECK: torch.aten.convolution %{{.*}}
func.func @torch.aten.regular_convolution() -> !torch.vtensor<[1,32,16,1600],f32> {
%false = torch.constant.bool false
%input = torch.vtensor.literal(dense<0.0> : tensor<1x9x16x1600xf32>) : !torch.vtensor<[1,9,16,1600],f32>
%weights = torch.vtensor.literal(dense<0.0> : tensor<32x9x3x3xf32>) : !torch.vtensor<[32,9,3,3],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int0x0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%output = torch.aten.convolution %input, %weights, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32>

return %output : !torch.vtensor<[1,32,16,1600],f32>
}

0 comments on commit bf6191f

Please sign in to comment.