From 8f1f5eeae7135481f26272445c0f09d2483c5040 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 12 Jan 2024 17:11:22 +0100 Subject: [PATCH] Fix (equalize): improved cat equalization check (#793) --- src/brevitas/graph/equalize.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fa455cbac..580e4eb24 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -801,10 +801,13 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, channels = [] for n in node.all_input_nodes: channel_dim = find_srcs_channel_dim(graph_model, n) - if channel_dim is _UNSUPPORTED_OP: - state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP - continue channels.append(channel_dim) + + # If we found an unsupported op while walking up, we exit this branch and + # invalidate the region + if _UNSUPPORTED_OP in channels: + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP + continue start = sum(channels[:index]) end = start + channels[index] new_state = WalkRegionState(offset=state.offset)