Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support node name and node output name not same #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions onnx_opcounter/onnx_opcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def onnx_node_attributes_to_dict(args):
:param args: ONNX attributes object
:return: Python dictionary
"""

def onnx_attribute_to_dict(onnx_attr):
"""
Parse ONNX attribute
Expand All @@ -40,6 +41,7 @@ def onnx_attribute_to_dict(onnx_attr):
for attr_type in ['floats', 'ints', 'strings']:
if getattr(onnx_attr, attr_type):
return list(getattr(onnx_attr, attr_type))

return {arg.name: onnx_attribute_to_dict(arg) for arg in args}


Expand Down Expand Up @@ -67,9 +69,9 @@ def get_mapping_for_node(node, graph_outputs):
for output in node.output:
if output in graph_outputs:
return output
return node.name
return node.output[0]

output_name_mapping = {node.name: get_mapping_for_node(node, graph_outputs) for node in onnx_nodes}
output_name_mapping = {node.output[0]: get_mapping_for_node(node, graph_outputs) for node in onnx_nodes}
output_mapping = {}

for name in output_name_mapping:
Expand All @@ -83,7 +85,7 @@ def get_mapping_for_node(node, graph_outputs):
graph_outputs.append(output)
output_mapping[name] = graph_outputs.index(output)

print(name, '->', output, 'index', output_mapping[name])
# print(name, '->', output, 'index', output_mapping[name])

onnx.save(model, '+all-intermediate.onnx')

Expand Down Expand Up @@ -143,9 +145,10 @@ def no_macs(*args, **kwargs):

macs = 0
for node in onnx_nodes:
node_output_shape = output_shapes[node.name]
node_input_shape = output_shapes[node.input[0]]
macs += mac_calculators[node.op_type](
node, node_input_shape, node_output_shape, onnx_node_attributes_to_dict(node.attribute)
)
if node.op_type in mac_calculators:
node_output_shape = output_shapes[node.output[0]]
node_input_shape = output_shapes[node.input[0]]
macs += mac_calculators[node.op_type](
node, node_input_shape, node_output_shape, onnx_node_attributes_to_dict(node.attribute)
)
return macs