From c70923b0db85ef928b11461a6fdd0087d049d751 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Wed, 17 Apr 2024 17:41:27 +0200 Subject: [PATCH 01/11] Some initial modifications to support channels last with branches. The problem now seems to be to move the transpose upwards before the fork and properly reconnect the network. Not sure either what I should do concerning the model tensor shape since I see that for the base case it is modified. --- src/qonnx/custom_op/channels_last/__init__.py | 6 + src/qonnx/custom_op/channels_last/concat.py | 112 +++++++++++++++ src/qonnx/custom_op/channels_last/resize.py | 133 ++++++++++++++++++ src/qonnx/transformation/channels_last.py | 95 ++++++++++--- 4 files changed, 326 insertions(+), 20 deletions(-) create mode 100644 src/qonnx/custom_op/channels_last/concat.py create mode 100644 src/qonnx/custom_op/channels_last/resize.py diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..8b11f8c2 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -1,9 +1,15 @@ from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool +from qonnx.custom_op.channels_last.concat import Concat +from qonnx.custom_op.channels_last.resize import Resize + custom_op = dict() custom_op["Conv"] = Conv custom_op["MaxPool"] = MaxPool custom_op["BatchNormalization"] = BatchNormalization +custom_op["Concat"] = Concat +custom_op["Resize"] = Resize + diff --git a/src/qonnx/custom_op/channels_last/concat.py b/src/qonnx/custom_op/channels_last/concat.py new file mode 100644 index 00000000..84fe23d9 --- /dev/null +++ b/src/qonnx/custom_op/channels_last/concat.py @@ -0,0 +1,112 @@ +import numpy as np +from onnx import TensorProto, helper + +from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp + +class Concat(ChannelsLastWrappedOp): + def get_nodeattr_types(self): + """Returns a dict of permitted attributes for node, where: + ret_dict[attribute_name] = (dtype, require, default_value, ) + - dtype indicates which member of the ONNX AttributeProto + will be utilized + - require indicates whether this attribute is required + - default_val indicates the default value that will be used if the + attribute is not set + - (if specified) indicates that this attribute can only + be set to one of the values in the set . If not specified, + all values permitted by dtype are allowed. + """ + return { + # axis attribute of Concat layer, default 1 + "axis": ("i", True, 1) + } + + def make_shape_compatible_op(self, model): + """Returns a standard ONNX op which is compatible with this CustomOp + for performing shape inference.""" + + node = self.onnx_node + iname0 = node.input[0] + iname1 = node.input[1] + ishape0 = model.get_tensor_shape(iname0) + ishape1 = model.get_tensor_shape(iname1) + # axis = self.get_nodeattr("axis") + # not sure about what's the shape of inputs, don't know how to check it + # check that ishape0[1] == ishape1[1] and ishape0[2] == ishape1[2] + assert ishape0[1] == ishape1[1], "Input shape [1] has to be the same between the 2 input nodes of concat" + assert ishape0[2] == ishape1[2], "Input shape [2] has to be the same between the 2 input nodes of concat" + + # implement tensor with correct shape + output_shape = [1, ishape0[1], ishape0[2], ishape0[3] + ishape1[3]] + + # implement tensor with correct shape + values = np.random.randn(*output_shape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + name=self.onnx_node.name, + ) + + def verify_node(self): + node = self.onnx_node + + verification_successful = True + info_messages = [] + + wrapper_info = ChannelsLastWrappedOp.verify_node(self) + info_messages.extend(wrapper_info) + + # verify number of attributes + num_of_attr_min = 1 + num_of_attr_max = 1 + if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max: + info_messages.append("The number of attributes is correct") + else: + info_messages.append( + """The number of attributes is incorrect, + {} should have between {} and {} attributes""".format( + node.op_type, num_of_attr_min, num_of_attr_max + ) + ) + verification_successful = False + + # verify that all necessary attributes exist + try: + self.get_nodeattr("axis") + info_messages.append("All necessary attributes exist") + except Exception: + info_messages.append( + """The necessary attributes do not exist. + Concat needs the following attributes: + axis""" + ) + verification_successful = False + + # verify that attributes have the correct datatype. + try: + assert isinstance(self.get_nodeattr("axis"), int) + info_messages.append("All attributes are of the correct type") + except Exception: + info_messages.append("One or more attributes are of the wrong datatype") + verification_successful = False + + # verify the number of inputs + if len(node.input) == 2: + info_messages.append("The number of inputs is correct") + else: + info_messages.append("{} needs 2 data input".format(node.op_type)) + verification_successful = False + + if not verification_successful: + raise RuntimeError( + f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}" + ) + + return info_messages diff --git a/src/qonnx/custom_op/channels_last/resize.py b/src/qonnx/custom_op/channels_last/resize.py new file mode 100644 index 00000000..2aaa98c7 --- /dev/null +++ b/src/qonnx/custom_op/channels_last/resize.py @@ -0,0 +1,133 @@ +import struct +import numpy as np +from onnx import TensorProto, helper + +from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp + +class Resize(ChannelsLastWrappedOp): + def get_nodeattr_types(self): + """Returns a dict of permitted attributes for node, where: + ret_dict[attribute_name] = (dtype, require, default_value, ) + - dtype indicates which member of the ONNX AttributeProto + will be utilized + - require indicates whether this attribute is required + - default_val indicates the default value that will be used if the + attribute is not set + - (if specified) indicates that this attribute can only + be set to one of the values in the set . If not specified, + all values permitted by dtype are allowed. + """ + return { + "coordinate_transformation_mode": ("s", True, "half_pixel"), + "cubic_coeff_a": ("f", True, -0.75), + "mode": ("s", True, "linear"), + "nearest_mode": ("s", True, "floor") + } + + def _get_initializer_from_name(self, model, iname): + for i in model.graph.initializer: + if i.name == iname: + return i + + def _compute_fmt(self, tensor_shape): + fmt = "<" + for _ in range(tensor_shape): + fmt += "f" + return fmt + + def _compute_resize_output_shape(self, scales, input_shape): + assert len(scales) == len(input_shape) + scales = [int(i) for i in scales] + output_shape = input_shape.copy() + output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1] + for i in range(len(input_shape)): + output_shape[i] *= scales[i] + output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1] + return output_shape + + def make_shape_compatible_op(self, model): + """Returns a standard ONNX op which is compatible with this CustomOp + for performing shape inference.""" + node = self.onnx_node + iscalesn = node.input[2] + inode = node.input[0] + inodes = model.get_tensor_shape(inode) + iscalesns = model.get_tensor_shape(iscalesn) + i = self._get_initializer_from_name(model, iscalesn).raw_data + fmt = self._compute_fmt(iscalesns[0]) + scales = struct.unpack(fmt, i) + + # implement tensor with correct shape + output_shape = self._compute_resize_output_shape(scales, inodes) + + # implement tensor with correct shape + values = np.random.randn(*output_shape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + name=self.onnx_node.name, + ) + + def verify_node(self): + node = self.onnx_node + + verification_successful = True + info_messages = [] + + wrapper_info = ChannelsLastWrappedOp.verify_node(self) + info_messages.extend(wrapper_info) + + # verify number of attributes + num_of_attr_min = 1 + num_of_attr_max = 1 + if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max: + info_messages.append("The number of attributes is correct") + else: + info_messages.append( + """The number of attributes is incorrect, + {} should have between {} and {} attributes""".format( + node.op_type, num_of_attr_min, num_of_attr_max + ) + ) + verification_successful = False + + # verify that all necessary attributes exist + try: + self.get_nodeattr("axis") + info_messages.append("All necessary attributes exist") + except Exception: + info_messages.append( + """The necessary attributes do not exist. + Concat needs the following attributes: + axis""" + ) + verification_successful = False + + # verify that attributes have the correct datatype. + try: + assert isinstance(self.get_nodeattr("axis"), int) + info_messages.append("All attributes are of the correct type") + except Exception: + info_messages.append("One or more attributes are of the wrong datatype") + verification_successful = False + + # verify the number of inputs + if len(node.input) == 2: + info_messages.append("The number of inputs is correct") + else: + info_messages.append("{} needs 2 data input".format(node.op_type)) + verification_successful = False + + if not verification_successful: + raise RuntimeError( + f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}" + ) + + return info_messages diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index a5416b63..17b8013f 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,11 +40,12 @@ from qonnx.util.basic import get_by_name # Standard ONNX nodes which require a ChannelsLast data format to function properly +# _channelsLast_node_types = [x for x in list(channels_last.custom_op.keys()) if x != 'Resize'] _channelsLast_node_types = list(channels_last.custom_op.keys()) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. -_move_through_nodes = ["Quant", "Relu"] +_move_through_nodes = ["Quant", "Relu", "LeakyRelu", "Resize"] # Nodes, which do not modify the shape of the tensor, # And modify all values in the same way, if the second tensor is a scalar. @@ -68,40 +69,74 @@ def __init__(self, make_input_channels_last=False): self._make_input_channels_last = make_input_channels_last def apply(self, model): - assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment." - assert model.check_all_tensor_shapes_specified(), ( - "All tensor shapes must be specified. " "Consider running InferShapes." - ) + # assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment." + # assert model.check_all_tensor_shapes_specified(), ( + # "All tensor shapes must be specified. " "Consider running InferShapes." + # ) model = model.transform(InsertChannelsLastDomainsAndTrafos()) + + import onnx + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_1.onnx') + print('ONNX model saved InsertChannelsLastDomainsAndTrafos - 1') + initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_2.onnx') + print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 2') # Apply MoveChanLastUpstream and fold into initializers model = model.transform(MoveChanLastUpstream()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_3.onnx') + print('ONNX model saved MoveChanLastUpstream - 3') + model = model.transform(FoldTransposeIntoQuantInit()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_4.onnx') + print('ONNX model saved FoldTransposeIntoQuantInit - 4') # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_5.onnx') + print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 5') # Apply MoveChanLastDownStream model = model.transform(MoveChanFirstDownstream()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_6.onnx') + print('ONNX model saved MoveChanFirstDownstream - 6') # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_7.onnx') + print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 7') # Apply AbsorbChanFirstIntoMatMul model = model.transform(AbsorbChanFirstIntoMatMul()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_8.onnx') + print('ONNX model saved AbsorbChanFirstIntoMatMul - 8') if self._make_input_channels_last: model = model.transform(MakeInputChannelsLast()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_9.onnx') + print('ONNX model saved MakeInputChannelsLast - 9') + model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) + + onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_10.onnx') + print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 10') # Check if the model changed new_model_string = model.model.SerializeToString() - + # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) @@ -124,6 +159,7 @@ def apply(self, model): # Find nodes, where the domain should be changed for n in graph.node: node_ind += 1 + if (n.op_type in _channelsLast_node_types) and (n.domain == ""): running_node_index = node_ind # Insert transformation nodes for input nodes @@ -133,7 +169,6 @@ def apply(self, model): chanFirst_shape = model.get_tensor_shape(input_tensors[0]) if n.op_type == "BatchNormalization" and len(chanFirst_shape) == 2: continue - for i, inp in enumerate(input_tensors): # Skip higher "order" inputs of the Batch-Norm, # these don't need a transpose. @@ -142,6 +177,9 @@ def apply(self, model): # Skip Conv bias since it doesn't need a transpose if n.op_type == "Conv" and i == 2: continue + # Skip Resize scales since it does not need a transpose + if n.op_type == "Resize" and (i == 1 or i == 2): + continue # Get the shape of the input tensor # and convert it to the shape for the intermediate tensor chanFirst_shape = model.get_tensor_shape(inp) @@ -292,19 +330,36 @@ def apply(self, model): if isinstance(model.get_initializer(inp), type(None)): # Swap around node "predecessor" and "n" # collect tensors - tensor_1 = inp - tensor_2 = n.input[0] - tensor_3 = n.output[0] - # Now connect the tensors to the nodes again, - # but in different order - n.input[0] = tensor_1 - n.output[0] = tensor_2 - predecessor.input[0] = tensor_2 - predecessor.output[0] = tensor_3 - - # Change the shape of the middle tensor - target_shape = model.get_tensor_shape(tensor_3) - model.set_tensor_shape(tensor_2, target_shape) + if model.is_fork_node(predecessor): + # Here we are considering one branch of the fork. + # This case must be handles separately since the + # transpose on the other branch has to be simplified as well + transposes = model.find_direct_successors(predecessor) + both_transpose = True if all(n.op_type == 'Transpose' for n in transposes) else False + assert both_transpose, "Not handled case for branched model" + # only for 2 branches + other_node = transposes[0] if transposes[1] == n else transposes[1] + x0 = model.find_direct_successors(other_node)[0] + x1 = model.find_direct_successors(n)[0] + x0.input[0] = predecessor.output[0] + x1.input[0] = predecessor.output[0] + n.input[0] = predecessor.input[0] + predecessor.input[0] = n.name + graph.node.remove(other_node) + else: + tensor_1 = inp + tensor_2 = n.input[0] + tensor_3 = n.output[0] + # Now connect the tensors to the nodes again, + # but in different order + n.input[0] = tensor_1 + n.output[0] = tensor_2 + predecessor.input[0] = tensor_2 + predecessor.output[0] = tensor_3 + + # Change the shape of the middle tensor + target_shape = model.get_tensor_shape(tensor_3) + model.set_tensor_shape(tensor_2, target_shape) graph_modified = True return model, graph_modified From fd49030bbf352a0547576037e01f4c5f5611c72a Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Fri, 19 Apr 2024 11:16:48 +0200 Subject: [PATCH 02/11] Cleaned. Should be in good shape for PR --- src/qonnx/custom_op/channels_last/concat.py | 2 +- src/qonnx/custom_op/channels_last/resize.py | 18 +++-- src/qonnx/transformation/channels_last.py | 76 ++++++++------------- 3 files changed, 40 insertions(+), 56 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/concat.py b/src/qonnx/custom_op/channels_last/concat.py index 84fe23d9..4e158f5c 100644 --- a/src/qonnx/custom_op/channels_last/concat.py +++ b/src/qonnx/custom_op/channels_last/concat.py @@ -98,7 +98,7 @@ def verify_node(self): verification_successful = False # verify the number of inputs - if len(node.input) == 2: + if len(node.input) >= 2: info_messages.append("The number of inputs is correct") else: info_messages.append("{} needs 2 data input".format(node.op_type)) diff --git a/src/qonnx/custom_op/channels_last/resize.py b/src/qonnx/custom_op/channels_last/resize.py index 2aaa98c7..c4b59717 100644 --- a/src/qonnx/custom_op/channels_last/resize.py +++ b/src/qonnx/custom_op/channels_last/resize.py @@ -49,7 +49,7 @@ def make_shape_compatible_op(self, model): """Returns a standard ONNX op which is compatible with this CustomOp for performing shape inference.""" node = self.onnx_node - iscalesn = node.input[2] + iscalesn = node.input[-1] inode = node.input[0] inodes = model.get_tensor_shape(inode) iscalesns = model.get_tensor_shape(iscalesn) @@ -85,8 +85,8 @@ def verify_node(self): info_messages.extend(wrapper_info) # verify number of attributes - num_of_attr_min = 1 - num_of_attr_max = 1 + num_of_attr_min = 4 + num_of_attr_max = 4 if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max: info_messages.append("The number of attributes is correct") else: @@ -100,7 +100,10 @@ def verify_node(self): # verify that all necessary attributes exist try: - self.get_nodeattr("axis") + self.get_nodeattr("coordinate_transformation_mode") + self.get_nodeattr("cubic_coeff_a") + self.get_nodeattr("mode") + self.get_nodeattr("nearest_mode") info_messages.append("All necessary attributes exist") except Exception: info_messages.append( @@ -112,14 +115,17 @@ def verify_node(self): # verify that attributes have the correct datatype. try: - assert isinstance(self.get_nodeattr("axis"), int) + assert isinstance(self.get_nodeattr("coordinate_transformation_mode"), str) + assert isinstance(self.get_nodeattr("cubic_coeff_a"), float) + assert isinstance(self.get_nodeattr("mode"), str) + assert isinstance(self.get_nodeattr("nearest_mode"), str) info_messages.append("All attributes are of the correct type") except Exception: info_messages.append("One or more attributes are of the wrong datatype") verification_successful = False # verify the number of inputs - if len(node.input) == 2: + if len(node.input) == 1: info_messages.append("The number of inputs is correct") else: info_messages.append("{} needs 2 data input".format(node.op_type)) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 17b8013f..8920469d 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -69,70 +69,33 @@ def __init__(self, make_input_channels_last=False): self._make_input_channels_last = make_input_channels_last def apply(self, model): - # assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment." - # assert model.check_all_tensor_shapes_specified(), ( - # "All tensor shapes must be specified. " "Consider running InferShapes." - # ) model = model.transform(InsertChannelsLastDomainsAndTrafos()) - import onnx - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_1.onnx') - print('ONNX model saved InsertChannelsLastDomainsAndTrafos - 1') - initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_2.onnx') - print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 2') # Apply MoveChanLastUpstream and fold into initializers model = model.transform(MoveChanLastUpstream()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_3.onnx') - print('ONNX model saved MoveChanLastUpstream - 3') - model = model.transform(FoldTransposeIntoQuantInit()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_4.onnx') - print('ONNX model saved FoldTransposeIntoQuantInit - 4') # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_5.onnx') - print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 5') # Apply MoveChanLastDownStream model = model.transform(MoveChanFirstDownstream()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_6.onnx') - print('ONNX model saved MoveChanFirstDownstream - 6') # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_7.onnx') - print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 7') # Apply AbsorbChanFirstIntoMatMul model = model.transform(AbsorbChanFirstIntoMatMul()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_8.onnx') - print('ONNX model saved AbsorbChanFirstIntoMatMul - 8') if self._make_input_channels_last: model = model.transform(MakeInputChannelsLast()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_9.onnx') - print('ONNX model saved MakeInputChannelsLast - 9') - model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) - - onnx.save(model.model, 'modified-onnx-models/modified_tiny_unet_100k_ch_last_10.onnx') - print('ONNX model saved RemoveConsecutiveChanFirstAndChanLastTrafos - 10') # Check if the model changed new_model_string = model.model.SerializeToString() @@ -336,16 +299,31 @@ def apply(self, model): # transpose on the other branch has to be simplified as well transposes = model.find_direct_successors(predecessor) both_transpose = True if all(n.op_type == 'Transpose' for n in transposes) else False - assert both_transpose, "Not handled case for branched model" - # only for 2 branches - other_node = transposes[0] if transposes[1] == n else transposes[1] - x0 = model.find_direct_successors(other_node)[0] + assert len(transposes) == 2, "Only the case of 2 branches is handled" + assert both_transpose, "The first 2 nodes of the branches must be transpose nodes" + x2 = transposes[0] if transposes[1] == n else transposes[1] + + # It basically rewires the nodes and tensors in order to move + # one transpose before the fork node (usually an activation function) + # and removes the other transpose from the graph. + # Easier to understand by writing a simple graph + + # Define the nodes and tensor to be rewired + xa = model.find_direct_predecessors(predecessor)[0] + x0 = model.find_direct_successors(x2)[0] x1 = model.find_direct_successors(n)[0] - x0.input[0] = predecessor.output[0] - x1.input[0] = predecessor.output[0] - n.input[0] = predecessor.input[0] - predecessor.input[0] = n.name - graph.node.remove(other_node) + tensor_1 = xa.output[0] + tensor_2 = predecessor.output[0] + tensor_3 = n.output[0] + + # Perform the rewiring + x0.input.remove(x2.output[0]) + x0.input.append(tensor_2) + x1.input[0] = tensor_2 + n.input[0] = tensor_1 + xa.output[0] = tensor_1 + predecessor.input[0] = tensor_3 + graph.node.remove(x2) else: tensor_1 = inp tensor_2 = n.input[0] @@ -357,9 +335,9 @@ def apply(self, model): predecessor.input[0] = tensor_2 predecessor.output[0] = tensor_3 - # Change the shape of the middle tensor - target_shape = model.get_tensor_shape(tensor_3) - model.set_tensor_shape(tensor_2, target_shape) + # Change the shape of the middle tensor + target_shape = model.get_tensor_shape(tensor_3) + model.set_tensor_shape(tensor_2, target_shape) graph_modified = True return model, graph_modified From b7b6117d3aa4189d5c1bb6a0dfc7797ee2275522 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Fri, 19 Apr 2024 11:20:59 +0200 Subject: [PATCH 03/11] Some minor cleaning --- src/qonnx/custom_op/channels_last/__init__.py | 2 -- src/qonnx/transformation/channels_last.py | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 8b11f8c2..f264aa1f 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -4,7 +4,6 @@ from qonnx.custom_op.channels_last.concat import Concat from qonnx.custom_op.channels_last.resize import Resize - custom_op = dict() custom_op["Conv"] = Conv @@ -12,4 +11,3 @@ custom_op["BatchNormalization"] = BatchNormalization custom_op["Concat"] = Concat custom_op["Resize"] = Resize - diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 8920469d..34c18a50 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,11 +40,11 @@ from qonnx.util.basic import get_by_name # Standard ONNX nodes which require a ChannelsLast data format to function properly -# _channelsLast_node_types = [x for x in list(channels_last.custom_op.keys()) if x != 'Resize'] _channelsLast_node_types = list(channels_last.custom_op.keys()) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. +# Probably some more nodes have to be added here. _move_through_nodes = ["Quant", "Relu", "LeakyRelu", "Resize"] # Nodes, which do not modify the shape of the tensor, @@ -70,7 +70,6 @@ def __init__(self, make_input_channels_last=False): def apply(self, model): model = model.transform(InsertChannelsLastDomainsAndTrafos()) - initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) @@ -99,7 +98,6 @@ def apply(self, model): # Check if the model changed new_model_string = model.model.SerializeToString() - # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) @@ -122,7 +120,6 @@ def apply(self, model): # Find nodes, where the domain should be changed for n in graph.node: node_ind += 1 - if (n.op_type in _channelsLast_node_types) and (n.domain == ""): running_node_index = node_ind # Insert transformation nodes for input nodes @@ -132,6 +129,7 @@ def apply(self, model): chanFirst_shape = model.get_tensor_shape(input_tensors[0]) if n.op_type == "BatchNormalization" and len(chanFirst_shape) == 2: continue + for i, inp in enumerate(input_tensors): # Skip higher "order" inputs of the Batch-Norm, # these don't need a transpose. From ec78537321b555cf231e417c32f5dcccc9be1eb9 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Fri, 19 Apr 2024 11:23:55 +0200 Subject: [PATCH 04/11] Some other minor aesthetic fixes --- src/qonnx/transformation/channels_last.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 34c18a50..da17618f 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -98,6 +98,7 @@ def apply(self, model): # Check if the model changed new_model_string = model.model.SerializeToString() + # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) @@ -289,8 +290,8 @@ def apply(self, model): # Input tensors are always input 0 inp = predecessor.input[0] if isinstance(model.get_initializer(inp), type(None)): - # Swap around node "predecessor" and "n" - # collect tensors + # Handle of the case in which predecessor is fork node + # for models with branches if model.is_fork_node(predecessor): # Here we are considering one branch of the fork. # This case must be handles separately since the @@ -323,6 +324,8 @@ def apply(self, model): predecessor.input[0] = tensor_3 graph.node.remove(x2) else: + # Swap around node "predecessor" and "n" + # collect tensors tensor_1 = inp tensor_2 = n.input[0] tensor_3 = n.output[0] From 0125648acfe791e6990a33e285fc3cad18809400 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Tue, 23 Apr 2024 16:52:37 +0200 Subject: [PATCH 05/11] added a missing part that caused some tests to fail --- tests/transformation/test_channelslast.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index b80a6d60..74302f76 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -75,6 +75,8 @@ def analysis_testing_for_chanlast_domain(model): "Conv": 3, "MaxPool": 3, "BatchNormalization": 3, + "Resize": 4, + "Concat": 1 } # Check that all wrapped_ops in the registry have a definition here chanlast_op_types = list(channels_last.custom_op.keys()) From 4b0596469d09078a428a943f4e55bded8cf9bceb Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Tue, 23 Apr 2024 17:01:16 +0200 Subject: [PATCH 06/11] Removed an unused import --- src/qonnx/transformation/channels_last.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index da17618f..bd420387 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -29,7 +29,6 @@ import warnings from onnx import TensorProto, helper -from qonnx.analysis.topology import is_linear from qonnx.custom_op import channels_last from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args from qonnx.transformation.base import Transformation From 9d0b4bf1a1dab0c75fb0a2a8d0aa70de2d06608d Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Tue, 14 May 2024 11:28:47 +0200 Subject: [PATCH 07/11] Some fixes, now resize initializer is transposed as well --- src/qonnx/custom_op/channels_last/concat.py | 1 + src/qonnx/custom_op/channels_last/resize.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/concat.py b/src/qonnx/custom_op/channels_last/concat.py index 4e158f5c..71272e9f 100644 --- a/src/qonnx/custom_op/channels_last/concat.py +++ b/src/qonnx/custom_op/channels_last/concat.py @@ -30,6 +30,7 @@ def make_shape_compatible_op(self, model): iname1 = node.input[1] ishape0 = model.get_tensor_shape(iname0) ishape1 = model.get_tensor_shape(iname1) + self.set_nodeattr("axis", 3) # axis = self.get_nodeattr("axis") # not sure about what's the shape of inputs, don't know how to check it # check that ishape0[1] == ishape1[1] and ishape0[2] == ishape1[2] diff --git a/src/qonnx/custom_op/channels_last/resize.py b/src/qonnx/custom_op/channels_last/resize.py index c4b59717..c85ebe73 100644 --- a/src/qonnx/custom_op/channels_last/resize.py +++ b/src/qonnx/custom_op/channels_last/resize.py @@ -39,10 +39,8 @@ def _compute_resize_output_shape(self, scales, input_shape): assert len(scales) == len(input_shape) scales = [int(i) for i in scales] output_shape = input_shape.copy() - output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1] for i in range(len(input_shape)): output_shape[i] *= scales[i] - output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1] return output_shape def make_shape_compatible_op(self, model): @@ -53,10 +51,12 @@ def make_shape_compatible_op(self, model): inode = node.input[0] inodes = model.get_tensor_shape(inode) iscalesns = model.get_tensor_shape(iscalesn) - i = self._get_initializer_from_name(model, iscalesn).raw_data + i = self._get_initializer_from_name(model, iscalesn) + i_raw = i.raw_data fmt = self._compute_fmt(iscalesns[0]) - scales = struct.unpack(fmt, i) - + scales = struct.unpack(fmt, i_raw) + scales = (scales[0], scales[-1], scales[2], scales[1]) + i.raw_data = struct.pack(fmt, *scales) # implement tensor with correct shape output_shape = self._compute_resize_output_shape(scales, inodes) From acdc59b61cad6f9d12e86cb6642a8afa835926a2 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Thu, 23 May 2024 14:33:31 +0200 Subject: [PATCH 08/11] `Resize` and `Concat` are now handled as special cases. It probably better though to handle them in the move up and down transformations --- src/qonnx/custom_op/channels_last/__init__.py | 4 - .../channels_last/base_wrapped_op.py | 5 + src/qonnx/custom_op/channels_last/concat.py | 113 -------------- src/qonnx/custom_op/channels_last/resize.py | 139 ------------------ src/qonnx/transformation/channels_last.py | 59 +++++++- src/qonnx/util/cleanup.py | 2 + 6 files changed, 59 insertions(+), 263 deletions(-) delete mode 100644 src/qonnx/custom_op/channels_last/concat.py delete mode 100644 src/qonnx/custom_op/channels_last/resize.py diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f264aa1f..f1d7c39b 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -1,13 +1,9 @@ from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -from qonnx.custom_op.channels_last.concat import Concat -from qonnx.custom_op.channels_last.resize import Resize custom_op = dict() custom_op["Conv"] = Conv custom_op["MaxPool"] = MaxPool custom_op["BatchNormalization"] = BatchNormalization -custom_op["Concat"] = Concat -custom_op["Resize"] = Resize diff --git a/src/qonnx/custom_op/channels_last/base_wrapped_op.py b/src/qonnx/custom_op/channels_last/base_wrapped_op.py index 23f7b726..095226bb 100644 --- a/src/qonnx/custom_op/channels_last/base_wrapped_op.py +++ b/src/qonnx/custom_op/channels_last/base_wrapped_op.py @@ -58,6 +58,11 @@ def to_channels_first_args(ndim): return tuple(arg_list) +def to_channels_last_list(l): + l[1], l[-1] = l[-1], l[1] + return l + + class ChannelsLastWrappedOp(CustomOp): # ToDo: _channelsLast_node_types should be loaded / inferred from this file or the registry. # Standard ONNX nodes which require a ChannelsLast data format to function properly diff --git a/src/qonnx/custom_op/channels_last/concat.py b/src/qonnx/custom_op/channels_last/concat.py deleted file mode 100644 index 71272e9f..00000000 --- a/src/qonnx/custom_op/channels_last/concat.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -from onnx import TensorProto, helper - -from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp - -class Concat(ChannelsLastWrappedOp): - def get_nodeattr_types(self): - """Returns a dict of permitted attributes for node, where: - ret_dict[attribute_name] = (dtype, require, default_value, ) - - dtype indicates which member of the ONNX AttributeProto - will be utilized - - require indicates whether this attribute is required - - default_val indicates the default value that will be used if the - attribute is not set - - (if specified) indicates that this attribute can only - be set to one of the values in the set . If not specified, - all values permitted by dtype are allowed. - """ - return { - # axis attribute of Concat layer, default 1 - "axis": ("i", True, 1) - } - - def make_shape_compatible_op(self, model): - """Returns a standard ONNX op which is compatible with this CustomOp - for performing shape inference.""" - - node = self.onnx_node - iname0 = node.input[0] - iname1 = node.input[1] - ishape0 = model.get_tensor_shape(iname0) - ishape1 = model.get_tensor_shape(iname1) - self.set_nodeattr("axis", 3) - # axis = self.get_nodeattr("axis") - # not sure about what's the shape of inputs, don't know how to check it - # check that ishape0[1] == ishape1[1] and ishape0[2] == ishape1[2] - assert ishape0[1] == ishape1[1], "Input shape [1] has to be the same between the 2 input nodes of concat" - assert ishape0[2] == ishape1[2], "Input shape [2] has to be the same between the 2 input nodes of concat" - - # implement tensor with correct shape - output_shape = [1, ishape0[1], ishape0[2], ishape0[3] + ishape1[3]] - - # implement tensor with correct shape - values = np.random.randn(*output_shape).astype(np.float32) - return helper.make_node( - "Constant", - inputs=[], - outputs=[self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), - name=self.onnx_node.name, - ) - - def verify_node(self): - node = self.onnx_node - - verification_successful = True - info_messages = [] - - wrapper_info = ChannelsLastWrappedOp.verify_node(self) - info_messages.extend(wrapper_info) - - # verify number of attributes - num_of_attr_min = 1 - num_of_attr_max = 1 - if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max: - info_messages.append("The number of attributes is correct") - else: - info_messages.append( - """The number of attributes is incorrect, - {} should have between {} and {} attributes""".format( - node.op_type, num_of_attr_min, num_of_attr_max - ) - ) - verification_successful = False - - # verify that all necessary attributes exist - try: - self.get_nodeattr("axis") - info_messages.append("All necessary attributes exist") - except Exception: - info_messages.append( - """The necessary attributes do not exist. - Concat needs the following attributes: - axis""" - ) - verification_successful = False - - # verify that attributes have the correct datatype. - try: - assert isinstance(self.get_nodeattr("axis"), int) - info_messages.append("All attributes are of the correct type") - except Exception: - info_messages.append("One or more attributes are of the wrong datatype") - verification_successful = False - - # verify the number of inputs - if len(node.input) >= 2: - info_messages.append("The number of inputs is correct") - else: - info_messages.append("{} needs 2 data input".format(node.op_type)) - verification_successful = False - - if not verification_successful: - raise RuntimeError( - f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}" - ) - - return info_messages diff --git a/src/qonnx/custom_op/channels_last/resize.py b/src/qonnx/custom_op/channels_last/resize.py deleted file mode 100644 index c85ebe73..00000000 --- a/src/qonnx/custom_op/channels_last/resize.py +++ /dev/null @@ -1,139 +0,0 @@ -import struct -import numpy as np -from onnx import TensorProto, helper - -from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp - -class Resize(ChannelsLastWrappedOp): - def get_nodeattr_types(self): - """Returns a dict of permitted attributes for node, where: - ret_dict[attribute_name] = (dtype, require, default_value, ) - - dtype indicates which member of the ONNX AttributeProto - will be utilized - - require indicates whether this attribute is required - - default_val indicates the default value that will be used if the - attribute is not set - - (if specified) indicates that this attribute can only - be set to one of the values in the set . If not specified, - all values permitted by dtype are allowed. - """ - return { - "coordinate_transformation_mode": ("s", True, "half_pixel"), - "cubic_coeff_a": ("f", True, -0.75), - "mode": ("s", True, "linear"), - "nearest_mode": ("s", True, "floor") - } - - def _get_initializer_from_name(self, model, iname): - for i in model.graph.initializer: - if i.name == iname: - return i - - def _compute_fmt(self, tensor_shape): - fmt = "<" - for _ in range(tensor_shape): - fmt += "f" - return fmt - - def _compute_resize_output_shape(self, scales, input_shape): - assert len(scales) == len(input_shape) - scales = [int(i) for i in scales] - output_shape = input_shape.copy() - for i in range(len(input_shape)): - output_shape[i] *= scales[i] - return output_shape - - def make_shape_compatible_op(self, model): - """Returns a standard ONNX op which is compatible with this CustomOp - for performing shape inference.""" - node = self.onnx_node - iscalesn = node.input[-1] - inode = node.input[0] - inodes = model.get_tensor_shape(inode) - iscalesns = model.get_tensor_shape(iscalesn) - i = self._get_initializer_from_name(model, iscalesn) - i_raw = i.raw_data - fmt = self._compute_fmt(iscalesns[0]) - scales = struct.unpack(fmt, i_raw) - scales = (scales[0], scales[-1], scales[2], scales[1]) - i.raw_data = struct.pack(fmt, *scales) - # implement tensor with correct shape - output_shape = self._compute_resize_output_shape(scales, inodes) - - # implement tensor with correct shape - values = np.random.randn(*output_shape).astype(np.float32) - return helper.make_node( - "Constant", - inputs=[], - outputs=[self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), - name=self.onnx_node.name, - ) - - def verify_node(self): - node = self.onnx_node - - verification_successful = True - info_messages = [] - - wrapper_info = ChannelsLastWrappedOp.verify_node(self) - info_messages.extend(wrapper_info) - - # verify number of attributes - num_of_attr_min = 4 - num_of_attr_max = 4 - if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max: - info_messages.append("The number of attributes is correct") - else: - info_messages.append( - """The number of attributes is incorrect, - {} should have between {} and {} attributes""".format( - node.op_type, num_of_attr_min, num_of_attr_max - ) - ) - verification_successful = False - - # verify that all necessary attributes exist - try: - self.get_nodeattr("coordinate_transformation_mode") - self.get_nodeattr("cubic_coeff_a") - self.get_nodeattr("mode") - self.get_nodeattr("nearest_mode") - info_messages.append("All necessary attributes exist") - except Exception: - info_messages.append( - """The necessary attributes do not exist. - Concat needs the following attributes: - axis""" - ) - verification_successful = False - - # verify that attributes have the correct datatype. - try: - assert isinstance(self.get_nodeattr("coordinate_transformation_mode"), str) - assert isinstance(self.get_nodeattr("cubic_coeff_a"), float) - assert isinstance(self.get_nodeattr("mode"), str) - assert isinstance(self.get_nodeattr("nearest_mode"), str) - info_messages.append("All attributes are of the correct type") - except Exception: - info_messages.append("One or more attributes are of the wrong datatype") - verification_successful = False - - # verify the number of inputs - if len(node.input) == 1: - info_messages.append("The number of inputs is correct") - else: - info_messages.append("{} needs 2 data input".format(node.op_type)) - verification_successful = False - - if not verification_successful: - raise RuntimeError( - f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}" - ) - - return info_messages diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index bd420387..59c4d541 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -30,7 +30,7 @@ from onnx import TensorProto, helper from qonnx.custom_op import channels_last -from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args +from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args, to_channels_last_list from qonnx.transformation.base import Transformation from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.infer_shapes import InferShapes @@ -40,11 +40,12 @@ # Standard ONNX nodes which require a ChannelsLast data format to function properly _channelsLast_node_types = list(channels_last.custom_op.keys()) +_channelsLast_special_node_types = ['Resize', 'Upsample', 'Concat'] # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. # Probably some more nodes have to be added here. -_move_through_nodes = ["Quant", "Relu", "LeakyRelu", "Resize"] +_move_through_nodes = ["Quant", "Relu", "LeakyRelu"] # Nodes, which do not modify the shape of the tensor, # And modify all values in the same way, if the second tensor is a scalar. @@ -69,6 +70,9 @@ def __init__(self, make_input_channels_last=False): def apply(self, model): model = model.transform(InsertChannelsLastDomainsAndTrafos()) + + model = model.transform(RemoveDomainFromSpecialNodes()) + initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) @@ -101,12 +105,28 @@ def apply(self, model): # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) + model = model.transform(AddDomainFromSpecialNodes()) # Check if the model changed model_changed = initial_model_string != new_model_string return model, model_changed +class RemoveDomainFromSpecialNodes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type in _channelsLast_special_node_types: + n.domain = "" + return model, False + +class AddDomainFromSpecialNodes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type in _channelsLast_special_node_types: + n.domain = "qonnx.custom_op.channels_last_special" + return model, False class InsertChannelsLastDomainsAndTrafos(Transformation): """ @@ -120,7 +140,7 @@ def apply(self, model): # Find nodes, where the domain should be changed for n in graph.node: node_ind += 1 - if (n.op_type in _channelsLast_node_types) and (n.domain == ""): + if (n.op_type in _channelsLast_node_types or n.op_type in _channelsLast_special_node_types) and (n.domain == ""): running_node_index = node_ind # Insert transformation nodes for input nodes input_tensors = n.input @@ -138,9 +158,31 @@ def apply(self, model): # Skip Conv bias since it doesn't need a transpose if n.op_type == "Conv" and i == 2: continue - # Skip Resize scales since it does not need a transpose - if n.op_type == "Resize" and (i == 1 or i == 2): + # Handle Resize scales + if (n.op_type == "Resize") and (i == 1 or i == 2): + if i == 2: + scales = model.get_initializer(inp).copy() + scales = to_channels_last_list(scales) + model.set_initializer(inp, scales) + # old_shape = model.get_tensor_shape(n.output[0]) + # new_shape = to_channels_last_list(old_shape) + # model.set_tensor_shape(inp, new_shape) + continue + if (n.op_type == "Upsample") and (i == 1): + scales = model.get_initializer(inp).copy() + scales = to_channels_last_list(scales) + model.set_initializer(inp, scales) + # old_shape = model.get_tensor_shape(n.output[0]) + # new_shape = to_channels_last_list(old_shape) + # model.set_tensor_shape(inp, new_shape) continue + if (n.op_type == "Concat"): + if i == 0: + s = len(model.get_tensor_shape(inp)) + get_by_name(n.attribute, "axis").i = s - 1 + # t = model.get_tensor_shape(inp) + # t_new = to_channels_last_list(t) + # model.set_tensor_shape(inp, t_new) # Get the shape of the input tensor # and convert it to the shape for the intermediate tensor chanFirst_shape = model.get_tensor_shape(inp) @@ -191,7 +233,10 @@ def apply(self, model): n.output[i] = outp_trans_in # Modify domain - n.domain = "qonnx.custom_op.channels_last" + if (n.op_type in _channelsLast_node_types): + n.domain = "qonnx.custom_op.channels_last" + if (n.op_type in _channelsLast_special_node_types): + n.domain = "qonnx.custom_op.channels_last_special" # Set modified flag graph_modified = True @@ -293,7 +338,7 @@ def apply(self, model): # for models with branches if model.is_fork_node(predecessor): # Here we are considering one branch of the fork. - # This case must be handles separately since the + # This case must be handled separately since the # transpose on the other branch has to be simplified as well transposes = model.find_direct_successors(predecessor) both_transpose = True if all(n.op_type == 'Transpose' for n in transposes) else False diff --git a/src/qonnx/util/cleanup.py b/src/qonnx/util/cleanup.py index 933f729d..36b178f3 100644 --- a/src/qonnx/util/cleanup.py +++ b/src/qonnx/util/cleanup.py @@ -41,6 +41,7 @@ ) from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit +from qonnx.transformation.channels_last import RemoveDomainFromSpecialNodes def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False): @@ -75,6 +76,7 @@ def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_c model.set_tensor_shape(iname, override_inpsize) cleanup_transformations = [ + RemoveDomainFromSpecialNodes(), InferShapes(), GiveUniqueParameterTensors(), FoldConstants(exclude_op_types=preserve_qnt_optypes), From 2de8240ec7ed5d62c28e893c50a49e87f468f67c Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Wed, 29 May 2024 16:41:50 +0200 Subject: [PATCH 09/11] The nodes are now handled as special nodes. I also considered the case in which the transposes passes the special nodes. I added some cleaning transformation of the domain field, it is not very elegant but I think it is strictly necessary. --- .../channels_last/base_wrapped_op.py | 2 +- src/qonnx/transformation/channels_last.py | 199 +++++++++++------- src/qonnx/util/cleanup.py | 2 - 3 files changed, 121 insertions(+), 82 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/base_wrapped_op.py b/src/qonnx/custom_op/channels_last/base_wrapped_op.py index 095226bb..9eb8f688 100644 --- a/src/qonnx/custom_op/channels_last/base_wrapped_op.py +++ b/src/qonnx/custom_op/channels_last/base_wrapped_op.py @@ -58,7 +58,7 @@ def to_channels_first_args(ndim): return tuple(arg_list) -def to_channels_last_list(l): +def swap_channels_from_list(l): l[1], l[-1] = l[-1], l[1] return l diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 59c4d541..cab20bab 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -30,7 +30,7 @@ from onnx import TensorProto, helper from qonnx.custom_op import channels_last -from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args, to_channels_last_list +from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args, swap_channels_from_list from qonnx.transformation.base import Transformation from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.infer_shapes import InferShapes @@ -70,8 +70,6 @@ def __init__(self, make_input_channels_last=False): def apply(self, model): model = model.transform(InsertChannelsLastDomainsAndTrafos()) - - model = model.transform(RemoveDomainFromSpecialNodes()) initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos @@ -80,7 +78,6 @@ def apply(self, model): # Apply MoveChanLastUpstream and fold into initializers model = model.transform(MoveChanLastUpstream()) model = model.transform(FoldTransposeIntoQuantInit()) - # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) @@ -102,14 +99,15 @@ def apply(self, model): # Check if the model changed new_model_string = model.model.SerializeToString() + model = model.transform(RemoveDomainFromSpecialNodes()) # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) - model = model.transform(AddDomainFromSpecialNodes()) # Check if the model changed model_changed = initial_model_string != new_model_string - + if model_changed: + model = model.transform(AddDomainToSpecialNodes()) return model, model_changed class RemoveDomainFromSpecialNodes(Transformation): @@ -120,12 +118,12 @@ def apply(self, model): n.domain = "" return model, False -class AddDomainFromSpecialNodes(Transformation): +class AddDomainToSpecialNodes(Transformation): def apply(self, model): for n in model.graph.node: if n.op_type in _channelsLast_special_node_types: - n.domain = "qonnx.custom_op.channels_last_special" + n.domain = "modified" return model, False class InsertChannelsLastDomainsAndTrafos(Transformation): @@ -134,13 +132,62 @@ class InsertChannelsLastDomainsAndTrafos(Transformation): """ def apply(self, model): + + def insert_transpose_to_output(model, outp, graph, running_node_index, n, i): + # Get the shape of the input tensor + # and convert it to the shape for the intermediate tensor + chanFirst_shape = model.get_tensor_shape(outp) + ndim = len(chanFirst_shape) + assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." + chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] + # Intermediat tensor + outp_trans_in = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + chanLast_shape, + ) + graph.value_info.append(outp_trans_in) + outp_trans_in = outp_trans_in.name + + # ChannelsFirst -> ChannelsLast transpose + outp_trans_node = helper.make_node( + "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) + ) + graph.node.insert(running_node_index, outp_trans_node) + running_node_index += 1 + + # Attach to original node + n.output[i] = outp_trans_in + + def insert_transpose_to_input(model, inp, graph, running_node_index, n, i): + chanFirst_shape = model.get_tensor_shape(inp) + ndim = len(chanFirst_shape) + assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." + chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] + # Intermediate tensor + inp_trans_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + chanLast_shape, + ) + graph.value_info.append(inp_trans_out) + inp_trans_out = inp_trans_out.name + + # Channels last transpose + inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) + graph.node.insert(running_node_index, inp_trans_node) + running_node_index += 1 + + # Attach to original node + n.input[i] = inp_trans_out + graph = model.graph node_ind = 0 graph_modified = False # Find nodes, where the domain should be changed for n in graph.node: node_ind += 1 - if (n.op_type in _channelsLast_node_types or n.op_type in _channelsLast_special_node_types) and (n.domain == ""): + if (n.op_type in _channelsLast_node_types) and (n.domain == ""): running_node_index = node_ind # Insert transformation nodes for input nodes input_tensors = n.input @@ -158,88 +205,53 @@ def apply(self, model): # Skip Conv bias since it doesn't need a transpose if n.op_type == "Conv" and i == 2: continue + insert_transpose_to_input(model, inp, graph, running_node_index, n, i) + + # Insert transformation nodes for output nodes + output_tensors = n.output + for i, outp in enumerate(output_tensors): + insert_transpose_to_output(model, outp, graph, running_node_index, n, i) + + # Modify domain + # if (n.op_type in _channelsLast_node_types): + n.domain = "qonnx.custom_op.channels_last" + # if (n.op_type in _channelsLast_special_node_types): + # n.domain = "qonnx.custom_op.channels_last_special" + # # Set modified flag + graph_modified = True + + if (n.op_type in _channelsLast_special_node_types) and (n.domain == ""): + running_node_index = node_ind + # Insert transformation nodes for input nodes + input_tensors = n.input + # Skip for BatchNorm and 2D input tensors, + # these contain only channels and need no transpose. + chanFirst_shape = model.get_tensor_shape(input_tensors[0]) + for i, inp in enumerate(input_tensors): # Handle Resize scales - if (n.op_type == "Resize") and (i == 1 or i == 2): + if (n.op_type == "Resize") and (i != 0): if i == 2: scales = model.get_initializer(inp).copy() - scales = to_channels_last_list(scales) + scales = swap_channels_from_list(scales) model.set_initializer(inp, scales) - # old_shape = model.get_tensor_shape(n.output[0]) - # new_shape = to_channels_last_list(old_shape) - # model.set_tensor_shape(inp, new_shape) continue if (n.op_type == "Upsample") and (i == 1): scales = model.get_initializer(inp).copy() - scales = to_channels_last_list(scales) + scales = swap_channels_from_list(scales) model.set_initializer(inp, scales) - # old_shape = model.get_tensor_shape(n.output[0]) - # new_shape = to_channels_last_list(old_shape) - # model.set_tensor_shape(inp, new_shape) continue - if (n.op_type == "Concat"): - if i == 0: - s = len(model.get_tensor_shape(inp)) - get_by_name(n.attribute, "axis").i = s - 1 - # t = model.get_tensor_shape(inp) - # t_new = to_channels_last_list(t) - # model.set_tensor_shape(inp, t_new) - # Get the shape of the input tensor - # and convert it to the shape for the intermediate tensor - chanFirst_shape = model.get_tensor_shape(inp) - ndim = len(chanFirst_shape) - assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." - chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] - # Intermediate tensor - inp_trans_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - chanLast_shape, - ) - graph.value_info.append(inp_trans_out) - inp_trans_out = inp_trans_out.name - - # channels last transpose - inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) - graph.node.insert(running_node_index, inp_trans_node) - running_node_index += 1 - - # Attach to original node - n.input[i] = inp_trans_out - - # Insert transformation nodes for output nodes + if (n.op_type == "Concat") and (i == 0): + s = len(model.get_tensor_shape(inp)) + get_by_name(n.attribute, "axis").i = s - 1 + insert_transpose_to_input(model, inp, graph, running_node_index, n, i) + output_tensors = n.output for i, outp in enumerate(output_tensors): - chanFirst_shape = model.get_tensor_shape(outp) - ndim = len(chanFirst_shape) - assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." - chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] - # Intermediat tensor - outp_trans_in = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - chanLast_shape, - ) - graph.value_info.append(outp_trans_in) - outp_trans_in = outp_trans_in.name - - # ChannelsFirst -> ChannelsLast transpose - outp_trans_node = helper.make_node( - "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) - ) - graph.node.insert(running_node_index, outp_trans_node) - running_node_index += 1 - - # Attach to original node - n.output[i] = outp_trans_in - - # Modify domain - if (n.op_type in _channelsLast_node_types): - n.domain = "qonnx.custom_op.channels_last" - if (n.op_type in _channelsLast_special_node_types): - n.domain = "qonnx.custom_op.channels_last_special" - # Set modified flag + insert_transpose_to_output(model, outp, graph, running_node_index, n, i) + + n.domain = "modified" graph_modified = True - + return model, graph_modified @@ -329,6 +341,22 @@ def apply(self, model): if second_inp_shape == [1] or second_inp_shape == []: move_through_valid |= True + if predecessor.op_type in _move_through_nodes_if_scalar: + second_inp_shape = model.get_tensor_shape(predecessor.input[1]) + if second_inp_shape == [1] or second_inp_shape == []: + move_through_valid |= True + if (predecessor.op_type == "Resize"): + scales = model.get_initializer(predecessor.input[2]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(predecessor, scales) + if (predecessor.op_type == "Upsample"): + scales = model.get_initializer(predecessor.input[1]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(predecessor, scales) + if (predecessor.op_type == "Concat"): + s = len(model.get_tensor_shape(predecessor)) + get_by_name(predecessor.attribute, "axis").i = s - 1 + # Apply move through trafo if possible if move_through_valid: # Input tensors are always input 0 @@ -440,6 +468,19 @@ def apply(self, model): second_inp_shape = model.get_tensor_shape(successor.input[1]) if second_inp_shape == [1] or second_inp_shape == []: move_through_valid |= True + + if (successor.op_type == "Resize"): + scales = model.get_initializer(successor.input[2]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(successor, scales) + if (successor.op_type == "Upsample"): + scales = model.get_initializer(successor.input[1]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(successor, scales) + if (successor.op_type == "Concat"): + s = len(model.get_tensor_shape(successor)) + get_by_name(successor.attribute, "axis").i = s - 1 + # Apply move through trafo if possible if move_through_valid: # Collect all tensors connecting n and successor diff --git a/src/qonnx/util/cleanup.py b/src/qonnx/util/cleanup.py index 36b178f3..933f729d 100644 --- a/src/qonnx/util/cleanup.py +++ b/src/qonnx/util/cleanup.py @@ -41,7 +41,6 @@ ) from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.transformation.channels_last import RemoveDomainFromSpecialNodes def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False): @@ -76,7 +75,6 @@ def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_c model.set_tensor_shape(iname, override_inpsize) cleanup_transformations = [ - RemoveDomainFromSpecialNodes(), InferShapes(), GiveUniqueParameterTensors(), FoldConstants(exclude_op_types=preserve_qnt_optypes), From 5463de2e05fdf1437406013fca7b3052c560bbdc Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Thu, 30 May 2024 10:54:52 +0200 Subject: [PATCH 10/11] One minor fix on `Resize` --- src/qonnx/transformation/channels_last.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index cab20bab..3ceefbf8 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -230,7 +230,7 @@ def insert_transpose_to_input(model, inp, graph, running_node_index, n, i): for i, inp in enumerate(input_tensors): # Handle Resize scales if (n.op_type == "Resize") and (i != 0): - if i == 2: + if (len(input_tensors) == 2 and i == 1) or (len(input_tensors) == 3 and i == 2): scales = model.get_initializer(inp).copy() scales = swap_channels_from_list(scales) model.set_initializer(inp, scales) @@ -346,7 +346,11 @@ def apply(self, model): if second_inp_shape == [1] or second_inp_shape == []: move_through_valid |= True if (predecessor.op_type == "Resize"): - scales = model.get_initializer(predecessor.input[2]).copy() + if len(predecessor.input) == 2: + i = 1 + else: + i = 2 + scales = model.get_initializer(predecessor.input[i]).copy() scales = swap_channels_from_list(scales) model.set_initializer(predecessor, scales) if (predecessor.op_type == "Upsample"): @@ -470,7 +474,11 @@ def apply(self, model): move_through_valid |= True if (successor.op_type == "Resize"): - scales = model.get_initializer(successor.input[2]).copy() + if len(successor.input) == 2: + i = 1 + else: + i = 2 + scales = model.get_initializer(successor.input[i]).copy() scales = swap_channels_from_list(scales) model.set_initializer(successor, scales) if (successor.op_type == "Upsample"): From 8ab1f28c56a936217cdb60718eaeadb8648065a9 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Thu, 27 Jun 2024 12:13:57 +0200 Subject: [PATCH 11/11] Added an optional parameter perform a removal of eventual input and output transposes --- src/qonnx/transformation/channels_last.py | 54 ++++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 3ceefbf8..7aa3d5ac 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -64,9 +64,10 @@ class ConvertToChannelsLastAndClean(Transformation): """ - def __init__(self, make_input_channels_last=False): + def __init__(self, make_input_channels_last=False, remove_input_output_transposes=True): super().__init__() self._make_input_channels_last = make_input_channels_last + self._remove_input_output_transposes = remove_input_output_transposes def apply(self, model): model = model.transform(InsertChannelsLastDomainsAndTrafos()) @@ -108,8 +109,55 @@ def apply(self, model): model_changed = initial_model_string != new_model_string if model_changed: model = model.transform(AddDomainToSpecialNodes()) + + if self._remove_input_output_transposes: + model = model.transform(RemoveInputOutputTransposes(), cleanup=True) return model, model_changed +class RemoveInputOutputTransposes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type == 'Transpose': + if model.find_direct_predecessors(n) == None: + s = model.find_direct_successors(n) # i -> n -> s + assert len(s) == 1 + for i, e in enumerate(s[0].input): + if n.output[0] == e: + s[0].input[i] = n.input[0] + + new_input_shape = model.get_tensor_shape(n.output[0]) + + # Modify input tensor shape + input_name = model.graph.input[0].name + new_input = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, new_input_shape) + + # Update the model graph inputs + model.graph.input.remove(model.graph.input[0]) + model.graph.input.append(new_input) + model.graph.node.remove(n) + continue + if model.find_direct_successors(n) == None: + p = model.find_direct_predecessors(n) + assert len(p) == 1 + for i, e in enumerate(p[0].output): + if n.input[0] == e: + p[0].output[i] = n.output[0] + + new_output_shape = model.get_tensor_shape(n.input[0]) + + # Modify output tensor shape + output_name = model.graph.output[0].name + new_output = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, new_output_shape) + + # Update the model graph outputs + model.graph.output.remove(model.graph.output[0]) + model.graph.output.append(new_output) + + model.graph.node.remove(n) + continue + return model, False + class RemoveDomainFromSpecialNodes(Transformation): def apply(self, model): @@ -375,7 +423,9 @@ def apply(self, model): transposes = model.find_direct_successors(predecessor) both_transpose = True if all(n.op_type == 'Transpose' for n in transposes) else False assert len(transposes) == 2, "Only the case of 2 branches is handled" - assert both_transpose, "The first 2 nodes of the branches must be transpose nodes" + # assert both_transpose, "The first 2 nodes of the branches must be transpose nodes" + if not both_transpose: + continue x2 = transposes[0] if transposes[1] == n else transposes[1] # It basically rewires the nodes and tensors in order to move