Skip to content

Commit

Permalink
OnnxToTorch lowering resize op
Browse files Browse the repository at this point in the history
  • Loading branch information
aldesilv committed Apr 30, 2024
1 parent 72349f7 commit a6e6d49
Show file tree
Hide file tree
Showing 8 changed files with 774 additions and 0 deletions.
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7211,6 +7211,35 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
}];
}

def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchFloatType:$scale_factor,
Torch_StringType:$mode,
AnyTorchOptionalBoolType:$align_corners,
AnyTorchOptionalBoolType:$recompute_scale_factor,
Torch_BoolType:$antialias
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
113 changes: 113 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2595,4 +2595,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"Resize", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
std::string mode;
std::string coordTfMode;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel"))
return failure();
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value modeStrValue;

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};

auto getValueList = [&](Value operand) {
SmallVector<Value> itemList;
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

MLIRContext *context = binder.op->getContext();
for (int i = sizes[0] - 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)), itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
}
return ValueList;
};

Value scalesValueList;
Value sizesValueList;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;

if (mode == "cubic") {
return failure();
}
if (mode == "linear") {
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
"bilinear");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
}
if (mode == "nearest") {
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizesOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizesOperand);
}
}
rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
}
Loading

0 comments on commit a6e6d49

Please sign in to comment.