diff --git a/tests/generators/test_nonuniform.py b/tests/generators/test_nonuniform.py index d2d4787a6..712980ae5 100644 --- a/tests/generators/test_nonuniform.py +++ b/tests/generators/test_nonuniform.py @@ -67,51 +67,3 @@ def test_random_hypergraph(): H4 = xgi.random_hypergraph(10, [0.1], order=2, seed=1) assert H4.num_nodes == 10 assert xgi.unique_edge_sizes(H4) == [3] - - -def test_random_simplicial_complex(): - # seed - S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1) - S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2) - S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2) - - assert S1._edge != S2._edge - assert S2._edge == S3._edge - - # wrong input - with pytest.raises(ValueError): - S1 = xgi.random_simplicial_complex(10, [1, 1.1]) - with pytest.raises(ValueError): - S1 = xgi.random_simplicial_complex(10, [1, -2]) - - -def test_random_flag_complex(): - # seed - S1 = xgi.random_flag_complex(10, 0.1, seed=1) - S2 = xgi.random_flag_complex(10, 0.1, seed=2) - S3 = xgi.random_flag_complex(10, 0.1, seed=2) - - assert S1._edge != S2._edge - assert S2._edge == S3._edge - - # wrong input - with pytest.raises(ValueError): - S1 = xgi.random_flag_complex(10, 1.1) - with pytest.raises(ValueError): - S1 = xgi.random_flag_complex(10, -2) - - -def test_random_flag_complex_d2(): - # seed - S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1) - S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2) - S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2) - - assert S1._edge != S2._edge - assert S2._edge == S3._edge - - # wrong input - with pytest.raises(ValueError): - S1 = xgi.random_flag_complex_d2(10, 1.1) - with pytest.raises(ValueError): - S1 = xgi.random_flag_complex_d2(10, -2) diff --git a/tests/generators/test_simplicial_complexes.py b/tests/generators/test_simplicial_complexes.py index 94f8aff7c..6daaadb21 100644 --- a/tests/generators/test_simplicial_complexes.py +++ b/tests/generators/test_simplicial_complexes.py @@ -1,4 +1,5 @@ import networkx as nx +import pytest import xgi from xgi.exception import XGIError @@ -21,6 +22,7 @@ def test_flag_complex(): assert S.edges.members() == simplices_3 + # ps S1 = xgi.flag_complex(G, ps=[1], seed=42) S2 = xgi.flag_complex(G, ps=[0.5], seed=42) S3 = xgi.flag_complex(G, ps=[0], seed=42) @@ -29,6 +31,7 @@ def test_flag_complex(): assert S2.edges.members() == simplices_2 assert S3.edges.members() == simplices_2 + # complete graph G1 = nx.complete_graph(4) S4 = xgi.flag_complex(G1) S5 = xgi.flag_complex(G1, ps=[1]) @@ -44,3 +47,116 @@ def test_flag_complex_d2(): S2 = xgi.flag_complex_d2(G) assert set(S.edges.members()) == set(S2.edges.members()) + + +def test_random_simplicial_complex(): + # seed + S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1) + S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2) + S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2) + + assert S1._edge != S2._edge + assert S2._edge == S3._edge + + # wrong input + with pytest.raises(ValueError): + S1 = xgi.random_simplicial_complex(10, [1, 1.1]) + with pytest.raises(ValueError): + S1 = xgi.random_simplicial_complex(10, [1, -2]) + + +def test_random_flag_complex(): + + S = xgi.random_flag_complex(10, 0.4, seed=2) + simplices = { + frozenset({0, 4}), + frozenset({0, 7}), + frozenset({1, 8}), + frozenset({2, 5}), + frozenset({2, 9}), + frozenset({3, 5}), + frozenset({3, 6}), + frozenset({3, 7}), + frozenset({3, 8}), + frozenset({4, 5}), + frozenset({4, 7}), + frozenset({4, 8}), + frozenset({6, 7}), + frozenset({6, 8}), + frozenset({7, 8}), + frozenset({0, 4, 7}), + frozenset({3, 6, 7}), + frozenset({3, 6, 8}), + frozenset({3, 7, 8}), + frozenset({4, 7, 8}), + frozenset({6, 7, 8}), + } + + assert set(S.edges.members()) == simplices + + # max_order + S = xgi.random_flag_complex(10, 0.4, seed=2, max_order=3) + assert set(S.edges.members()) == simplices.union({frozenset({3, 6, 7, 8})}) + + # seed + S1 = xgi.random_flag_complex(10, 0.1, seed=1) + S2 = xgi.random_flag_complex(10, 0.1, seed=2) + S3 = xgi.random_flag_complex(10, 0.1, seed=2) + + assert S1._edge != S2._edge + assert S2._edge == S3._edge + + # wrong input + with pytest.raises(ValueError): + S1 = xgi.random_flag_complex(10, 1.1) + with pytest.raises(ValueError): + S1 = xgi.random_flag_complex(10, -2) + + +def test_random_flag_complex_d2(): + + S = xgi.random_flag_complex_d2(10, 0.4, seed=2) + simplices = { + frozenset({0, 4}), + frozenset({0, 7}), + frozenset({1, 8}), + frozenset({2, 5}), + frozenset({2, 9}), + frozenset({3, 5}), + frozenset({3, 6}), + frozenset({3, 7}), + frozenset({3, 8}), + frozenset({4, 5}), + frozenset({4, 7}), + frozenset({4, 8}), + frozenset({6, 7}), + frozenset({6, 8}), + frozenset({7, 8}), + frozenset({0, 4, 7}), + frozenset({3, 6, 7}), + frozenset({3, 6, 8}), + frozenset({3, 7, 8}), + frozenset({4, 7, 8}), + frozenset({6, 7, 8}), + } + + assert set(S.edges.members()) == simplices + + # consistency with other function + S = xgi.random_flag_complex(10, 0.4, seed=3, max_order=2) + S0 = xgi.random_flag_complex_d2(10, 0.4, seed=3) + assert set(S.edges.members()) == set(S0.edges.members()) + + # seed + S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1) + S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2) + S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2) + + assert S1._edge != S2._edge + assert S2._edge == S3._edge + + # wrong input + with pytest.raises(ValueError): + S1 = xgi.random_flag_complex_d2(10, 1.1) + with pytest.raises(ValueError): + S1 = xgi.random_flag_complex_d2(10, -2) diff --git a/xgi/generators/classic.py b/xgi/generators/classic.py index 9a37f14c0..e46743233 100644 --- a/xgi/generators/classic.py +++ b/xgi/generators/classic.py @@ -217,6 +217,7 @@ def complete_hypergraph(N, order=None, max_order=None, include_singletons=False) elif max_order is not None: start = 1 if include_singletons else 2 end = max_order + 1 + assert end >= start # can be equal because adding +1 to end below s = list(nodes) edges = chain.from_iterable(combinations(s, r) for r in range(start, end + 1)) diff --git a/xgi/generators/simplicial_complexes.py b/xgi/generators/simplicial_complexes.py index 487ba3ff7..2da25f1d2 100644 --- a/xgi/generators/simplicial_complexes.py +++ b/xgi/generators/simplicial_complexes.py @@ -130,26 +130,21 @@ def flag_complex(G, max_order=2, ps=None, seed=None): random.seed(seed) nodes = G.nodes() + N = len(nodes) edges = G.edges() - # compute all maximal cliques to fill - max_cliques = list(nx.find_cliques(G)) + cliques_to_add = _cliques_to_fill(G, max_order) S = SimplicialComplex() S.add_nodes_from(nodes) S.add_simplices_from(edges) if not ps: # promote all cliques - S.add_simplices_from(max_cliques, max_order=max_order) + S.add_simplices_from(cliques_to_add, max_order=max_order) return S - if max_order: # compute subfaces of order max_order (allowed max cliques) - max_cliques_to_add = subfaces(max_cliques, order=max_order) - else: - max_cliques_to_add = max_cliques - # store max cliques per order cliques_d = defaultdict(list) - for x in max_cliques_to_add: + for x in cliques_to_add: cliques_d[len(x)].append(x) # promote cliques with a given probability @@ -277,13 +272,47 @@ def random_flag_complex(N, p, max_order=2, seed=None): G = nx.fast_gnp_random_graph(N, p, seed=seed) nodes = G.nodes() - edges = list(G.edges()) - # compute all triangles to fill - max_cliques = list(nx.find_cliques(G)) + cliques = _cliques_to_fill(G, max_order) S = SimplicialComplex() S.add_nodes_from(nodes) - S.add_simplices_from(max_cliques, max_order=max_order) + S.add_simplices_from(cliques, max_order=max_order) return S + + +def _cliques_to_fill(G, max_order): + """Return cliques to fill for flag complexes, + to be passed to `add_simplices_from`. + + This function was written to speedup flag_complex functions + by avoiding adding redundant faces. + + Parameters + ---------- + G : networkx Graph + Graph to consider + max_order: int or None + If None, return maximal cliques. If int, return all cliques + up to max_order. + + Returns + ------- + cliques : list + List of cliques + + """ + if max_order is None: + cliques = list(nx.find_cliques(G)) # max cliques + else: # avoid adding many unnecessary redundant cliques + cliques = [] + for clique in nx.enumerate_all_cliques(G): # sorted by size + if len(clique) == 1: + continue # don't add singletons + if len(clique) <= max_order + 1: + cliques.append(clique) + else: + break # dont go over whole list if not necessary + + return cliques