From ee2eb202e8c6385fd4e5bb27db09b2b92e41262d Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Tue, 17 Dec 2024 11:43:50 -0600 Subject: [PATCH] fix removal of consecutive tranposes with branches --- src/qonnx/transformation/channels_last.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index c22889a..a48a79f 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -92,6 +92,14 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe # scalar input, always broadcastable, no action needed continue elif ndim == ndim_inp: + p_nodes = model.find_direct_predecessors(eltwise_node) + for pn in p_nodes: + if pn.output[0] == eltwise_inp: + break + pn_shape = model.get_tensor_shape(pn.output[0]) + if pn_shape == subgraph_inp_shape: + # input with matching shape, inverse transpose not needed + continue # input with matching dimensions, add inverse transpose new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) @@ -323,6 +331,15 @@ def apply(self, model): if inp == output_tensor_name: target_node.input[i] = input_tensor + n_successors = model.find_direct_successors(n) + if n_successors is None: + continue + + for ns in n_successors: + for i, inp in enumerate(ns.input): + if inp == n.output[0]: + ns.input[i] = input_tensor + # remove old nodes graph.node.remove(n) graph.node.remove(successor_node)