Skip to content

Commit

Permalink
Clean up classes (#353)
Browse files Browse the repository at this point in the history
* moved IDDict

* remove keys

* remove extraneous arguments

* style: format with black
  • Loading branch information
nwlandry authored May 6, 2023
1 parent 591d681 commit a58a4b8
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 77 deletions.
4 changes: 1 addition & 3 deletions docs/source/api/classes/xgi.classes.hypergraph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@
:toctree: .
:nosignatures:

Hypergraph

IDDict
Hypergraph
2 changes: 2 additions & 0 deletions docs/source/api/utils/xgi.utils.utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
:toctree: .
:nosignatures:

IDDict


.. rubric:: Functions

Expand Down
2 changes: 1 addition & 1 deletion tests/classes/test_hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_double_edge_swap(edgelist1):

with pytest.raises(IDNotFound):
H.double_edge_swap(8, 3, 0, 1)

H = xgi.Hypergraph(edgelist1)
with pytest.raises(XGIError):
H.double_edge_swap(6, 7, 2, 3)
Expand Down
17 changes: 13 additions & 4 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import xgi
from xgi.exception import XGIError


def test_convert_empty_hypergraph():
H = xgi.convert_to_hypergraph(None)
assert H.num_nodes == 0
assert H.num_edges == 0


def test_convert_simplicial_complex_to_hypergraph():
SC = xgi.SimplicialComplex()
SC.add_simplices_from([[3, 4, 5], [3, 6], [6, 7, 8, 9], [1, 4, 10, 11, 12], [1, 4]])
Expand All @@ -18,43 +20,50 @@ def test_convert_simplicial_complex_to_hypergraph():
assert SC.nodes == H.nodes
assert SC.edges.maximal().members() == H.edges.members()


def test_convert_list_to_hypergraph(edgelist2):
H = xgi.convert_to_hypergraph(edgelist2)
assert isinstance(H, xgi.Hypergraph)
assert set(H.nodes) == {1, 2, 3, 4, 5, 6}
assert H.edges.members() == [{1, 2}, {3, 4}, {4, 5, 6}]



def test_convert_pandas_dataframe_to_hypergraph(dataframe5):
H = xgi.convert_to_hypergraph(dataframe5)
assert isinstance(H, xgi.Hypergraph)
assert set(H.nodes) == set(dataframe5['col1'])
assert set(H.nodes) == set(dataframe5["col1"])
assert H.edges.members() == [{0, 1, 2, 3}, {4}, {5, 6}, {8, 6, 7}]


def test_convert_empty_simplicial_complex():
S = xgi.convert_to_simplicial_complex(None)
assert S.num_nodes == 0
assert S.num_edges == 0


def test_convert_hypergraph_to_simplicial_complex():
H = xgi.Hypergraph()
H.add_edges_from([[1, 2, 3], [3, 4], [4, 5, 6, 7], [7, 8, 9, 10, 11]])
SC = xgi.convert_to_simplicial_complex(H)
assert isinstance(SC, xgi.SimplicialComplex)
assert H.nodes == SC.nodes
assert H.edges.members() == SC.edges.maximal().members()



def test_convert_list_to_simplicial_complex(edgelist2):
SC = xgi.convert_to_simplicial_complex(edgelist2)
assert isinstance(SC, xgi.SimplicialComplex)
assert set(SC.nodes) == {1, 2, 3, 4, 5, 6}
assert SC.edges.maximal().members() == [{1, 2}, {3, 4}, {4, 5, 6}]


def test_convert_pandas_dataframe_to_simplicial_complex(dataframe5):
SC = xgi.convert_to_simplicial_complex(dataframe5)
assert isinstance(SC, xgi.SimplicialComplex)
assert set(SC.nodes) == set(dataframe5['col1'])
assert set(SC.nodes) == set(dataframe5["col1"])
assert SC.edges.maximal().members() == [{0, 1, 2, 3}, {4}, {5, 6}, {8, 6, 7}]


def test_convert_to_graph(edgelist2, edgelist5):
H1 = xgi.Hypergraph(edgelist2)
H2 = xgi.Hypergraph(edgelist5)
Expand Down
47 changes: 9 additions & 38 deletions xgi/classes/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,12 @@
from warnings import warn

from ..exception import IDNotFound, XGIError
from ..utils.utilities import update_uid_counter
from ..utils.utilities import IDDict, update_uid_counter
from .reportviews import EdgeView, NodeView

__all__ = ["Hypergraph"]


class IDDict(dict):
"""A dict that holds (node or edge) IDs.
For internal use only. Adds input validation functionality to the internal dicts
that hold nodes and edges in a network.
"""

def __getitem__(self, item):
try:
return dict.__getitem__(self, item)
except KeyError as e:
raise IDNotFound(f"ID {item} not found") from e

def __setitem__(self, item, value):
if item is None:
raise XGIError("None cannot be a node or edge")
try:
return dict.__setitem__(self, item, value)
except TypeError as e:
raise TypeError(f"ID {item} not a valid type") from e

def __delitem__(self, item):
try:
return dict.__delitem__(self, item)
except KeyError as e:
raise IDNotFound(f"ID {item} not found") from e


class Hypergraph:
r"""A hypergraph is a collection of subsets of a set of *nodes* or *vertices*.
Expand Down Expand Up @@ -834,27 +805,27 @@ def double_edge_swap(self, n_id1, n_id2, e_id1, e_id2):
temp_members1.add(n_id2)
temp_members2.add(n_id1)

# Now we handle the memberships
# Now we handle the memberships
# remove old nodes from edges
temp_memberships1.remove(e_id1)
temp_memberships2.remove(e_id2)

# swap nodes
temp_memberships1.add(e_id2)
temp_memberships2.add(e_id1)

temp_memberships2.add(e_id1)

except KeyError as e:

raise IDNotFound(
"One of the nodes specified doesn't belong to the specified edge."
) from e

if (len(temp_memberships1) != len(self._node[n_id1]) or
len(temp_memberships2) != len(self._node[n_id2]) or
len(temp_members1) != len(self._edge[e_id1]) or
len(temp_members2) != len(self._edge[e_id2])
):
if (
len(temp_memberships1) != len(self._node[n_id1])
or len(temp_memberships2) != len(self._node[n_id2])
or len(temp_members1) != len(self._edge[e_id1])
or len(temp_members2) != len(self._edge[e_id2])
):
raise XGIError("This swap does not preserve edge sizes.")

self._node[n_id1] = temp_memberships1
Expand Down
47 changes: 22 additions & 25 deletions xgi/classes/reportviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ class IDView(Mapping, Set):

__slots__ = (
"_net",
"_id_dict",
"_id_attr",
"_bi_id_dict",
"_bi_id_attr",
"_ids",
)

Expand All @@ -63,10 +59,7 @@ def __getstate__(self):
"""
return {
"_id_dict": self._id_dict,
"_id_attr": self._id_attr,
"_bi_id_dict": self._bi_id_dict,
"_bi_id_attr": self._bi_id_attr,
"_net": self._net,
"_ids": self._ids,
}

Expand All @@ -80,18 +73,22 @@ def __setstate__(self, state):
and the values are dictionarys from the Hypergraph class.
"""
self._id_dict = state["_id_dict"]
self._id_attr = state["_id_attr"]
self._bi_id_dict = state["_bi_id_dict"]
self._bi_id_attr = state["_bi_id_attr"]
self._ids = state["_ids"]
self._net = state["_net"]
self._id_kind = state["_id_kind"]

def __init__(self, network, id_dict, id_attr, bi_id_dict, bi_id_attr, ids=None):
def __init__(self, network, ids=None):
self._net = network
self._id_dict = id_dict
self._id_attr = id_attr
self._bi_id_dict = bi_id_dict
self._bi_id_attr = bi_id_attr

if self._id_kind == "node":
self._id_dict = None if self._net is None else network._node
self._id_attr = None if self._net is None else network._node_attr
self._bi_id_dict = None if self._net is None else network._edge
self._bi_id_attr = None if self._net is None else network._edge_attr
elif self._id_kind == "edge":
self._id_dict = None if self._net is None else network._edge
self._id_attr = None if self._net is None else network._edge_attr
self._bi_id_dict = None if self._net is None else network._node
self._bi_id_attr = None if self._net is None else network._node_attr

if ids is None:
self._ids = self._id_dict
Expand All @@ -116,7 +113,7 @@ def ids(self):
always use `x in view`. The latter is always faster.
"""
return set(self._id_dict.keys()) if self._ids is None else self._ids
return set(self._id_dict) if self._ids is None else self._ids

def __len__(self):
"""The number of IDs."""
Expand All @@ -125,7 +122,7 @@ def __len__(self):
def __iter__(self):
"""Returns an iterator over the IDs."""
if self._ids is None:
return iter({}) if self._id_dict is None else iter(self._id_dict.keys())
return iter({}) if self._id_dict is None else iter(self._id_dict)
else:
return iter(self._ids)

Expand Down Expand Up @@ -497,7 +494,7 @@ def from_view(cls, view, bunch=None):
newview._id_attr = view._id_attr
newview._bi_id_dict = view._bi_id_dict
newview._bi_id_attr = view._bi_id_attr
all_ids = set(view._id_dict.keys())
all_ids = set(view._id_dict)
if bunch is None:
newview._ids = all_ids
else:
Expand Down Expand Up @@ -545,9 +542,9 @@ class NodeView(IDView):

def __init__(self, H, bunch=None):
if H is None:
super().__init__(None, None, None, None, None, bunch)
super().__init__(None, bunch)
else:
super().__init__(H, H._node, H._node_attr, H._edge, H._edge_attr, bunch)
super().__init__(H, bunch)

def memberships(self, n=None):
"""Get the edge ids of which a node is a member.
Expand Down Expand Up @@ -642,9 +639,9 @@ class EdgeView(IDView):

def __init__(self, H, bunch=None):
if H is None:
super().__init__(None, None, None, None, None, bunch)
super().__init__(None, bunch)
else:
super().__init__(H, H._edge, H._edge_attr, H._node, H._node_attr, bunch)
super().__init__(H, bunch)

def members(self, e=None, dtype=list):
"""Get the node ids that are members of an edge.
Expand Down
9 changes: 3 additions & 6 deletions xgi/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,9 @@ def convert_to_hypergraph(data, create_using=None):
H = empty_hypergraph(create_using)
H.add_nodes_from((n, attr) for n, attr in data.nodes.items())
ee = data.edges
H.add_edges_from(
(ee.members(e), e, deepcopy(attr)) for e, attr in ee.items()
)
H.add_edges_from((ee.members(e), e, deepcopy(attr)) for e, attr in ee.items())
H._hypergraph = deepcopy(data._hypergraph)
return H


elif isinstance(data, SimplicialComplex):
return from_max_simplices(data)
Expand Down Expand Up @@ -223,7 +220,7 @@ def convert_to_simplicial_complex(data, create_using=None):
)
H._hypergraph = deepcopy(data._hypergraph)
return H

elif isinstance(data, Hypergraph):
H = empty_simplicial_complex(create_using)
H.add_nodes_from((n, attr) for n, attr in data.nodes.items())
Expand All @@ -243,7 +240,7 @@ def convert_to_simplicial_complex(data, create_using=None):
result = from_bipartite_pandas_dataframe(data, create_using)
if not isinstance(create_using, SimplicialComplex):
return convert_to_simplicial_complex(result)

elif isinstance(data, dict):
# edge dict in the form we need
raise XGIError("Cannot generate SimplicialComplex from simplex dictionary")
Expand Down
32 changes: 32 additions & 0 deletions xgi/utils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,46 @@
from collections import defaultdict
from itertools import chain, combinations, count

from xgi.exception import IDNotFound, XGIError

__all__ = [
"IDDict",
"dual_dict",
"powerset",
"update_uid_counter",
"find_triangles",
]


class IDDict(dict):
"""A dict that holds (node or edge) IDs.
For internal use only. Adds input validation functionality to the internal dicts
that hold nodes and edges in a network.
"""

def __getitem__(self, item):
try:
return dict.__getitem__(self, item)
except KeyError as e:
raise IDNotFound(f"ID {item} not found") from e

def __setitem__(self, item, value):
if item is None:
raise XGIError("None cannot be a node or edge")
try:
return dict.__setitem__(self, item, value)
except TypeError as e:
raise TypeError(f"ID {item} not a valid type") from e

def __delitem__(self, item):
try:
return dict.__delitem__(self, item)
except KeyError as e:
raise IDNotFound(f"ID {item} not found") from e


def dual_dict(edge_dict):
"""Given a dictionary with IDs as keys
and sets as values, return the dual.
Expand Down

0 comments on commit a58a4b8

Please sign in to comment.