Skip to content

Commit d08ee52

Browse files
seanshparkbatcheu
andcommitted
[circle-mlir/pass] Convert Conv2DOp
This will enable to convert ONNXConvOp to circle Conv2DOp. ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]> Co-authored-by: Jonghwa Lee <[email protected]>
1 parent 2693997 commit d08ee52

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

circle-mlir/circle-mlir/lib/pass/src/ConvertONNXToCirclePass.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
// class: ConvAbcd
2626
// file: AbcdOp.h
2727
#include "ops/ConstantOp.h"
28+
#include "ops/ConvOp.h"
2829
#include "ops/NoneOp.h"
2930
#include "ops/PadOp.h"
3031
#include "ops/ReshapeOp.h"
@@ -149,6 +150,7 @@ void ConvertONNXToCirclePass::runOnOperation()
149150
patterns.insert<ConvBinaryT<mlir::ONNXAddOp, mlir::Circle::AddOp>>(typeConverter, context);
150151

151152
patterns.insert<ConvConstant>(typeConverter, context);
153+
patterns.insert<ConvConv>(typeConverter, context);
152154
patterns.insert<ConvNone>(typeConverter, context);
153155
patterns.insert<ConvPad>(typeConverter, context);
154156
patterns.insert<ConvReshape>(typeConverter, context);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef __CIRCLE_MLIR_PASS_OPS_CONV_OP_H__
18+
#define __CIRCLE_MLIR_PASS_OPS_CONV_OP_H__
19+
20+
#include <circle-mlir/dialect/CircleDialect.h>
21+
22+
#include "ConvertHelper.h"
23+
24+
#include <mlir/Transforms/DialectConversion.h>
25+
26+
#include <src/Dialect/ONNX/ONNXOps.hpp>
27+
28+
#include <cassert>
29+
#include <vector>
30+
#include <limits>
31+
32+
namespace mlir
33+
{
34+
namespace Circle
35+
{
36+
37+
// NOTE Name ConvConv is from Convert Convolution
38+
class ConvConv : public mlir::OpConversionPattern<mlir::ONNXConvOp>
39+
{
40+
public:
41+
using mlir::OpConversionPattern<mlir::ONNXConvOp>::OpConversionPattern;
42+
using OpAdaptor = typename mlir::ONNXConvOp::Adaptor;
43+
44+
mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp op, OpAdaptor adaptor,
45+
mlir::ConversionPatternRewriter &rewriter) const override
46+
{
47+
assert(op.verify().succeeded());
48+
49+
auto op_name = GetOperationName(op.getOperation());
50+
51+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv name: " << op_name << "\n"; });
52+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv auto_pad: " << op.getAutoPad() << "\n"; });
53+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv dilations: " << op.getDilations() << "\n"; });
54+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv group: " << op.getGroup() << "\n"; });
55+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv kernel_shape: " << op.getKernelShape() << "\n"; });
56+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv pads: " << op.getPads() << "\n"; });
57+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv strides: " << op.getStrides() << "\n"; });
58+
59+
if (notYetImplemented(op))
60+
return mlir::failure();
61+
62+
mlir::Value input = adaptor.getX();
63+
mlir::Value filter = adaptor.getW();
64+
mlir::Value bias = adaptor.getB();
65+
bool biasNone = bias.getType().isa<mlir::NoneType>();
66+
67+
mlir::Location opLoc = op->getLoc();
68+
69+
// TODO support other ranks for I/O
70+
71+
mlir::RankedTensorType intype = input.getType().dyn_cast_or_null<mlir::RankedTensorType>();
72+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv intype: " << intype << "\n"; });
73+
if (intype.getRank() != 4)
74+
assert(false);
75+
CHECK_VALID_RANK_4(intype);
76+
77+
mlir::RankedTensorType outtype = op.getType().dyn_cast_or_null<mlir::RankedTensorType>();
78+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv outtype: " << outtype << "\n"; });
79+
if (outtype.getRank() != 4)
80+
assert(false);
81+
CHECK_VALID_RANK_4(outtype);
82+
83+
mlir::Value inputPreTr = input;
84+
// for op.pads != [0,0,0,0]
85+
std::vector<int32_t> padsValue;
86+
if (GetPads(op.getPads(), padsValue))
87+
{
88+
mlir::Type i32 = rewriter.getI32Type();
89+
mlir::RankedTensorType ptype = RankedTensorType::get({4, 2}, i32);
90+
llvm::SmallVector<int32_t, 8> pvalue = {
91+
0, 0, 0, 0, padsValue[0], padsValue[2], padsValue[1], padsValue[3]};
92+
mlir::Location padsval_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/padsval"));
93+
mlir::Value paddings =
94+
rewriter.create<ConstOp>(padsval_loc, DenseIntElementsAttr::get(ptype, pvalue));
95+
96+
// calc output type+shape of Pad
97+
auto shape = intype.getShape();
98+
assert(shape.size() == 4);
99+
int64_t padH = 0, padW = 0;
100+
// NOTE if input is unknown, set padH, padW as unknown.
101+
// these will be resolved in shape inference.
102+
padH =
103+
(shape[2] == std::numeric_limits<int64_t>::min() ? shape[2]
104+
: shape[2] + padsValue[0] + padsValue[2]);
105+
padW =
106+
(shape[3] == std::numeric_limits<int64_t>::min() ? shape[3]
107+
: shape[3] + padsValue[1] + padsValue[3]);
108+
auto padShape = {shape[0], shape[1], padH, padW}; // order is NCHW
109+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv padH: " << padH << ", padW: " << padW << "\n"; });
110+
auto padType = mlir::RankedTensorType::get(padShape, outtype.getElementType());
111+
112+
// change pre Transpose input to this new Pad
113+
mlir::Location pads_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/pads"));
114+
LLVM_DEBUG({ llvm::dbgs() << "ConvConv Pad: " << pads_loc << "\n"; });
115+
inputPreTr = rewriter.create<PadOp>(pads_loc, padType, input, paddings);
116+
}
117+
118+
int32_t stride_h = 1;
119+
int32_t stride_w = 1;
120+
auto strides = op.getStrides();
121+
if (strides.has_value())
122+
{
123+
auto value = strides.value();
124+
if (value.size() != 2)
125+
return mlir::failure();
126+
127+
stride_h = GetIntValue<int32_t>(value, 0);
128+
stride_w = GetIntValue<int32_t>(value, 1);
129+
}
130+
131+
int64_t dilation_h_factor = 1;
132+
int64_t dilation_w_factor = 1;
133+
auto dilations = op.getDilations();
134+
if (dilations.has_value())
135+
{
136+
auto value = dilations.value();
137+
if (value.size() != 2)
138+
return mlir::failure();
139+
140+
dilation_h_factor = GetIntValue<int64_t>(value, 0);
141+
dilation_w_factor = GetIntValue<int64_t>(value, 1);
142+
}
143+
144+
// NOTE luci-interpreter fails to execute when bias is none.
145+
// we can (1) fix luci-interpreter (2) update bias to have zero values.
146+
// onnx-tensorflow works like (2) so we follow this.
147+
if (biasNone)
148+
{
149+
auto ftype = filter.getType().dyn_cast_or_null<mlir::RankedTensorType>();
150+
assert(ftype.getElementType().isF32());
151+
auto shape = ftype.getShape();
152+
int32_t num = shape[0]; // dim 0 from OIHW
153+
mlir::RankedTensorType type = RankedTensorType::get({num}, ftype.getElementType());
154+
std::vector<float> val;
155+
for (int32_t c = 0; c < num; ++c)
156+
val.push_back(0.0f);
157+
mlir::Location nobias_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/nobias"));
158+
bias = rewriter.create<ConstOp>(nobias_loc, DenseFPElementsAttr::get(type, val));
159+
}
160+
161+
auto filter_name = GetOperationName(filter.getDefiningOp());
162+
if (filter_name.empty())
163+
filter_name = op_name + "/filter";
164+
165+
mlir::Value pre_tran = CreatePreTranspose(rewriter, inputPreTr, op_name);
166+
167+
auto group = op.getGroup();
168+
if (group == 1)
169+
{
170+
mlir::Value filter_tran = CreatePreTranspose(rewriter, filter, filter_name);
171+
172+
auto conv_output_type = GetChnLastType(outtype);
173+
// TODO support activation != NONE
174+
// TODO support padding != VALID
175+
mlir::Value conv2d = rewriter.create<Conv2DOp>(opLoc, conv_output_type, pre_tran, filter_tran,
176+
bias, dilation_h_factor, dilation_w_factor,
177+
/*fused_activation_function=*/"NONE",
178+
/*padding=*/"VALID", stride_h, stride_w);
179+
180+
ReplaceOpWithPostTranspose(rewriter, op, conv2d, op.getType(), op_name);
181+
}
182+
else if (group > 1)
183+
{
184+
// TODO convert to DepthwiseConv2DOp
185+
return mlir::failure();
186+
}
187+
else
188+
return mlir::failure();
189+
190+
return mlir::success();
191+
}
192+
193+
private:
194+
bool notYetImplemented(mlir::ONNXConvOp &op) const
195+
{
196+
// TODO support other auto_pad: 'SAME_UPPER', 'VALID', 'SAME_LOWER'
197+
if (!op.getAutoPad().equals_insensitive("NOTSET"))
198+
return true;
199+
200+
return false;
201+
}
202+
};
203+
204+
} // namespace Circle
205+
} // namespace mlir
206+
207+
#endif // __CIRCLE_MLIR_PASS_OPS_CONV_OP_H__

0 commit comments

Comments
 (0)