diff --git a/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml b/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml new file mode 100644 index 000000000..78c3fb8a6 --- /dev/null +++ b/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added new methods, :meth:`.PyDiGraph.subsitute_subgraph` and + :meth:`.PyGraph.substitute_subgraph`, which is used to replace + a subgraph in a graph object with an external graph. diff --git a/src/digraph.rs b/src/digraph.rs index 2d8a97d18..2d392c5c4 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -22,7 +22,7 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use rustworkx_core::dictmap::*; @@ -226,9 +226,10 @@ impl PyDiGraph { p_index: NodeIndex, c_index: NodeIndex, edge: PyObject, + force: bool, ) -> PyResult { // Only check for cycles if instance attribute is set to true - if self.check_cycle { + if self.check_cycle || force { // Only check for a cycle (by running has_path_connecting) if // the new edge could potentially add a cycle let cycle_check_required = is_cycle_check_required(self, p_index, c_index); @@ -269,11 +270,11 @@ impl PyDiGraph { .collect::>(); for (other_index, edge_index, weight) in edges { if direction { - self._add_edge(node_between_index, index, weight.clone_ref(py))?; - self._add_edge(index, other_index, weight.clone_ref(py))?; + self._add_edge(node_between_index, index, weight.clone_ref(py), false)?; + self._add_edge(index, other_index, weight.clone_ref(py), false)?; } else { - self._add_edge(other_index, index, weight.clone_ref(py))?; - self._add_edge(index, node_between_index, weight.clone_ref(py))?; + self._add_edge(other_index, index, weight.clone_ref(py), false)?; + self._add_edge(index, node_between_index, weight.clone_ref(py), false)?; } self.graph.remove_edge(edge_index); } @@ -1029,7 +1030,7 @@ impl PyDiGraph { } } for (source, target, weight) in edge_list { - self._add_edge(source, target, weight)?; + self._add_edge(source, target, weight, false)?; } self.graph.remove_node(index); self.node_removed = true; @@ -1061,7 +1062,7 @@ impl PyDiGraph { "One of the endpoints of the edge does not exist in graph", )); } - let out_index = self._add_edge(p_index, c_index, edge)?; + let out_index = self._add_edge(p_index, c_index, edge, false)?; Ok(out_index) } @@ -1131,7 +1132,12 @@ impl PyDiGraph { while max_index >= self.node_count() { self.graph.add_node(py.None()); } - self._add_edge(NodeIndex::new(source), NodeIndex::new(target), py.None())?; + self._add_edge( + NodeIndex::new(source), + NodeIndex::new(target), + py.None(), + false, + )?; } Ok(()) } @@ -1156,7 +1162,12 @@ impl PyDiGraph { while max_index >= self.node_count() { self.graph.add_node(py.None()); } - self._add_edge(NodeIndex::new(source), NodeIndex::new(target), weight)?; + self._add_edge( + NodeIndex::new(source), + NodeIndex::new(target), + weight, + false, + )?; } Ok(()) } @@ -2264,7 +2275,7 @@ impl PyDiGraph { let new_p_index = new_node_map.get(&edge.source()).unwrap(); let new_c_index = new_node_map.get(&edge.target()).unwrap(); let weight = weight_transform_callable(py, &edge_map_func, edge.weight())?; - self._add_edge(*new_p_index, *new_c_index, weight)?; + self._add_edge(*new_p_index, *new_c_index, weight, false)?; } // Add edges from map for (this_index, (index, weight)) in node_map.iter() { @@ -2273,6 +2284,7 @@ impl PyDiGraph { NodeIndex::new(*this_index), *new_index, weight.clone_ref(py), + false, )?; } let out_dict = PyDict::new(py); @@ -2378,6 +2390,7 @@ impl PyDiGraph { NodeIndex::new(out_map[&edge.source().index()]), NodeIndex::new(out_map[&edge.target().index()]), weight_map_fn(edge.weight(), &edge_weight_map)?, + false, )?; } // Add edges to/from node to nodes in other @@ -2405,7 +2418,7 @@ impl PyDiGraph { }, None => continue, }; - self._add_edge(source, target_out, weight)?; + self._add_edge(source, target_out, weight, false)?; } for (source, target, weight) in out_edges { let old_index = map_fn(source.index(), target.index(), &weight)?; @@ -2421,7 +2434,7 @@ impl PyDiGraph { }, None => continue, }; - self._add_edge(source_out, target, weight)?; + self._add_edge(source_out, target, weight, false)?; } // Remove node self.remove_node(node_index.index())?; @@ -2605,6 +2618,137 @@ impl PyDiGraph { } } + /// Substitute a subgraph in the graph with a different subgraph + /// + /// This is used to replace a subgraph in this graph with another graph. A similar result + /// can be achieved by combining :meth:`~.PyDiGraph.contract_nodes` and + /// :meth:`~.PyDiGraph.substitute_node_with_subgraph`. + /// + /// :param list nodes: A list of nodes in this graph representing the subgraph + /// to be removed. + /// :param PyDiGraph other: The subgraph to replace ``nodes`` with + /// :param dict input_node_map: The mapping of node indices from ``nodes`` to a node + /// in ``subgraph``. This is used for incoming and outgoing edges into the removed + /// subgraph. This will replace any edges connected to a node in ``nodes`` with the + /// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this + /// mapping. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``subgraph`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :param bool cycle_check: To check and raise if the substitution would introduce a cycle. + /// If set to ``True`` or :attr:`.check_cycle` is set to ``True`` when a cycle would be + /// added a :class:`~.DAGWouldCycle` exception will be raised. However, in this case the + /// state of the graph will be partially modified through the internal steps required for the + /// substitution. If your intent is to detect and use the graph if a + /// cycle were to be detected, you should make a copy of the graph + /// (see :meth:`.copy`) prior to calling this method so you have a + /// copy of the input graph to use. + /// + /// :returns: A mapping of node indices in ``other`` to the new node index in this graph + /// :rtype: NodeMap + /// + /// :raises DAGWouldCycle: If ``cycle_check`` or the :attr:`.check_cycle` attribute are set to + /// ``True`` and a cycle would be introduced by the substitution. + #[pyo3(signature=(nodes, other, input_node_map, edge_weight_map=None, cycle_check=false))] + pub fn substitute_subgraph( + &mut self, + py: Python, + nodes: Vec, + other: &PyDiGraph, + input_node_map: HashMap, + edge_weight_map: Option, + cycle_check: bool, + ) -> PyResult { + let mut in_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut out_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut node_map: IndexMap = + IndexMap::with_capacity_and_hasher( + other.graph.node_count(), + ahash::RandomState::default(), + ); + let removed_nodes: HashSet = nodes.iter().map(|n| NodeIndex::new(*n)).collect(); + + let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { + match weight_fn { + Some(weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), + } + }; + for node in nodes { + let index = NodeIndex::new(node); + in_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Incoming) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.source()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + out_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Outgoing) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.target()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + self.graph.remove_node(index); + } + for node in other.graph.node_indices() { + let weight = other.graph.node_weight(node).unwrap(); + let new_index = self.graph.add_node(weight.clone_ref(py)); + node_map.insert(node.index(), new_index.index()); + } + for edge in other.graph.edge_references() { + let new_source = node_map[edge.source().index()]; + let new_target = node_map[edge.target().index()]; + self._add_edge( + NodeIndex::new(new_source), + NodeIndex::new(new_target), + weight_map_fn(edge.weight(), &edge_weight_map)?, + cycle_check, + )?; + } + for edge in out_nodes { + let old_source = edge.0; + let new_source = match input_node_map.get(&old_source.index()) { + Some(new_source) => NodeIndex::new(node_map[new_source]), + None => { + let missing_index = old_source.index(); + return Err(PyIndexError::new_err(format!( + "Input node {} not found in io_node_map", + missing_index + ))); + } + }; + self._add_edge(new_source, edge.1, edge.2, cycle_check)?; + } + for edge in in_nodes { + let old_target = edge.1; + let new_target = match input_node_map.get(&old_target.index()) { + Some(new_target) => NodeIndex::new(node_map[new_target]), + None => { + let missing_index = old_target.index(); + return Err(PyIndexError::new_err(format!( + "Output node {} not found in io_node_map", + missing_index + ))); + } + }; + self._add_edge(edge.0, new_target, edge.2, cycle_check)?; + } + Ok(NodeMap { node_map }) + } + /// Return a new PyDiGraph object for an edge induced subgraph of this graph /// /// The induced subgraph contains each edge in `edge_list` and each node @@ -2716,7 +2860,7 @@ impl PyDiGraph { Some(callback) => callback.call1(py, (forward_weight,))?, None => forward_weight.clone_ref(py), }; - self._add_edge(*edge_target, *edge_source, weight)?; + self._add_edge(*edge_target, *edge_source, weight, false)?; } } Ok(()) diff --git a/src/graph.rs b/src/graph.rs index 45d8902a7..96fb88750 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -20,7 +20,7 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; @@ -35,10 +35,10 @@ use num_traits::Zero; use numpy::Complex64; use numpy::PyReadonlyArray2; -use crate::iterators::NodeMap; - use super::dot_utils::build_dot; -use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; +use super::iterators::{ + EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, +}; use super::{ find_node_by_weight, merge_duplicates, weight_callable, IsNan, NoEdgeBetweenNodes, NodesRemoved, StablePyGraph, @@ -1947,6 +1947,111 @@ impl PyGraph { out_graph } + /// Substitute a subgraph in the graph with a different subgraph + /// + /// :param list nodes: A list of nodes in this graph representing the subgraph + /// to be removed. + /// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with + /// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node + /// in ``subgraph``. This is used for incoming and outgoing edges into the removed + /// subgraph. This will replace any edges conneted to a node in ``nodes`` with the + /// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this + /// mapping. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``subgraph`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :returns: A mapping of node indices in ``other`` to the new node index in this graph + /// :rtype: NodeMap + pub fn substitute_subgraph( + &mut self, + py: Python, + nodes: Vec, + other: &PyGraph, + input_node_map: HashMap, + edge_weight_map: Option, + ) -> PyResult { + let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut node_map: IndexMap = + IndexMap::with_capacity_and_hasher( + other.graph.node_count(), + ahash::RandomState::default(), + ); + let removed_nodes: HashSet = nodes.iter().map(|n| NodeIndex::new(*n)).collect(); + + let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { + match weight_fn { + Some(weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), + } + }; + for node in nodes { + let index = NodeIndex::new(node); + io_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Outgoing) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.target()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + self.graph.remove_node(index); + } + for node in other.graph.node_indices() { + let weight = other.graph.node_weight(node).unwrap(); + let new_index = self.graph.add_node(weight.clone_ref(py)); + node_map.insert(node.index(), new_index.index()); + } + for edge in other.graph.edge_references() { + let new_source = node_map[edge.source().index()]; + let new_target = node_map[edge.target().index()]; + self.graph.add_edge( + NodeIndex::new(new_source), + NodeIndex::new(new_target), + weight_map_fn(edge.weight(), &edge_weight_map)?, + ); + } + for edge in io_nodes { + let old_source = edge.0; + let new_source = if removed_nodes.contains(&old_source) { + match input_node_map.get(&old_source.index()) { + Some(new_source) => NodeIndex::new(node_map[new_source]), + None => { + let missing_index = old_source.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {} not found in io_node_map", + missing_index + ))); + } + } + } else { + old_source + }; + let old_target = edge.1; + let new_target = if removed_nodes.contains(&old_target) { + match input_node_map.get(&old_target.index()) { + Some(new_target) => NodeIndex::new(node_map[new_target]), + None => { + let missing_index = old_target.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {} not found in io_node_map", + missing_index + ))); + } + } + } else { + old_target + }; + self.graph.add_edge(new_source, new_target, edge.2); + } + Ok(NodeMap { node_map }) + } + /// Return a shallow copy of the graph /// /// All node and edge weight/data payloads in the copy will have a diff --git a/tests/rustworkx_tests/digraph/test_substitute_subgraph.py b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py new file mode 100644 index 000000000..25004dea2 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py @@ -0,0 +1,57 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestSubstitute(unittest.TestCase): + def setUp(self): + super().setUp() + self.graph = rustworkx.generators.directed_path_graph(5) + + def test_empty_replacement(self): + in_graph = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + self.graph.substitute_subgraph([2], in_graph, {}) + + def test_single_node(self): + in_graph = rustworkx.PyDiGraph() + in_graph.add_node(0) + in_graph.add_child(0, 1, "edge") + res = self.graph.substitute_subgraph([2], in_graph, {2: 0}) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (1, 2)], self.graph.edge_list()) + self.assertEqual("edge", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_edge_weight_modifier(self): + in_graph = rustworkx.PyDiGraph() + in_graph.add_node(0) + in_graph.add_child(0, 1, "edge") + res = self.graph.substitute_subgraph( + [2], + in_graph, + {2: 0}, + edge_weight_map=lambda edge: edge + "-migrated", + ) + + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (1, 2)], self.graph.edge_list()) + self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_multiple_mapping(self): + graph = rustworkx.generators.directed_star_graph(5) + in_graph = rustworkx.generators.directed_star_graph(3, inward=True) + res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2}) + self.assertEqual({0: 2, 1: 1, 2: 0}, res) + expected = [(1, 2), (0, 2), (2, 4), (2, 3)] + self.assertEqual(expected, graph.edge_list()) diff --git a/tests/rustworkx_tests/graph/test_substitute_subgraph.py b/tests/rustworkx_tests/graph/test_substitute_subgraph.py new file mode 100644 index 000000000..3cb992fca --- /dev/null +++ b/tests/rustworkx_tests/graph/test_substitute_subgraph.py @@ -0,0 +1,58 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestSubstitute(unittest.TestCase): + def setUp(self): + super().setUp() + self.graph = rustworkx.generators.path_graph(5) + + def test_empty_replacement(self): + in_graph = rustworkx.PyGraph() + with self.assertRaises(IndexError): + self.graph.substitute_subgraph([2], in_graph, {}) + + def test_single_node(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node(0) + in_graph.add_node(1) + in_graph.add_edge(0, 1, "edge") + res = self.graph.substitute_subgraph([2], in_graph, {2: 0}) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list()) + self.assertEqual("edge", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_edge_weight_modifier(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node(0) + in_graph.add_node(1) + in_graph.add_edge(0, 1, "edge") + res = self.graph.substitute_subgraph( + [2], + in_graph, + {2: 0}, + edge_weight_map=lambda edge: edge + "-migrated", + ) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list()) + self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_multiple_mapping(self): + graph = rustworkx.generators.star_graph(5) + in_graph = rustworkx.generators.path_graph(3) + res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2}) + self.assertEqual({0: 2, 1: 1, 2: 0}, res) + expected = [(2, 1), (1, 0), (2, 4), (2, 3)] + self.assertEqual(expected, graph.edge_list())