From d957d9dd5c68d483d6d636877f929cb529b1b6ad Mon Sep 17 00:00:00 2001 From: Thomas Bonald Date: Wed, 28 Aug 2024 15:22:52 +0200 Subject: [PATCH] Check format of input matrix in Louvain --- sknetwork/clustering/louvain.py | 6 +++--- sknetwork/clustering/tests/test_louvain.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sknetwork/clustering/louvain.py b/sknetwork/clustering/louvain.py index ababf68d..a3f01976 100644 --- a/sknetwork/clustering/louvain.py +++ b/sknetwork/clustering/louvain.py @@ -14,8 +14,8 @@ from sknetwork.clustering.base import BaseClustering from sknetwork.clustering.louvain_core import optimize_core from sknetwork.clustering.postprocess import reindex_labels -from sknetwork.utils.check import check_random_state, get_probs -from sknetwork.utils.format import check_format, get_adjacency, directed2undirected +from sknetwork.utils.check import check_format, check_random_state, get_probs +from sknetwork.utils.format import get_adjacency, directed2undirected from sknetwork.utils.membership import get_membership from sknetwork.log import Log @@ -193,7 +193,6 @@ def _pre_processing(self, input_matrix, force_bipartite): self._init_vars() # adjacency matrix - input_matrix = check_format(input_matrix) force_directed = self.modularity == 'dugue' adjacency, self.bipartite = get_adjacency(input_matrix, force_directed=force_directed, force_bipartite=force_bipartite) @@ -266,6 +265,7 @@ def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartit ------- self : :class:`Louvain` """ + input_matrix = check_format(input_matrix) adjacency, out_weights, in_weights, membership, index = self._pre_processing(input_matrix, force_bipartite) n = adjacency.shape[0] count = 0 diff --git a/sknetwork/clustering/tests/test_louvain.py b/sknetwork/clustering/tests/test_louvain.py index 7bf998de..72884100 100644 --- a/sknetwork/clustering/tests/test_louvain.py +++ b/sknetwork/clustering/tests/test_louvain.py @@ -17,6 +17,12 @@ def test_disconnected(self): labels = Louvain().fit_predict(adjacency) self.assertEqual(len(labels), n) + def test_format(self): + adjacency = test_graph() + n = adjacency.shape[0] + labels = Louvain().fit_predict(adjacency.toarray()) + self.assertEqual(len(labels), n) + def test_modularity(self): adjacency = karate_club() louvain_d = Louvain(modularity='dugue')