From 82dac0f74ebee036b290f761fa168f79c56f0025 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Mon, 16 Sep 2024 17:03:30 -0700 Subject: [PATCH] Fixes for IR optimizer (#1865) A few fixes needed to make the IR based optimizer work for the HF benchmark. Removed the changes relating to SymbolicTensor hash issue, and what's remaining are pure bug fixes. --- onnxscript/optimizer/_constant_folding.py | 2 +- onnxscript/optimizer/_inliner.py | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index b7cbc0bb2..2e9486b68 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -348,7 +348,7 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] inputs = state.get_sym_value(input) - if any(x is None for x in inputs): + if inputs is None or any(x is None for x in inputs): return None new_axis = _get_int_attribute(node, "new_axis", 0) axis = _get_int_attribute(node, "axis", None) diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index 31221de02..590937397 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -62,7 +62,7 @@ def clone_value(self, value: ir.Value) -> ir.Value | None: if value in self._value_map: return self._value_map[value] # If the value is not in the value map, it must be a graph input. - assert value.producer() is not None, f"Value {value} has no entry in the value map" + assert value.producer() is None, f"Value {value} has no entry in the value map" new_value = ir.Value( name=value.name, type=value.type, @@ -90,8 +90,17 @@ def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAt ) return attr assert isinstance(attr, ir.RefAttr) - if key in self._attr_map: - return self._attr_map[key] + ref_attr_name = attr.ref_attr_name + if ref_attr_name in self._attr_map: + ref_attr = self._attr_map[ref_attr_name] + if isinstance(ref_attr, ir.Attr): + return ir.Attr( + key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string + ) + assert isinstance(ref_attr, ir.RefAttr) + return ir.RefAttr( + key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string + ) # Note that if a function has an attribute-parameter X, and a call (node) to the function # has no attribute X, all references to X in nodes inside the function body will be # removed. This is just the ONNX representation of optional-attributes. @@ -142,10 +151,13 @@ def clone_graph(self, graph: ir.Graph) -> ir.Graph: input_values = [self.clone_value(v) for v in graph.inputs] nodes = [self.clone_node(node) for node in graph] initializers = [self.clone_value(init) for init in graph.initializers.values()] + output_values = [ + self.clone_value(v) for v in graph.outputs + ] # Looks up already cloned values return ir.Graph( input_values, # type: ignore - graph.outputs, + output_values, # type: ignore nodes=nodes, initializers=initializers, # type: ignore doc_string=graph.doc_string,