Skip to content

Commit

Permalink
[onnx] Convert onnx.QLinearConv to torch (llvm#2851)
Browse files Browse the repository at this point in the history
Leaning on the QDQ functionality in torch we can support the QLinearConv
operation by piggybacking through `torch.Convolution`. This includes
some changes such as allowing the `onnx` rewriter to run recursively.
Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which
is then lowered to `torch`.
  • Loading branch information
rsuderman authored Feb 6, 2024
1 parent cb52c4b commit e3faef5
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 47 deletions.
5 changes: 4 additions & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ class OnnxCustomOpConversionPattern
OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix,
int64_t domainVersion)
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)),
domainVersion(domainVersion) {}
domainVersion(domainVersion) {
// Onnx lowerings could produce other Onnx operations during the rewrite.
setHasBoundedRewriteRecursion();
}

LogicalResult
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);

Type getQTorchTypeFromTorchIntType(Type ty);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure();
});
patterns.onOp(
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
Expand Down
139 changes: 117 additions & 22 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -99,6 +100,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

return failure();
});
patterns.onOp(
"QLinearConv", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if ((binder.tensorOperands(operands, 8) &&
binder.tensorOperands(operands, 9)) ||
binder.tensorResultType(resultType))
return failure();
Value a = operands[0];
Value aScale = operands[1];
Value aZp = operands[2];
Value b = operands[3];
Value bScale = operands[4];
Value bZp = operands[5];
Value cScale = operands[6];
Value cZp = operands[7];
Value c = operands.size() == 9 ? operands[8] : nullptr;

auto check = [](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>();
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
};
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
!check(cScale) || !check(cScale))
return rewriter.notifyMatchFailure(
binder.op, "not supported for non per-tensor quantization");

auto extract = [&rewriter, &binder](Value v) {
auto vTy = v.getType().cast<Torch::ValueTensorType>();
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(vTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};

aZp = extract(aZp);
bZp = extract(bZp);
cZp = extract(cZp);
aScale = extract(aScale);
bScale = extract(bScale);
cScale = extract(cScale);

auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>();
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};

a = make(a, aScale, aZp);
b = make(b, bScale, bZp);

auto cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getIntegerType(32, /*issigned=*/true));

// TODO(suderman): insert convolution operator.
llvm::SmallVector<Value> newOperands = {a, b};
if (c)
newOperands.push_back(c);

cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getType<Torch::QInt32Type>());

llvm::SmallVector<NamedAttribute> newAttributes;
newAttributes.push_back(
rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv")));
for (auto namedAttr : binder.op->getAttrDictionary()) {
if (namedAttr.getName().getValue().compare("name") == 0)
continue;
llvm::errs() << namedAttr.getName() << "\n";
newAttributes.push_back(namedAttr);
}

c = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
newAttributes)
.getResult(0);

Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), cTy, c, outScale, outZp);
cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = dyn_cast<Torch::ValueTensorType>(
getQTorchTypeFromTorchIntType(resultType));
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(
Torch::getScalarTypeForType(cTy.getDtype()))));
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
c);
return success();
});
patterns.onOp(
"QLinearMatMul", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -157,28 +269,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
bScale = extract(bScale);
cScale = extract(cScale);

auto getQTy =
[&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType {
auto dt = ty.getDtype();
Type newDt;
if (dt.isUnsignedInteger(8)) {
newDt = rewriter.getType<Torch::QUInt8Type>();
} else if (dt.isSignedInteger(8)) {
newDt = rewriter.getType<Torch::QInt8Type>();
} else if (dt.isSignedInteger(32)) {
newDt = rewriter.getType<Torch::QInt32Type>();
} else {
return nullptr;
}

return rewriter.getType<Torch::ValueTensorType>(ty.getOptionalSizes(),
newDt);
};

auto make = [&rewriter, &binder, &getQTy](Value v, Value scale,
Value zp) -> Value {
auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto ty = v.getType().cast<Torch::ValueTensorType>();
auto newTy = getQTy(ty);
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};
Expand Down Expand Up @@ -214,7 +308,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = getQTy(resultType);
cTy = dyn_cast<Torch::ValueTensorType>(
getQTorchTypeFromTorchIntType(resultType));
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"

using namespace mlir;
using namespace mlir::torch;
Expand All @@ -26,3 +27,23 @@ Value mlir::torch::onnx_c::createConstantIntList(
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstValue);
}

Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
Torch::ValueTensorType tty = dyn_cast<Torch::ValueTensorType>(ty);
if (!tty)
return nullptr;

auto ctx = ty.getContext();
Type dty = tty.getDtype();

if (dty.isUnsignedInteger(8))
dty = Torch::QUInt8Type::get(ctx);
if (dty.isSignedInteger(8))
dty = Torch::QInt8Type::get(ctx);
if (dty.isSignedInteger(32))
dty = Torch::QInt32Type::get(ctx);

if (!dty)
return nullptr;
return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty);
}
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "lhs and rhs of convolution must either be both int or fp");
}

if (inputZp && weightZp) {
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
auto biasDTy = bias.getType().cast<RankedTensorType>().getElementType();
if (!biasDTy.isInteger(32)) {
return rewriter.notifyMatchFailure(
Expand Down
45 changes: 25 additions & 20 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
if (operands.size() < 3)
return failure();

Value bias = operands[2];
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
return failure();

Value lhsScale;
if (auto qLhs =
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
Expand All @@ -82,11 +78,18 @@ template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
return failure();

auto resultTy = cast<ValueTensorType>(op.getType());
auto biasTy = bias.getType().cast<ValueTensorType>();
auto biasETy = biasTy.getOptionalDtype();
if (!biasETy || !isa<mlir::FloatType>(biasETy))
if (!isa<mlir::FloatType>(resultTy.getDtype()))
return failure();

Value bias = operands[2];
auto biasTy = bias.getType().dyn_cast<ValueTensorType>();

if (biasTy) {
auto biasETy = biasTy.getOptionalDtype();
if (!biasETy || !isa<mlir::FloatType>(biasETy))
return failure();
}

Value biasScale = rewriter.create<AtenMulFloatOp>(
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);

Expand All @@ -95,19 +98,21 @@ template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

auto qi32Ty = rewriter.getType<QInt32Type>();
auto newBiasTy =
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
bias = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
bias = rewriter.create<AtenIntReprOp>(
op.getLoc(),
rewriter.getType<ValueTensorType>(
biasTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed)),
bias);

operands[2] = bias;

if (biasTy) {
auto newBiasTy =
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
bias = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
bias = rewriter.create<AtenIntReprOp>(
op.getLoc(),
rewriter.getType<ValueTensorType>(
biasTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed)),
bias);
operands[2] = bias;
}

auto convTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(),
Expand Down
Loading

0 comments on commit e3faef5

Please sign in to comment.