Skip to content

Commit

Permalink
fix removal of consecutive tranposes with branches
Browse files Browse the repository at this point in the history
  • Loading branch information
jicampos committed Dec 17, 2024
1 parent cf640b9 commit ee2eb20
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe
# scalar input, always broadcastable, no action needed
continue
elif ndim == ndim_inp:
p_nodes = model.find_direct_predecessors(eltwise_node)
for pn in p_nodes:
if pn.output[0] == eltwise_inp:
break
pn_shape = model.get_tensor_shape(pn.output[0])
if pn_shape == subgraph_inp_shape:
# input with matching shape, inverse transpose not needed
continue
# input with matching dimensions, add inverse transpose
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
Expand Down Expand Up @@ -323,6 +331,15 @@ def apply(self, model):
if inp == output_tensor_name:
target_node.input[i] = input_tensor

n_successors = model.find_direct_successors(n)
if n_successors is None:
continue

for ns in n_successors:
for i, inp in enumerate(ns.input):
if inp == n.output[0]:
ns.input[i] = input_tensor

# remove old nodes
graph.node.remove(n)
graph.node.remove(successor_node)
Expand Down

0 comments on commit ee2eb20

Please sign in to comment.