Skip to content

Commit

Permalink
PEP8 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tbonald committed Mar 27, 2024
1 parent d35af94 commit 7c72a26
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions sknetwork/clustering/tests/test_kcenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,26 @@
from sknetwork.data import karate_club, painters, star_wars
from sknetwork.data.test_graphs import *


class TestKCentersClustering(unittest.TestCase):

def test_kcenters(self):
# Test undirected graph
n_clusters = 2
adjacency = karate_club()
n_row = adjacency.shape[0]
kcenters = KCenters(n_clusters=n_clusters)
labels = kcenters.fit_predict(adjacency)
self.assertEqual(len(labels), n_row)
self.assertEqual(len(labels), n_row)
self.assertEqual(len(set(labels)), n_clusters)

# Test directed graph
# Test directed graph
n_clusters = 3
adjacency = painters()
n_row = adjacency.shape[0]
kcenters = KCenters(n_clusters=n_clusters, directed=True)
labels = kcenters.fit_predict(adjacency)
self.assertEqual(len(labels), n_row)
self.assertEqual(len(labels), n_row)
self.assertEqual(len(set(labels)), n_clusters)

# Test bipartite graph
Expand All @@ -38,33 +39,33 @@ def test_kcenters(self):
self.assertEqual(len(kcenters.labels_row_), n_row)
self.assertEqual(len(kcenters.labels_col_), n_col)
self.assertEqual(len(set(labels)), n_clusters)

def test_kcenters_centers(self):
# Test centers for undirected graphs
# Test centers for undirected graphs
n_clusters = 2
adjacency = karate_club()
kcenters = KCenters(n_clusters=n_clusters)
kcenters.fit(adjacency)
centers = kcenters.centers_
self.assertEqual(n_clusters, len(set(centers)))

# Test centers for bipartite graphs
# Test centers for bipartite graphs
n_clusters = 2
biadjacency = star_wars()
n_row, n_col = biadjacency.shape
for position in ["row", "col", "both"]:
kcenters = KCenters(n_clusters=n_clusters, center_position=position)
kcenters.fit(biadjacency)
centers_row = kcenters.centers_row_
centers_col = kcenters.centers_col_
centers_col = kcenters.centers_col_
if position == "row":
self.assertEqual(n_clusters, len(set(centers_row)))
self.assertTrue(np.all(centers_row < n_row))
self.assertTrue(centers_col==None)
self.assertTrue(centers_col is None)
if position == "col":
self.assertEqual(n_clusters, len(set(centers_col)))
self.assertTrue(np.all((centers_col < n_col) & (0 <= centers_col)))
self.assertTrue(centers_row==None)
self.assertTrue(centers_row is None)
if position == "both":
self.assertEqual(n_clusters, len(set(centers_row)) + len(set(centers_col)))
self.assertTrue(np.all(centers_row < n_row))
Expand All @@ -74,21 +75,18 @@ def test_kcenters_error(self):
# Test value errors
adjacency = karate_club()
biadjacency = star_wars()

# test n_clusters error
kcenters = KCenters(n_clusters=1)
with self.assertRaises(ValueError):
kcenters.fit(adjacency)

# test n_init error
# test n_init error
kcenters = KCenters(n_clusters=2, n_init=0)
with self.assertRaises(ValueError):
kcenters.fit(adjacency)

# test center_position error
# test center_position error
kcenters = KCenters(n_clusters=2, center_position="other")
with self.assertRaises(ValueError):
kcenters.fit(biadjacency)



0 comments on commit 7c72a26

Please sign in to comment.