diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index eb964fc4..d2419813 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -117,3 +117,78 @@ def execute_node(self, context, graph): def verify_node(self): info_messages = [] return info_messages + + +class AveragePoolNHWC(CustomOp): + # an AveragePool node, but using the NHWC data layout + + def get_nodeattr_types(self): + # no specific attributes for AveragePoolNHWC + # attributes below are identical to the standard ONNX AveragePool op: + # https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool + return { + "kernel_shape": ("ints", True, []), + "pads": ("ints", True, []), + "strides": ("ints", True, []), + "ceil_mode": ("i", False, 0), + } + + def make_shape_compatible_op(self, model): + node = self.onnx_node + iname = node.input[0] + ishape = model.get_tensor_shape(iname) + kernel_shape = self.get_nodeattr("kernel_shape") + pads = self.get_nodeattr("pads") + strides = self.get_nodeattr("strides") + ceil_mode = self.get_nodeattr("ceil_mode") + assert len(kernel_shape) == 2, "Non-2D AveragePoolNHWC not supported" + assert pads[0] == pads[2], "Uneven padding not supported" + assert pads[1] == pads[3], "Uneven padding not supported" + (n, hi, wi, c) = ishape + ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0], ceil_mode) + wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[1], ceil_mode) + oshape = (n, ho, wo, c) + return super().make_const_shape_op(oshape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + node = self.onnx_node + inp_name = node.input[0] + out_name = node.output[0] + inp = context[inp_name] + dummy_out = context[out_name] + # convert i/o NHWC -> NCHW + inp = np.transpose(inp, (0, 3, 1, 2)) + dummy_out = np.transpose(dummy_out, (0, 3, 1, 2)) + # execute as regular MaxPool + orig_domain = node.domain + node.domain = "" + node.op_type = "AveragePool" + inp_vi = helper.make_tensor_value_info(inp_name, TensorProto.FLOAT, inp.shape) + out_vi = helper.make_tensor_value_info(out_name, TensorProto.FLOAT, dummy_out.shape) + tmp_graph = helper.make_graph(nodes=[node], name="tmp_graph", inputs=[inp_vi], outputs=[out_vi]) + opset_version = self.onnx_opset_version + opset_imports = [helper.make_opsetid("", opset_version)] + onnx_kwargs = {"opset_imports": opset_imports} + tmp_model = qonnx_make_model(tmp_graph, producer_name="finn", **onnx_kwargs) + tmp_model = ModelWrapper(tmp_model) + new_ctx = {inp_name: inp} + from qonnx.core.onnx_exec import execute_onnx + + ret = execute_onnx(tmp_model, new_ctx) + # restore original node props + node.domain = orig_domain + node.op_type = "AveragePoolNHWC" + outp = ret[out_name] + # convert output NCHW -> NHWC + outp = np.transpose(outp, (0, 2, 3, 1)) + context[out_name] = outp + + def verify_node(self): + info_messages = [] + return info_messages