Skip to content

Commit

Permalink
Fixes for IR optimizer (#1865)
Browse files Browse the repository at this point in the history
A few fixes needed to make the IR based optimizer work for the HF
benchmark.

Removed the changes relating to SymbolicTensor hash issue, and what's
remaining are pure bug fixes.
  • Loading branch information
gramalingam authored Sep 17, 2024
1 parent 1eef633 commit 82dac0f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = node.inputs[0]
inputs = state.get_sym_value(input)
if any(x is None for x in inputs):
if inputs is None or any(x is None for x in inputs):
return None
new_axis = _get_int_attribute(node, "new_axis", 0)
axis = _get_int_attribute(node, "axis", None)
Expand Down
20 changes: 16 additions & 4 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def clone_value(self, value: ir.Value) -> ir.Value | None:
if value in self._value_map:
return self._value_map[value]
# If the value is not in the value map, it must be a graph input.
assert value.producer() is not None, f"Value {value} has no entry in the value map"
assert value.producer() is None, f"Value {value} has no entry in the value map"
new_value = ir.Value(
name=value.name,
type=value.type,
Expand Down Expand Up @@ -90,8 +90,17 @@ def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAt
)
return attr
assert isinstance(attr, ir.RefAttr)
if key in self._attr_map:
return self._attr_map[key]
ref_attr_name = attr.ref_attr_name
if ref_attr_name in self._attr_map:
ref_attr = self._attr_map[ref_attr_name]
if isinstance(ref_attr, ir.Attr):
return ir.Attr(
key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
)
assert isinstance(ref_attr, ir.RefAttr)
return ir.RefAttr(
key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
)
# Note that if a function has an attribute-parameter X, and a call (node) to the function
# has no attribute X, all references to X in nodes inside the function body will be
# removed. This is just the ONNX representation of optional-attributes.
Expand Down Expand Up @@ -142,10 +151,13 @@ def clone_graph(self, graph: ir.Graph) -> ir.Graph:
input_values = [self.clone_value(v) for v in graph.inputs]
nodes = [self.clone_node(node) for node in graph]
initializers = [self.clone_value(init) for init in graph.initializers.values()]
output_values = [
self.clone_value(v) for v in graph.outputs
] # Looks up already cloned values

return ir.Graph(
input_values, # type: ignore
graph.outputs,
output_values, # type: ignore
nodes=nodes,
initializers=initializers, # type: ignore
doc_string=graph.doc_string,
Expand Down

0 comments on commit 82dac0f

Please sign in to comment.