diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index d8f2a65932..9a84eb6e13 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -1209,27 +1209,34 @@ def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[ # sender_comp will be the last element of cycle and receiver_comp will be the first. # So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]). for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]): - # We get the key and iterate those as we want to edit the graph data while - # iterating the edges and that would raise. - # Even though the connection key set in Pipeline.connect() uses only the - # sockets name we don't have clashes since it's only used to differentiate - # multiple edges between two nodes. - edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys()) - for edge_key in edge_keys: - edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key] - receiver_socket = edge_data["to_socket"] - if not receiver_socket.is_variadic and receiver_socket.is_mandatory: - continue - - # We found a breakable edge - sender_socket = edge_data["from_socket"] - edges_removed[sender_comp].append(sender_socket.name) - temp_graph.remove_edge(sender_comp, receiver_comp, edge_key) - - graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph) - if not graph_has_cycles: - # We removed all the cycles, we can stop - break + # for graphs with multiple nested cycles, we need to check if the edge hasn't + # been previously removed before we try to remove it again + edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp) + if edge_data is not None: + # We get the key and iterate those as we want to edit the graph data while + # iterating the edges and that would raise. + # Even though the connection key set in Pipeline.connect() uses only the + # sockets name we don't have clashes since it's only used to differentiate + # multiple edges between two nodes. + edge_keys = list(edge_data.keys()) + + for edge_key in edge_keys: + edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key] + receiver_socket = edge_data["to_socket"] + if not receiver_socket.is_variadic and receiver_socket.is_mandatory: + continue + + # We found a breakable edge + sender_socket = edge_data["from_socket"] + edges_removed[sender_comp].append(sender_socket.name) + temp_graph.remove_edge(sender_comp, receiver_comp, edge_key) + + # The following check seems heavy, considering there are more cycles to iterate on, + # although in some cases I can see that breaking one edge might break more than one cycle + graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph) + if not graph_has_cycles: + # We removed all the cycles, we can stop + break if not graph_has_cycles: # We removed all the cycles, nice diff --git a/releasenotes/notes/fix-break-sypported-cycles-in-grapth-f8769351fe4ca706.yaml b/releasenotes/notes/fix-break-sypported-cycles-in-grapth-f8769351fe4ca706.yaml new file mode 100644 index 0000000000..fca7114249 --- /dev/null +++ b/releasenotes/notes/fix-break-sypported-cycles-in-grapth-f8769351fe4ca706.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Prevents the pipeline from raising an exception when there are multiple nested cycles in the graph. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 1cd57a5b5b..2bde85d050 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -1581,3 +1581,28 @@ def test__find_receivers_from(self): ), ) ] + + def test__break_supported_cycles_in_graph(self): + # the following pipeline has a nested cycle, which is supported by Haystack + # but was causing an exception to be raised in the _break_supported_cycles_in_graph method + comp1 = component_class("Comp1", input_types={"value": int}, output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": Variadic[int]}, output_types={"value": int})() + comp3 = component_class("Comp3", input_types={"value": Variadic[int]}, output_types={"value": int})() + comp4 = component_class("Comp4", input_types={"value": Optional[int]}, output_types={"value": int})() + comp5 = component_class("Comp5", input_types={"value": Variadic[int]}, output_types={"value": int})() + pipe = Pipeline() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.add_component("comp3", comp3) + pipe.add_component("comp4", comp4) + pipe.add_component("comp5", comp5) + pipe.connect("comp1.value", "comp2.value") + pipe.connect("comp2.value", "comp3.value") + pipe.connect("comp3.value", "comp4.value") + pipe.connect("comp3.value", "comp5.value") + pipe.connect("comp4.value", "comp5.value") + pipe.connect("comp4.value", "comp3.value") + pipe.connect("comp5.value", "comp2.value") + + # the following call should not raise an exception + pipe._break_supported_cycles_in_graph()