Skip to content

Commit

Permalink
Check format of input matrix in Louvain
Browse files Browse the repository at this point in the history
  • Loading branch information
tbonald committed Aug 28, 2024
1 parent f7507f3 commit d957d9d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sknetwork/clustering/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from sknetwork.clustering.base import BaseClustering
from sknetwork.clustering.louvain_core import optimize_core
from sknetwork.clustering.postprocess import reindex_labels
from sknetwork.utils.check import check_random_state, get_probs
from sknetwork.utils.format import check_format, get_adjacency, directed2undirected
from sknetwork.utils.check import check_format, check_random_state, get_probs
from sknetwork.utils.format import get_adjacency, directed2undirected
from sknetwork.utils.membership import get_membership
from sknetwork.log import Log

Expand Down Expand Up @@ -193,7 +193,6 @@ def _pre_processing(self, input_matrix, force_bipartite):
self._init_vars()

# adjacency matrix
input_matrix = check_format(input_matrix)
force_directed = self.modularity == 'dugue'
adjacency, self.bipartite = get_adjacency(input_matrix, force_directed=force_directed,
force_bipartite=force_bipartite)
Expand Down Expand Up @@ -266,6 +265,7 @@ def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartit
-------
self : :class:`Louvain`
"""
input_matrix = check_format(input_matrix)
adjacency, out_weights, in_weights, membership, index = self._pre_processing(input_matrix, force_bipartite)
n = adjacency.shape[0]
count = 0
Expand Down
6 changes: 6 additions & 0 deletions sknetwork/clustering/tests/test_louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def test_disconnected(self):
labels = Louvain().fit_predict(adjacency)
self.assertEqual(len(labels), n)

def test_format(self):
adjacency = test_graph()
n = adjacency.shape[0]
labels = Louvain().fit_predict(adjacency.toarray())
self.assertEqual(len(labels), n)

def test_modularity(self):
adjacency = karate_club()
louvain_d = Louvain(modularity='dugue')
Expand Down

0 comments on commit d957d9d

Please sign in to comment.