diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index db9797dc..b95c6a33 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -38,7 +38,12 @@ import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType from qonnx.transformation.double_to_single_float import DoubleToSingleFloat -from qonnx.transformation.general import RemoveStaticGraphInputs, RemoveUnusedTensors, SortGraph +from qonnx.transformation.general import ( + RemoveStaticGraphInputs, + RemoveUnusedTensors, + SortCommutativeInputsInitializerLast, + SortGraph, +) class ModelWrapper: @@ -149,6 +154,7 @@ def cleanup(self): RemoveUnusedTensors(), RemoveStaticGraphInputs(), SortGraph(), + SortCommutativeInputsInitializerLast(), ] for trn in cleanup_transforms: transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False) diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 5153e616..d69cee5a 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -29,6 +29,9 @@ import json import numpy as np import warnings + +# Protobuf onnx graph node type +from onnx import NodeProto # noqa from onnx import mapping from toposort import toposort_flatten @@ -359,3 +362,56 @@ def apply(self, model): # one iteration is enough return (model, False) + + +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model): # noqa + # Select all dynamic inputs, which are those without initializer tensor + dynamics = [i for i in node.input if model.get_initializer(i) is None] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [i for i in node.input if i not in dynamics] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers + + +# Tidy-Up transformation sorting the inputs to all commutative operations to +# have initializer inputs last +class SortCommutativeInputsInitializerLast(Transformation): + """ + Sorts inputs of nodes describing commutative operations to have initializer + inputs last. This order of inputs is assumed by many other transformations. + """ + + # Set of supported commutative operations + # TODO: There might be more valid operations + SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"} + + # Applies the transform to a whole model graph + def apply(self, model): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Check whether this node is among the supported + if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS: + # Group node inputs by category + dynamics, initializers = group_inputs_by_category(node, model) + # Flatten the grouped input list + inputs = [*dynamics, *initializers] + # Length of sorted and original input list must match + assert len(inputs) == len(node.input) + # Reassigned inputs from sorted categories + for i, name in enumerate(inputs): + # The graph has been modified if any input is reordered + if node.input[i] != name: + # Note: This is never reset back to False + graph_modified = True + # Reassign input name at the new index + node.input[i] = name + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/tests/transformation/test_sort_commutative_inputs_initializer_last.py b/tests/transformation/test_sort_commutative_inputs_initializer_last.py new file mode 100644 index 00000000..1cd1eb72 --- /dev/null +++ b/tests/transformation/test_sort_commutative_inputs_initializer_last.py @@ -0,0 +1,91 @@ +# Set pytest parameters +import pytest + +# Numpy for handling simulation of tensor operations +import numpy as np + +# Helper for creating ONNX nodes +from onnx import TensorProto +from onnx import helper as oh + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Execute QONNX model graphs +from qonnx.core.onnx_exec import execute_onnx + +# Graph transformation to be tested: Sorts the input list of commutative +# operations to have all dynamic inputs first followed by all initializer inputs +from qonnx.transformation.general import SortCommutativeInputsInitializerLast + +# QONNX utility for creating models from ONNX graphs +from qonnx.util.basic import qonnx_make_model + + +# Specify how many inputs the test should cover +@pytest.mark.parametrize("num_inputs", [4, 5, 6]) +# Specify which inputs should be turned into initializers +@pytest.mark.parametrize( + # fmt: off + "initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]] + # fmt: on +) +# Tests the SortCommutativeInputsInitializerLast transformation +def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): + # Generate the input tensor names + inputs = [f"in{i}" for i in range(num_inputs)] + # We will use the Sum ONNX operation to test this behavior, as it allows for + # arbitrary many inputs + node = oh.make_node( + # fmt: off + op_type="Sum", inputs=inputs, outputs=["out"], name="Sum" + # fmt: on + ) + # Create value infos for all input and the output tensor + inputs = [ + # fmt: off + oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs + # fmt: on + ] + out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,)) + # Make a graph comprising the Sum node and value infos for all inputs and + # the output + graph = oh.make_graph([node], inputs=inputs, outputs=[out], name="Sum") + # Wrap the graph in an QONNX model wrapper + model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests")) + # Prepare the execution context + context = {f"in{i}": np.random.rand(16) for i in range(num_inputs)} + # Make sure all inputs are of type float32 + context = {key: value.astype(np.float32) for key, value in context.items()} + # Turn selected inputs into initializers + for i in initializers: + model.set_initializer(f"in{i}", context[f"in{i}"]) + + # Execute the ONNX model before transforming + out_expected = execute_onnx(model, context)["out"] + # Apply the transformation to be tested + # Note: No cleanup, as the tested transformation is part of the cleanup, and + # we want to test this in isolation + model = model.transform( + # fmt: off + SortCommutativeInputsInitializerLast(), cleanup=False + # fmt: on + ) + # Execute the ONNX model after transforming + out_produced = execute_onnx(model, context)["out"] + + # Start with no initializer input seen so far + seen_initializer = False + # Verify that no "dynamic" input follows an initializer input + for i in model.graph.node[0].input: + # Keep track of when an initializer has been seen + if model.get_initializer(i) is not None: + seen_initializer = True + # If there has already been an initializer, this input must be an + # initializer as well + assert ( + not seen_initializer or model.get_initializer(i) is not None + ), "Non-initializer input following initializer after sorting" + + # Outputs before and after must match + assert np.allclose(out_produced, out_expected)