diff --git a/src/qonnx/custom_op/channels_last/base_wrapped_op.py b/src/qonnx/custom_op/channels_last/base_wrapped_op.py index 23f7b726..9eb8f688 100644 --- a/src/qonnx/custom_op/channels_last/base_wrapped_op.py +++ b/src/qonnx/custom_op/channels_last/base_wrapped_op.py @@ -58,6 +58,11 @@ def to_channels_first_args(ndim): return tuple(arg_list) +def swap_channels_from_list(l): + l[1], l[-1] = l[-1], l[1] + return l + + class ChannelsLastWrappedOp(CustomOp): # ToDo: _channelsLast_node_types should be loaded / inferred from this file or the registry. # Standard ONNX nodes which require a ChannelsLast data format to function properly diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index a5416b63..7aa3d5ac 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -29,9 +29,8 @@ import warnings from onnx import TensorProto, helper -from qonnx.analysis.topology import is_linear from qonnx.custom_op import channels_last -from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args +from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args, swap_channels_from_list from qonnx.transformation.base import Transformation from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.infer_shapes import InferShapes @@ -41,10 +40,12 @@ # Standard ONNX nodes which require a ChannelsLast data format to function properly _channelsLast_node_types = list(channels_last.custom_op.keys()) +_channelsLast_special_node_types = ['Resize', 'Upsample', 'Concat'] # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. -_move_through_nodes = ["Quant", "Relu"] +# Probably some more nodes have to be added here. +_move_through_nodes = ["Quant", "Relu", "LeakyRelu"] # Nodes, which do not modify the shape of the tensor, # And modify all values in the same way, if the second tensor is a scalar. @@ -63,16 +64,14 @@ class ConvertToChannelsLastAndClean(Transformation): """ - def __init__(self, make_input_channels_last=False): + def __init__(self, make_input_channels_last=False, remove_input_output_transposes=True): super().__init__() self._make_input_channels_last = make_input_channels_last + self._remove_input_output_transposes = remove_input_output_transposes def apply(self, model): - assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment." - assert model.check_all_tensor_shapes_specified(), ( - "All tensor shapes must be specified. " "Consider running InferShapes." - ) model = model.transform(InsertChannelsLastDomainsAndTrafos()) + initial_model_string = model.model.SerializeToString() # Apply RemoveConsecutiveChanFirstAndChanLastTrafos model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) @@ -80,7 +79,6 @@ def apply(self, model): # Apply MoveChanLastUpstream and fold into initializers model = model.transform(MoveChanLastUpstream()) model = model.transform(FoldTransposeIntoQuantInit()) - # Run RemoveConsecutiveChanFirstAndChanLastTrafos again, # Technically only required if something changed in the previous trafo model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos()) @@ -102,15 +100,79 @@ def apply(self, model): # Check if the model changed new_model_string = model.model.SerializeToString() + model = model.transform(RemoveDomainFromSpecialNodes()) # Do small cleanup, which isn't done by the cleanup in the normal transformation model = model.transform(InferShapes()) model = model.transform(FoldConstants()) # Check if the model changed model_changed = initial_model_string != new_model_string - + if model_changed: + model = model.transform(AddDomainToSpecialNodes()) + + if self._remove_input_output_transposes: + model = model.transform(RemoveInputOutputTransposes(), cleanup=True) return model, model_changed +class RemoveInputOutputTransposes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type == 'Transpose': + if model.find_direct_predecessors(n) == None: + s = model.find_direct_successors(n) # i -> n -> s + assert len(s) == 1 + for i, e in enumerate(s[0].input): + if n.output[0] == e: + s[0].input[i] = n.input[0] + + new_input_shape = model.get_tensor_shape(n.output[0]) + + # Modify input tensor shape + input_name = model.graph.input[0].name + new_input = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, new_input_shape) + + # Update the model graph inputs + model.graph.input.remove(model.graph.input[0]) + model.graph.input.append(new_input) + model.graph.node.remove(n) + continue + if model.find_direct_successors(n) == None: + p = model.find_direct_predecessors(n) + assert len(p) == 1 + for i, e in enumerate(p[0].output): + if n.input[0] == e: + p[0].output[i] = n.output[0] + + new_output_shape = model.get_tensor_shape(n.input[0]) + + # Modify output tensor shape + output_name = model.graph.output[0].name + new_output = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, new_output_shape) + + # Update the model graph outputs + model.graph.output.remove(model.graph.output[0]) + model.graph.output.append(new_output) + + model.graph.node.remove(n) + continue + return model, False + +class RemoveDomainFromSpecialNodes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type in _channelsLast_special_node_types: + n.domain = "" + return model, False + +class AddDomainToSpecialNodes(Transformation): + + def apply(self, model): + for n in model.graph.node: + if n.op_type in _channelsLast_special_node_types: + n.domain = "modified" + return model, False class InsertChannelsLastDomainsAndTrafos(Transformation): """ @@ -118,6 +180,55 @@ class InsertChannelsLastDomainsAndTrafos(Transformation): """ def apply(self, model): + + def insert_transpose_to_output(model, outp, graph, running_node_index, n, i): + # Get the shape of the input tensor + # and convert it to the shape for the intermediate tensor + chanFirst_shape = model.get_tensor_shape(outp) + ndim = len(chanFirst_shape) + assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." + chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] + # Intermediat tensor + outp_trans_in = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + chanLast_shape, + ) + graph.value_info.append(outp_trans_in) + outp_trans_in = outp_trans_in.name + + # ChannelsFirst -> ChannelsLast transpose + outp_trans_node = helper.make_node( + "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) + ) + graph.node.insert(running_node_index, outp_trans_node) + running_node_index += 1 + + # Attach to original node + n.output[i] = outp_trans_in + + def insert_transpose_to_input(model, inp, graph, running_node_index, n, i): + chanFirst_shape = model.get_tensor_shape(inp) + ndim = len(chanFirst_shape) + assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." + chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] + # Intermediate tensor + inp_trans_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + chanLast_shape, + ) + graph.value_info.append(inp_trans_out) + inp_trans_out = inp_trans_out.name + + # Channels last transpose + inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) + graph.node.insert(running_node_index, inp_trans_node) + running_node_index += 1 + + # Attach to original node + n.input[i] = inp_trans_out + graph = model.graph node_ind = 0 graph_modified = False @@ -142,60 +253,53 @@ def apply(self, model): # Skip Conv bias since it doesn't need a transpose if n.op_type == "Conv" and i == 2: continue - # Get the shape of the input tensor - # and convert it to the shape for the intermediate tensor - chanFirst_shape = model.get_tensor_shape(inp) - ndim = len(chanFirst_shape) - assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." - chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] - # Intermediate tensor - inp_trans_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - chanLast_shape, - ) - graph.value_info.append(inp_trans_out) - inp_trans_out = inp_trans_out.name - - # channels last transpose - inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) - graph.node.insert(running_node_index, inp_trans_node) - running_node_index += 1 - - # Attach to original node - n.input[i] = inp_trans_out + insert_transpose_to_input(model, inp, graph, running_node_index, n, i) # Insert transformation nodes for output nodes output_tensors = n.output for i, outp in enumerate(output_tensors): - chanFirst_shape = model.get_tensor_shape(outp) - ndim = len(chanFirst_shape) - assert ndim == 3 or ndim == 4, "Channels last conversion is only available for 3D and 4D tensors." - chanLast_shape = [chanFirst_shape[idx] for idx in to_channels_last_args(ndim)] - # Intermediat tensor - outp_trans_in = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - chanLast_shape, - ) - graph.value_info.append(outp_trans_in) - outp_trans_in = outp_trans_in.name - - # ChannelsFirst -> ChannelsLast transpose - outp_trans_node = helper.make_node( - "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) - ) - graph.node.insert(running_node_index, outp_trans_node) - running_node_index += 1 - - # Attach to original node - n.output[i] = outp_trans_in + insert_transpose_to_output(model, outp, graph, running_node_index, n, i) # Modify domain + # if (n.op_type in _channelsLast_node_types): n.domain = "qonnx.custom_op.channels_last" - # Set modified flag + # if (n.op_type in _channelsLast_special_node_types): + # n.domain = "qonnx.custom_op.channels_last_special" + # # Set modified flag graph_modified = True - + + if (n.op_type in _channelsLast_special_node_types) and (n.domain == ""): + running_node_index = node_ind + # Insert transformation nodes for input nodes + input_tensors = n.input + # Skip for BatchNorm and 2D input tensors, + # these contain only channels and need no transpose. + chanFirst_shape = model.get_tensor_shape(input_tensors[0]) + for i, inp in enumerate(input_tensors): + # Handle Resize scales + if (n.op_type == "Resize") and (i != 0): + if (len(input_tensors) == 2 and i == 1) or (len(input_tensors) == 3 and i == 2): + scales = model.get_initializer(inp).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(inp, scales) + continue + if (n.op_type == "Upsample") and (i == 1): + scales = model.get_initializer(inp).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(inp, scales) + continue + if (n.op_type == "Concat") and (i == 0): + s = len(model.get_tensor_shape(inp)) + get_by_name(n.attribute, "axis").i = s - 1 + insert_transpose_to_input(model, inp, graph, running_node_index, n, i) + + output_tensors = n.output + for i, outp in enumerate(output_tensors): + insert_transpose_to_output(model, outp, graph, running_node_index, n, i) + + n.domain = "modified" + graph_modified = True + return model, graph_modified @@ -285,22 +389,78 @@ def apply(self, model): if second_inp_shape == [1] or second_inp_shape == []: move_through_valid |= True + if predecessor.op_type in _move_through_nodes_if_scalar: + second_inp_shape = model.get_tensor_shape(predecessor.input[1]) + if second_inp_shape == [1] or second_inp_shape == []: + move_through_valid |= True + if (predecessor.op_type == "Resize"): + if len(predecessor.input) == 2: + i = 1 + else: + i = 2 + scales = model.get_initializer(predecessor.input[i]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(predecessor, scales) + if (predecessor.op_type == "Upsample"): + scales = model.get_initializer(predecessor.input[1]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(predecessor, scales) + if (predecessor.op_type == "Concat"): + s = len(model.get_tensor_shape(predecessor)) + get_by_name(predecessor.attribute, "axis").i = s - 1 + # Apply move through trafo if possible if move_through_valid: # Input tensors are always input 0 inp = predecessor.input[0] if isinstance(model.get_initializer(inp), type(None)): - # Swap around node "predecessor" and "n" - # collect tensors - tensor_1 = inp - tensor_2 = n.input[0] - tensor_3 = n.output[0] - # Now connect the tensors to the nodes again, - # but in different order - n.input[0] = tensor_1 - n.output[0] = tensor_2 - predecessor.input[0] = tensor_2 - predecessor.output[0] = tensor_3 + # Handle of the case in which predecessor is fork node + # for models with branches + if model.is_fork_node(predecessor): + # Here we are considering one branch of the fork. + # This case must be handled separately since the + # transpose on the other branch has to be simplified as well + transposes = model.find_direct_successors(predecessor) + both_transpose = True if all(n.op_type == 'Transpose' for n in transposes) else False + assert len(transposes) == 2, "Only the case of 2 branches is handled" + # assert both_transpose, "The first 2 nodes of the branches must be transpose nodes" + if not both_transpose: + continue + x2 = transposes[0] if transposes[1] == n else transposes[1] + + # It basically rewires the nodes and tensors in order to move + # one transpose before the fork node (usually an activation function) + # and removes the other transpose from the graph. + # Easier to understand by writing a simple graph + + # Define the nodes and tensor to be rewired + xa = model.find_direct_predecessors(predecessor)[0] + x0 = model.find_direct_successors(x2)[0] + x1 = model.find_direct_successors(n)[0] + tensor_1 = xa.output[0] + tensor_2 = predecessor.output[0] + tensor_3 = n.output[0] + + # Perform the rewiring + x0.input.remove(x2.output[0]) + x0.input.append(tensor_2) + x1.input[0] = tensor_2 + n.input[0] = tensor_1 + xa.output[0] = tensor_1 + predecessor.input[0] = tensor_3 + graph.node.remove(x2) + else: + # Swap around node "predecessor" and "n" + # collect tensors + tensor_1 = inp + tensor_2 = n.input[0] + tensor_3 = n.output[0] + # Now connect the tensors to the nodes again, + # but in different order + n.input[0] = tensor_1 + n.output[0] = tensor_2 + predecessor.input[0] = tensor_2 + predecessor.output[0] = tensor_3 # Change the shape of the middle tensor target_shape = model.get_tensor_shape(tensor_3) @@ -362,6 +522,23 @@ def apply(self, model): second_inp_shape = model.get_tensor_shape(successor.input[1]) if second_inp_shape == [1] or second_inp_shape == []: move_through_valid |= True + + if (successor.op_type == "Resize"): + if len(successor.input) == 2: + i = 1 + else: + i = 2 + scales = model.get_initializer(successor.input[i]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(successor, scales) + if (successor.op_type == "Upsample"): + scales = model.get_initializer(successor.input[1]).copy() + scales = swap_channels_from_list(scales) + model.set_initializer(successor, scales) + if (successor.op_type == "Concat"): + s = len(model.get_tensor_shape(successor)) + get_by_name(successor.attribute, "axis").i = s - 1 + # Apply move through trafo if possible if move_through_valid: # Collect all tensors connecting n and successor diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index b80a6d60..74302f76 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -75,6 +75,8 @@ def analysis_testing_for_chanlast_domain(model): "Conv": 3, "MaxPool": 3, "BatchNormalization": 3, + "Resize": 4, + "Concat": 1 } # Check that all wrapped_ops in the registry have a definition here chanlast_op_types = list(channels_last.custom_op.keys())