From a58a4b86b3ffe4def071b108866b2ea26a9bab1c Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Sat, 6 May 2023 08:23:09 -0400 Subject: [PATCH] Clean up classes (#353) * moved IDDict * remove keys * remove extraneous arguments * style: format with black --- .../api/classes/xgi.classes.hypergraph.rst | 4 +- docs/source/api/utils/xgi.utils.utilities.rst | 2 + tests/classes/test_hypergraph.py | 2 +- tests/test_convert.py | 17 +++++-- xgi/classes/hypergraph.py | 47 ++++--------------- xgi/classes/reportviews.py | 47 +++++++++---------- xgi/convert.py | 9 ++-- xgi/utils/utilities.py | 32 +++++++++++++ 8 files changed, 83 insertions(+), 77 deletions(-) diff --git a/docs/source/api/classes/xgi.classes.hypergraph.rst b/docs/source/api/classes/xgi.classes.hypergraph.rst index 3b0950748..e34178796 100644 --- a/docs/source/api/classes/xgi.classes.hypergraph.rst +++ b/docs/source/api/classes/xgi.classes.hypergraph.rst @@ -11,6 +11,4 @@ :toctree: . :nosignatures: - Hypergraph - - IDDict \ No newline at end of file + Hypergraph \ No newline at end of file diff --git a/docs/source/api/utils/xgi.utils.utilities.rst b/docs/source/api/utils/xgi.utils.utilities.rst index bec558c19..d6e453baf 100644 --- a/docs/source/api/utils/xgi.utils.utilities.rst +++ b/docs/source/api/utils/xgi.utils.utilities.rst @@ -11,6 +11,8 @@ :toctree: . :nosignatures: + IDDict + .. rubric:: Functions diff --git a/tests/classes/test_hypergraph.py b/tests/classes/test_hypergraph.py index c92741395..f446e4657 100644 --- a/tests/classes/test_hypergraph.py +++ b/tests/classes/test_hypergraph.py @@ -466,7 +466,7 @@ def test_double_edge_swap(edgelist1): with pytest.raises(IDNotFound): H.double_edge_swap(8, 3, 0, 1) - + H = xgi.Hypergraph(edgelist1) with pytest.raises(XGIError): H.double_edge_swap(6, 7, 2, 3) diff --git a/tests/test_convert.py b/tests/test_convert.py index 6b2fabfec..9f68197d6 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -5,11 +5,13 @@ import xgi from xgi.exception import XGIError + def test_convert_empty_hypergraph(): H = xgi.convert_to_hypergraph(None) assert H.num_nodes == 0 assert H.num_edges == 0 + def test_convert_simplicial_complex_to_hypergraph(): SC = xgi.SimplicialComplex() SC.add_simplices_from([[3, 4, 5], [3, 6], [6, 7, 8, 9], [1, 4, 10, 11, 12], [1, 4]]) @@ -18,23 +20,27 @@ def test_convert_simplicial_complex_to_hypergraph(): assert SC.nodes == H.nodes assert SC.edges.maximal().members() == H.edges.members() + def test_convert_list_to_hypergraph(edgelist2): H = xgi.convert_to_hypergraph(edgelist2) assert isinstance(H, xgi.Hypergraph) assert set(H.nodes) == {1, 2, 3, 4, 5, 6} assert H.edges.members() == [{1, 2}, {3, 4}, {4, 5, 6}] - + + def test_convert_pandas_dataframe_to_hypergraph(dataframe5): H = xgi.convert_to_hypergraph(dataframe5) assert isinstance(H, xgi.Hypergraph) - assert set(H.nodes) == set(dataframe5['col1']) + assert set(H.nodes) == set(dataframe5["col1"]) assert H.edges.members() == [{0, 1, 2, 3}, {4}, {5, 6}, {8, 6, 7}] + def test_convert_empty_simplicial_complex(): S = xgi.convert_to_simplicial_complex(None) assert S.num_nodes == 0 assert S.num_edges == 0 + def test_convert_hypergraph_to_simplicial_complex(): H = xgi.Hypergraph() H.add_edges_from([[1, 2, 3], [3, 4], [4, 5, 6, 7], [7, 8, 9, 10, 11]]) @@ -42,19 +48,22 @@ def test_convert_hypergraph_to_simplicial_complex(): assert isinstance(SC, xgi.SimplicialComplex) assert H.nodes == SC.nodes assert H.edges.members() == SC.edges.maximal().members() - + + def test_convert_list_to_simplicial_complex(edgelist2): SC = xgi.convert_to_simplicial_complex(edgelist2) assert isinstance(SC, xgi.SimplicialComplex) assert set(SC.nodes) == {1, 2, 3, 4, 5, 6} assert SC.edges.maximal().members() == [{1, 2}, {3, 4}, {4, 5, 6}] + def test_convert_pandas_dataframe_to_simplicial_complex(dataframe5): SC = xgi.convert_to_simplicial_complex(dataframe5) assert isinstance(SC, xgi.SimplicialComplex) - assert set(SC.nodes) == set(dataframe5['col1']) + assert set(SC.nodes) == set(dataframe5["col1"]) assert SC.edges.maximal().members() == [{0, 1, 2, 3}, {4}, {5, 6}, {8, 6, 7}] + def test_convert_to_graph(edgelist2, edgelist5): H1 = xgi.Hypergraph(edgelist2) H2 = xgi.Hypergraph(edgelist5) diff --git a/xgi/classes/hypergraph.py b/xgi/classes/hypergraph.py index 86c9055fa..32381d78d 100644 --- a/xgi/classes/hypergraph.py +++ b/xgi/classes/hypergraph.py @@ -6,41 +6,12 @@ from warnings import warn from ..exception import IDNotFound, XGIError -from ..utils.utilities import update_uid_counter +from ..utils.utilities import IDDict, update_uid_counter from .reportviews import EdgeView, NodeView __all__ = ["Hypergraph"] -class IDDict(dict): - """A dict that holds (node or edge) IDs. - - For internal use only. Adds input validation functionality to the internal dicts - that hold nodes and edges in a network. - - """ - - def __getitem__(self, item): - try: - return dict.__getitem__(self, item) - except KeyError as e: - raise IDNotFound(f"ID {item} not found") from e - - def __setitem__(self, item, value): - if item is None: - raise XGIError("None cannot be a node or edge") - try: - return dict.__setitem__(self, item, value) - except TypeError as e: - raise TypeError(f"ID {item} not a valid type") from e - - def __delitem__(self, item): - try: - return dict.__delitem__(self, item) - except KeyError as e: - raise IDNotFound(f"ID {item} not found") from e - - class Hypergraph: r"""A hypergraph is a collection of subsets of a set of *nodes* or *vertices*. @@ -834,15 +805,14 @@ def double_edge_swap(self, n_id1, n_id2, e_id1, e_id2): temp_members1.add(n_id2) temp_members2.add(n_id1) - # Now we handle the memberships + # Now we handle the memberships # remove old nodes from edges temp_memberships1.remove(e_id1) temp_memberships2.remove(e_id2) # swap nodes temp_memberships1.add(e_id2) - temp_memberships2.add(e_id1) - + temp_memberships2.add(e_id1) except KeyError as e: @@ -850,11 +820,12 @@ def double_edge_swap(self, n_id1, n_id2, e_id1, e_id2): "One of the nodes specified doesn't belong to the specified edge." ) from e - if (len(temp_memberships1) != len(self._node[n_id1]) or - len(temp_memberships2) != len(self._node[n_id2]) or - len(temp_members1) != len(self._edge[e_id1]) or - len(temp_members2) != len(self._edge[e_id2]) - ): + if ( + len(temp_memberships1) != len(self._node[n_id1]) + or len(temp_memberships2) != len(self._node[n_id2]) + or len(temp_members1) != len(self._edge[e_id1]) + or len(temp_members2) != len(self._edge[e_id2]) + ): raise XGIError("This swap does not preserve edge sizes.") self._node[n_id1] = temp_memberships1 diff --git a/xgi/classes/reportviews.py b/xgi/classes/reportviews.py index 008cbe203..4509f5603 100644 --- a/xgi/classes/reportviews.py +++ b/xgi/classes/reportviews.py @@ -45,10 +45,6 @@ class IDView(Mapping, Set): __slots__ = ( "_net", - "_id_dict", - "_id_attr", - "_bi_id_dict", - "_bi_id_attr", "_ids", ) @@ -63,10 +59,7 @@ def __getstate__(self): """ return { - "_id_dict": self._id_dict, - "_id_attr": self._id_attr, - "_bi_id_dict": self._bi_id_dict, - "_bi_id_attr": self._bi_id_attr, + "_net": self._net, "_ids": self._ids, } @@ -80,18 +73,22 @@ def __setstate__(self, state): and the values are dictionarys from the Hypergraph class. """ - self._id_dict = state["_id_dict"] - self._id_attr = state["_id_attr"] - self._bi_id_dict = state["_bi_id_dict"] - self._bi_id_attr = state["_bi_id_attr"] - self._ids = state["_ids"] + self._net = state["_net"] + self._id_kind = state["_id_kind"] - def __init__(self, network, id_dict, id_attr, bi_id_dict, bi_id_attr, ids=None): + def __init__(self, network, ids=None): self._net = network - self._id_dict = id_dict - self._id_attr = id_attr - self._bi_id_dict = bi_id_dict - self._bi_id_attr = bi_id_attr + + if self._id_kind == "node": + self._id_dict = None if self._net is None else network._node + self._id_attr = None if self._net is None else network._node_attr + self._bi_id_dict = None if self._net is None else network._edge + self._bi_id_attr = None if self._net is None else network._edge_attr + elif self._id_kind == "edge": + self._id_dict = None if self._net is None else network._edge + self._id_attr = None if self._net is None else network._edge_attr + self._bi_id_dict = None if self._net is None else network._node + self._bi_id_attr = None if self._net is None else network._node_attr if ids is None: self._ids = self._id_dict @@ -116,7 +113,7 @@ def ids(self): always use `x in view`. The latter is always faster. """ - return set(self._id_dict.keys()) if self._ids is None else self._ids + return set(self._id_dict) if self._ids is None else self._ids def __len__(self): """The number of IDs.""" @@ -125,7 +122,7 @@ def __len__(self): def __iter__(self): """Returns an iterator over the IDs.""" if self._ids is None: - return iter({}) if self._id_dict is None else iter(self._id_dict.keys()) + return iter({}) if self._id_dict is None else iter(self._id_dict) else: return iter(self._ids) @@ -497,7 +494,7 @@ def from_view(cls, view, bunch=None): newview._id_attr = view._id_attr newview._bi_id_dict = view._bi_id_dict newview._bi_id_attr = view._bi_id_attr - all_ids = set(view._id_dict.keys()) + all_ids = set(view._id_dict) if bunch is None: newview._ids = all_ids else: @@ -545,9 +542,9 @@ class NodeView(IDView): def __init__(self, H, bunch=None): if H is None: - super().__init__(None, None, None, None, None, bunch) + super().__init__(None, bunch) else: - super().__init__(H, H._node, H._node_attr, H._edge, H._edge_attr, bunch) + super().__init__(H, bunch) def memberships(self, n=None): """Get the edge ids of which a node is a member. @@ -642,9 +639,9 @@ class EdgeView(IDView): def __init__(self, H, bunch=None): if H is None: - super().__init__(None, None, None, None, None, bunch) + super().__init__(None, bunch) else: - super().__init__(H, H._edge, H._edge_attr, H._node, H._node_attr, bunch) + super().__init__(H, bunch) def members(self, e=None, dtype=list): """Get the node ids that are members of an edge. diff --git a/xgi/convert.py b/xgi/convert.py index 3362d6d7b..3bc20a2e2 100644 --- a/xgi/convert.py +++ b/xgi/convert.py @@ -76,12 +76,9 @@ def convert_to_hypergraph(data, create_using=None): H = empty_hypergraph(create_using) H.add_nodes_from((n, attr) for n, attr in data.nodes.items()) ee = data.edges - H.add_edges_from( - (ee.members(e), e, deepcopy(attr)) for e, attr in ee.items() - ) + H.add_edges_from((ee.members(e), e, deepcopy(attr)) for e, attr in ee.items()) H._hypergraph = deepcopy(data._hypergraph) return H - elif isinstance(data, SimplicialComplex): return from_max_simplices(data) @@ -223,7 +220,7 @@ def convert_to_simplicial_complex(data, create_using=None): ) H._hypergraph = deepcopy(data._hypergraph) return H - + elif isinstance(data, Hypergraph): H = empty_simplicial_complex(create_using) H.add_nodes_from((n, attr) for n, attr in data.nodes.items()) @@ -243,7 +240,7 @@ def convert_to_simplicial_complex(data, create_using=None): result = from_bipartite_pandas_dataframe(data, create_using) if not isinstance(create_using, SimplicialComplex): return convert_to_simplicial_complex(result) - + elif isinstance(data, dict): # edge dict in the form we need raise XGIError("Cannot generate SimplicialComplex from simplex dictionary") diff --git a/xgi/utils/utilities.py b/xgi/utils/utilities.py index 6cdd4d18e..65760c7fc 100644 --- a/xgi/utils/utilities.py +++ b/xgi/utils/utilities.py @@ -3,7 +3,10 @@ from collections import defaultdict from itertools import chain, combinations, count +from xgi.exception import IDNotFound, XGIError + __all__ = [ + "IDDict", "dual_dict", "powerset", "update_uid_counter", @@ -11,6 +14,35 @@ ] +class IDDict(dict): + """A dict that holds (node or edge) IDs. + + For internal use only. Adds input validation functionality to the internal dicts + that hold nodes and edges in a network. + + """ + + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError as e: + raise IDNotFound(f"ID {item} not found") from e + + def __setitem__(self, item, value): + if item is None: + raise XGIError("None cannot be a node or edge") + try: + return dict.__setitem__(self, item, value) + except TypeError as e: + raise TypeError(f"ID {item} not a valid type") from e + + def __delitem__(self, item): + try: + return dict.__delitem__(self, item) + except KeyError as e: + raise IDNotFound(f"ID {item} not found") from e + + def dual_dict(edge_dict): """Given a dictionary with IDs as keys and sets as values, return the dual.