diff --git a/graph/base.py b/graph/base.py index 94ea70c..8f5c63a 100644 --- a/graph/base.py +++ b/graph/base.py @@ -35,6 +35,8 @@ def __eq__(self, other): return False if self._nodes != other._nodes: return False + if self._edge_count != other._edge_count: + return False if self._edges != other._edges: return False return True @@ -126,17 +128,16 @@ def del_edge(self, node1, node2): del self._edges[node1][node2] except KeyError: return + if not self._edges[node1]: + del self._edges[node1] del self._reverse_edges[node2][node1] - self._edge_count -= 1 + if not self._reverse_edges[node2]: + del self._reverse_edges[node2] - if self._out_degree[node1] == 0: - assert False self._out_degree[node1] -= 1 - - if self._in_degree[node2] == 0: - assert False self._in_degree[node2] -= 1 + self._edge_count -= 1 def add_node(self, node_id, obj=None): """ @@ -169,18 +170,23 @@ def del_node(self, node_id): return # outgoing - for n2, d in self._edges[node_id].copy().items(): + for n2, _ in self._edges[node_id].copy().items(): self.del_edge(node_id, n2) # incoming - for n1, d in self._reverse_edges[node_id].copy().items(): + for n1, _ in self._reverse_edges[node_id].copy().items(): self.del_edge(n1, node_id) - del self._edges[node_id] - del self._reverse_edges[node_id] # removes outgoing edges - del self._nodes[node_id] - del self._in_degree[node_id] - del self._out_degree[node_id] + for _d in [ + self._edges, + self._reverse_edges, + self._nodes, + self._in_degree, + self._out_degree]: + try: + del _d[node_id] + except KeyError: + pass def nodes(self, from_node=None, to_node=None, in_degree=None, out_degree=None): """ diff --git a/tests/test_basics.py b/tests/test_basics.py index f946820..bb192ef 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -230,6 +230,25 @@ def test_in_and_out_degree(): assert g.in_degree(node) == in_degree[node] assert g.out_degree(node) == out_degree[node] +def test_equals(): + g1 = Graph(from_list=[(1,2),(2,3)]) + g2 = g1.copy() + assert g1 == g2 + + g1.add_edge(3,1) + g1.del_edge(3,1) + assert g1 == g2 + + assert g1.edges() == g2.edges() + assert g1.nodes() == g2.nodes() + +def test_copy_equals(): + g1 = Graph(from_list=[(1,2),(2,3)]) + g1.add_edge(3,1) + g1.del_edge(3,1) + g2 = g1.copy() + assert g1 == g2 + def test_copy(): g = graph05()