Skip to content

Commit

Permalink
Add force_bipartite option in hierarchical algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
tbonald committed Jul 16, 2024
1 parent 8a7b479 commit bc49fe0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
16 changes: 10 additions & 6 deletions sknetwork/hierarchy/louvain_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,23 @@ def _recursive_louvain(self, adjacency: Union[sparse.csr_matrix, np.ndarray], de
tree.append(self._recursive_louvain(adjacency_cluster, depth - 1, nodes_cluster))
return tree

def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray]) -> 'LouvainIteration':
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartite: bool = False) \
-> 'LouvainIteration':
"""Fit algorithm to data.
Parameters
----------
input_matrix : sparse.csr_matrix, np.ndarray
Adjacency matrix or biadjacency matrix of the graph.
force_bipartite :
If ``True``, force the input matrix to be considered as a biadjacency matrix.
Returns
-------
self: :class:`LouvainIteration`
"""
self._init_vars()
input_matrix = check_format(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix, force_bipartite=force_bipartite)
tree = self._recursive_louvain(adjacency, self.depth)
dendrogram, _ = get_dendrogram(tree)
dendrogram = np.array(dendrogram)
Expand Down Expand Up @@ -243,21 +245,23 @@ def _get_hierarchy(self, adjacency: Union[sparse.csr_matrix, np.ndarray]):
labels_unique = np.unique(labels)
return tree

def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray]) -> 'LouvainHierarchy':
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartite: bool = False) \
-> 'LouvainHierarchy':
"""Fit algorithm to data.
Parameters
----------
input_matrix : sparse.csr_matrix, np.ndarray
Adjacency matrix or biadjacency matrix of the graph.
force_bipartite :
If ``True``, force the input matrix to be considered as a biadjacency matrix.
Returns
-------
self: :class:`LouvainHierarchy`
"""
self._init_vars()
input_matrix = check_format(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix, force_bipartite=force_bipartite)
tree = self._get_hierarchy(adjacency)
dendrogram, _ = get_dendrogram(tree)
dendrogram = np.array(dendrogram)
Expand Down
7 changes: 4 additions & 3 deletions sknetwork/hierarchy/paris.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,15 @@ class Paris(BaseHierarchy):

@cython.boundscheck(False)
@cython.wraparound(False)
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray]) -> 'Paris':
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartite: bool = False) -> 'Paris':
"""Agglomerative clustering using the nearest neighbor chain.

Parameters
----------
input_matrix : sparse.csr_matrix, np.ndarray
Adjacency matrix or biadjacency matrix of the graph.
force_bipartite :
If ``True``, force the input matrix to be considered as a biadjacency matrix.

Returns
-------
Expand All @@ -227,8 +229,7 @@ class Paris(BaseHierarchy):
self._init_vars()

# input
input_matrix = check_format(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix)
adjacency, self.bipartite = get_adjacency(input_matrix, force_bipartite=force_bipartite)

weights = self.weights
out_weights = get_probs(weights, adjacency)
Expand Down
8 changes: 4 additions & 4 deletions sknetwork/hierarchy/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestMetrics(unittest.TestCase):

def setUp(self):
self.paris = Paris()
self.louvain_hierarchy = LouvainIteration()
self.louvain_iteration = LouvainIteration()

def test_undirected(self):
adjacency = cyclic_graph(3)
Expand All @@ -31,7 +31,7 @@ def test_undirected(self):
self.assertAlmostEqual(dasgupta_cost(adjacency, dendrogram), 4.26, 2)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.573, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.304, 2)
dendrogram = self.louvain_hierarchy.fit_transform(adjacency)
dendrogram = self.louvain_iteration.fit_transform(adjacency)
self.assertAlmostEqual(dasgupta_cost(adjacency, dendrogram), 4.43, 2)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.555, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.286, 2)
Expand All @@ -41,7 +41,7 @@ def test_directed(self):
dendrogram = self.paris.fit_transform(adjacency)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.566, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.318, 2)
dendrogram = self.louvain_hierarchy.fit_transform(adjacency)
dendrogram = self.louvain_iteration.fit_transform(adjacency)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.55, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.313, 2)

Expand All @@ -50,7 +50,7 @@ def test_disconnected(self):
dendrogram = self.paris.fit_transform(adjacency)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.682, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.464, 2)
dendrogram = self.louvain_hierarchy.fit_transform(adjacency)
dendrogram = self.louvain_iteration.fit_transform(adjacency)
self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.670, 2)
self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.594, 2)

Expand Down

0 comments on commit bc49fe0

Please sign in to comment.