From 64b6087033d040fc8605fda5c885ef8f1d35e180 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 19:28:17 +0000 Subject: [PATCH] update --- onnxscript/ir/serde.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 1e311cc64..41571bcd3 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -1274,6 +1274,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto: return node_proto +def _remove_trailing_outputs( + outputs: Sequence[_protocols.ValueProtocol], +) -> Sequence[_protocols.ValueProtocol]: + """Remove trailing outputs that have empty names. + + Args: + outputs: The outputs to remove trailing outputs from. + + Returns: + The outputs with trailing outputs removed. + """ + for i, output in enumerate(reversed(outputs)): + if output.name: + return outputs[: len(outputs) - i] + return [] + + @_capture_errors(lambda node_proto, from_: repr(from_)) def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None: node_proto.op_type = from_.op_type @@ -1293,17 +1310,9 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc node_proto.input.append("") else: node_proto.input.append(input_.name) + # Do not include the trailing outputs that have empty names - trailing_empty_outputs = 0 - for output in reversed(from_.outputs): - if output.name: - break - trailing_empty_outputs += 1 - if trailing_empty_outputs > 0: - outputs = from_.outputs[:-trailing_empty_outputs] - else: - outputs = from_.outputs - for output in outputs: + for output in _remove_trailing_outputs(from_.outputs): node_proto.output.append(output.name) for attr in from_.attributes.values():