Skip to content

Commit

Permalink
Add experimental AveragePoolNHWC op
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar authored Apr 5, 2024
1 parent 6b8b4ec commit a0d4a0a
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/qonnx/custom_op/general/maxpoolnhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,78 @@ def execute_node(self, context, graph):
def verify_node(self):
info_messages = []
return info_messages


class AveragePoolNHWC(CustomOp):
# an AveragePool node, but using the NHWC data layout

def get_nodeattr_types(self):
# no specific attributes for AveragePoolNHWC
# attributes below are identical to the standard ONNX AveragePool op:
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
return {
"kernel_shape": ("ints", True, []),
"pads": ("ints", True, []),
"strides": ("ints", True, []),
"ceil_mode": ("i", False, 0),
}

def make_shape_compatible_op(self, model):
node = self.onnx_node
iname = node.input[0]
ishape = model.get_tensor_shape(iname)
kernel_shape = self.get_nodeattr("kernel_shape")
pads = self.get_nodeattr("pads")
strides = self.get_nodeattr("strides")
ceil_mode = self.get_nodeattr("ceil_mode")
assert len(kernel_shape) == 2, "Non-2D AveragePoolNHWC not supported"
assert pads[0] == pads[2], "Uneven padding not supported"
assert pads[1] == pads[3], "Uneven padding not supported"
(n, hi, wi, c) = ishape
ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0], ceil_mode)
wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[1], ceil_mode)
oshape = (n, ho, wo, c)
return super().make_const_shape_op(oshape)

def infer_node_datatype(self, model):
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])
model.set_tensor_datatype(node.output[0], dtype)

def execute_node(self, context, graph):
node = self.onnx_node
inp_name = node.input[0]
out_name = node.output[0]
inp = context[inp_name]
dummy_out = context[out_name]
# convert i/o NHWC -> NCHW
inp = np.transpose(inp, (0, 3, 1, 2))
dummy_out = np.transpose(dummy_out, (0, 3, 1, 2))
# execute as regular MaxPool
orig_domain = node.domain
node.domain = ""
node.op_type = "AveragePool"
inp_vi = helper.make_tensor_value_info(inp_name, TensorProto.FLOAT, inp.shape)
out_vi = helper.make_tensor_value_info(out_name, TensorProto.FLOAT, dummy_out.shape)
tmp_graph = helper.make_graph(nodes=[node], name="tmp_graph", inputs=[inp_vi], outputs=[out_vi])
opset_version = self.onnx_opset_version
opset_imports = [helper.make_opsetid("", opset_version)]
onnx_kwargs = {"opset_imports": opset_imports}
tmp_model = qonnx_make_model(tmp_graph, producer_name="finn", **onnx_kwargs)
tmp_model = ModelWrapper(tmp_model)
new_ctx = {inp_name: inp}
from qonnx.core.onnx_exec import execute_onnx

ret = execute_onnx(tmp_model, new_ctx)
# restore original node props
node.domain = orig_domain
node.op_type = "AveragePoolNHWC"
outp = ret[out_name]
# convert output NCHW -> NHWC
outp = np.transpose(outp, (0, 2, 3, 1))
context[out_name] = outp

def verify_node(self):
info_messages = []
return info_messages

0 comments on commit a0d4a0a

Please sign in to comment.