Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Oct 14, 2024
1 parent f568dcb commit 64b6087
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
return node_proto


def _remove_trailing_outputs(
outputs: Sequence[_protocols.ValueProtocol],
) -> Sequence[_protocols.ValueProtocol]:
"""Remove trailing outputs that have empty names.
Args:
outputs: The outputs to remove trailing outputs from.
Returns:
The outputs with trailing outputs removed.
"""
for i, output in enumerate(reversed(outputs)):
if output.name:
return outputs[: len(outputs) - i]
return []


@_capture_errors(lambda node_proto, from_: repr(from_))
def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
node_proto.op_type = from_.op_type
Expand All @@ -1293,17 +1310,9 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
node_proto.input.append("")
else:
node_proto.input.append(input_.name)

# Do not include the trailing outputs that have empty names
trailing_empty_outputs = 0
for output in reversed(from_.outputs):
if output.name:
break
trailing_empty_outputs += 1
if trailing_empty_outputs > 0:
outputs = from_.outputs[:-trailing_empty_outputs]
else:
outputs = from_.outputs
for output in outputs:
for output in _remove_trailing_outputs(from_.outputs):
node_proto.output.append(output.name)

for attr in from_.attributes.values():
Expand Down

0 comments on commit 64b6087

Please sign in to comment.