From 294229bbecbfceac81501fbded3c337cb0bec101 Mon Sep 17 00:00:00 2001 From: Eye <380614540@qq.com> Date: Wed, 27 Oct 2021 14:59:06 +0800 Subject: [PATCH] fix: support node name and node output name not same --- onnx_opcounter/onnx_opcounter.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/onnx_opcounter/onnx_opcounter.py b/onnx_opcounter/onnx_opcounter.py index ed9f07a..7391ee6 100644 --- a/onnx_opcounter/onnx_opcounter.py +++ b/onnx_opcounter/onnx_opcounter.py @@ -24,6 +24,7 @@ def onnx_node_attributes_to_dict(args): :param args: ONNX attributes object :return: Python dictionary """ + def onnx_attribute_to_dict(onnx_attr): """ Parse ONNX attribute @@ -40,6 +41,7 @@ def onnx_attribute_to_dict(onnx_attr): for attr_type in ['floats', 'ints', 'strings']: if getattr(onnx_attr, attr_type): return list(getattr(onnx_attr, attr_type)) + return {arg.name: onnx_attribute_to_dict(arg) for arg in args} @@ -67,9 +69,9 @@ def get_mapping_for_node(node, graph_outputs): for output in node.output: if output in graph_outputs: return output - return node.name + return node.output[0] - output_name_mapping = {node.name: get_mapping_for_node(node, graph_outputs) for node in onnx_nodes} + output_name_mapping = {node.output[0]: get_mapping_for_node(node, graph_outputs) for node in onnx_nodes} output_mapping = {} for name in output_name_mapping: @@ -83,7 +85,7 @@ def get_mapping_for_node(node, graph_outputs): graph_outputs.append(output) output_mapping[name] = graph_outputs.index(output) - print(name, '->', output, 'index', output_mapping[name]) + # print(name, '->', output, 'index', output_mapping[name]) onnx.save(model, '+all-intermediate.onnx') @@ -143,9 +145,10 @@ def no_macs(*args, **kwargs): macs = 0 for node in onnx_nodes: - node_output_shape = output_shapes[node.name] - node_input_shape = output_shapes[node.input[0]] - macs += mac_calculators[node.op_type]( - node, node_input_shape, node_output_shape, onnx_node_attributes_to_dict(node.attribute) - ) + if node.op_type in mac_calculators: + node_output_shape = output_shapes[node.output[0]] + node_input_shape = output_shapes[node.input[0]] + macs += mac_calculators[node.op_type]( + node, node_input_shape, node_output_shape, onnx_node_attributes_to_dict(node.attribute) + ) return macs