Skip to content

Commit

Permalink
Faster flag complex (#355)
Browse files Browse the repository at this point in the history
* feat: added complete hypergraph

* tests: added corresponding

* docs: added new function. style: black isort

* review comments

* review comments

* fix docs maths

* perf: faster random_flag_complex by adding only necessary cliques. tests: moved for consistency and added corresponding

* perf: same for flag_complex

* review comments

* style: black
  • Loading branch information
maximelucas authored May 10, 2023
1 parent a58a4b8 commit 26ba651
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 61 deletions.
48 changes: 0 additions & 48 deletions tests/generators/test_nonuniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,51 +67,3 @@ def test_random_hypergraph():
H4 = xgi.random_hypergraph(10, [0.1], order=2, seed=1)
assert H4.num_nodes == 10
assert xgi.unique_edge_sizes(H4) == [3]


def test_random_simplicial_complex():
# seed
S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1)
S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)
S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, 1.1])
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, -2])


def test_random_flag_complex():
# seed
S1 = xgi.random_flag_complex(10, 0.1, seed=1)
S2 = xgi.random_flag_complex(10, 0.1, seed=2)
S3 = xgi.random_flag_complex(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, -2)


def test_random_flag_complex_d2():
# seed
S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1)
S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2)
S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, -2)
116 changes: 116 additions & 0 deletions tests/generators/test_simplicial_complexes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import networkx as nx
import pytest

import xgi
from xgi.exception import XGIError
Expand All @@ -21,6 +22,7 @@ def test_flag_complex():

assert S.edges.members() == simplices_3

# ps
S1 = xgi.flag_complex(G, ps=[1], seed=42)
S2 = xgi.flag_complex(G, ps=[0.5], seed=42)
S3 = xgi.flag_complex(G, ps=[0], seed=42)
Expand All @@ -29,6 +31,7 @@ def test_flag_complex():
assert S2.edges.members() == simplices_2
assert S3.edges.members() == simplices_2

# complete graph
G1 = nx.complete_graph(4)
S4 = xgi.flag_complex(G1)
S5 = xgi.flag_complex(G1, ps=[1])
Expand All @@ -44,3 +47,116 @@ def test_flag_complex_d2():
S2 = xgi.flag_complex_d2(G)

assert set(S.edges.members()) == set(S2.edges.members())


def test_random_simplicial_complex():
# seed
S1 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=1)
S2 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)
S3 = xgi.random_simplicial_complex(10, [0.1, 0.001], seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, 1.1])
with pytest.raises(ValueError):
S1 = xgi.random_simplicial_complex(10, [1, -2])


def test_random_flag_complex():

S = xgi.random_flag_complex(10, 0.4, seed=2)
simplices = {
frozenset({0, 4}),
frozenset({0, 7}),
frozenset({1, 8}),
frozenset({2, 5}),
frozenset({2, 9}),
frozenset({3, 5}),
frozenset({3, 6}),
frozenset({3, 7}),
frozenset({3, 8}),
frozenset({4, 5}),
frozenset({4, 7}),
frozenset({4, 8}),
frozenset({6, 7}),
frozenset({6, 8}),
frozenset({7, 8}),
frozenset({0, 4, 7}),
frozenset({3, 6, 7}),
frozenset({3, 6, 8}),
frozenset({3, 7, 8}),
frozenset({4, 7, 8}),
frozenset({6, 7, 8}),
}

assert set(S.edges.members()) == simplices

# max_order
S = xgi.random_flag_complex(10, 0.4, seed=2, max_order=3)
assert set(S.edges.members()) == simplices.union({frozenset({3, 6, 7, 8})})

# seed
S1 = xgi.random_flag_complex(10, 0.1, seed=1)
S2 = xgi.random_flag_complex(10, 0.1, seed=2)
S3 = xgi.random_flag_complex(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex(10, -2)


def test_random_flag_complex_d2():

S = xgi.random_flag_complex_d2(10, 0.4, seed=2)
simplices = {
frozenset({0, 4}),
frozenset({0, 7}),
frozenset({1, 8}),
frozenset({2, 5}),
frozenset({2, 9}),
frozenset({3, 5}),
frozenset({3, 6}),
frozenset({3, 7}),
frozenset({3, 8}),
frozenset({4, 5}),
frozenset({4, 7}),
frozenset({4, 8}),
frozenset({6, 7}),
frozenset({6, 8}),
frozenset({7, 8}),
frozenset({0, 4, 7}),
frozenset({3, 6, 7}),
frozenset({3, 6, 8}),
frozenset({3, 7, 8}),
frozenset({4, 7, 8}),
frozenset({6, 7, 8}),
}

assert set(S.edges.members()) == simplices

# consistency with other function
S = xgi.random_flag_complex(10, 0.4, seed=3, max_order=2)
S0 = xgi.random_flag_complex_d2(10, 0.4, seed=3)
assert set(S.edges.members()) == set(S0.edges.members())

# seed
S1 = xgi.random_flag_complex_d2(10, 0.1, seed=1)
S2 = xgi.random_flag_complex_d2(10, 0.1, seed=2)
S3 = xgi.random_flag_complex_d2(10, 0.1, seed=2)

assert S1._edge != S2._edge
assert S2._edge == S3._edge

# wrong input
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, 1.1)
with pytest.raises(ValueError):
S1 = xgi.random_flag_complex_d2(10, -2)
1 change: 1 addition & 0 deletions xgi/generators/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def complete_hypergraph(N, order=None, max_order=None, include_singletons=False)
elif max_order is not None:
start = 1 if include_singletons else 2
end = max_order + 1
assert end >= start # can be equal because adding +1 to end below

s = list(nodes)
edges = chain.from_iterable(combinations(s, r) for r in range(start, end + 1))
Expand Down
55 changes: 42 additions & 13 deletions xgi/generators/simplicial_complexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,21 @@ def flag_complex(G, max_order=2, ps=None, seed=None):
random.seed(seed)

nodes = G.nodes()
N = len(nodes)
edges = G.edges()

# compute all maximal cliques to fill
max_cliques = list(nx.find_cliques(G))
cliques_to_add = _cliques_to_fill(G, max_order)

S = SimplicialComplex()
S.add_nodes_from(nodes)
S.add_simplices_from(edges)
if not ps: # promote all cliques
S.add_simplices_from(max_cliques, max_order=max_order)
S.add_simplices_from(cliques_to_add, max_order=max_order)
return S

if max_order: # compute subfaces of order max_order (allowed max cliques)
max_cliques_to_add = subfaces(max_cliques, order=max_order)
else:
max_cliques_to_add = max_cliques

# store max cliques per order
cliques_d = defaultdict(list)
for x in max_cliques_to_add:
for x in cliques_to_add:
cliques_d[len(x)].append(x)

# promote cliques with a given probability
Expand Down Expand Up @@ -277,13 +272,47 @@ def random_flag_complex(N, p, max_order=2, seed=None):
G = nx.fast_gnp_random_graph(N, p, seed=seed)

nodes = G.nodes()
edges = list(G.edges())

# compute all triangles to fill
max_cliques = list(nx.find_cliques(G))
cliques = _cliques_to_fill(G, max_order)

S = SimplicialComplex()
S.add_nodes_from(nodes)
S.add_simplices_from(max_cliques, max_order=max_order)
S.add_simplices_from(cliques, max_order=max_order)

return S


def _cliques_to_fill(G, max_order):
"""Return cliques to fill for flag complexes,
to be passed to `add_simplices_from`.
This function was written to speedup flag_complex functions
by avoiding adding redundant faces.
Parameters
----------
G : networkx Graph
Graph to consider
max_order: int or None
If None, return maximal cliques. If int, return all cliques
up to max_order.
Returns
-------
cliques : list
List of cliques
"""
if max_order is None:
cliques = list(nx.find_cliques(G)) # max cliques
else: # avoid adding many unnecessary redundant cliques
cliques = []
for clique in nx.enumerate_all_cliques(G): # sorted by size
if len(clique) == 1:
continue # don't add singletons
if len(clique) <= max_order + 1:
cliques.append(clique)
else:
break # dont go over whole list if not necessary

return cliques

0 comments on commit 26ba651

Please sign in to comment.