Skip to content

Commit

Permalink
Merge pull request #1508 from helmholtz-analytics/features/1448-rando…
Browse files Browse the repository at this point in the history
…m_arrays_of_arbitrary_size

Features/1448 Refactor random module
  • Loading branch information
mrfh92 authored Jul 22, 2024
2 parents fea923b + 803a016 commit b153d30
Show file tree
Hide file tree
Showing 10 changed files with 674 additions and 148 deletions.
33 changes: 4 additions & 29 deletions heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,34 +109,9 @@ def _initialize_cluster_centers(self, x: DNDarray):

# initialize the centroids by randomly picking some of the points
if self.init == "random":
# Samples will be equally distributed drawn from all involved processes
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0)
centroids = ht.empty(
(self.n_clusters, x.shape[1]), split=None, device=x.device, comm=x.comm
)
if x.split is None or x.split == 0:
for i in range(self.n_clusters):
samplerange = (
x.gshape[0] // self.n_clusters * i,
x.gshape[0] // self.n_clusters * (i + 1),
)
sample = ht.random.randint(samplerange[0], samplerange[1]).item()
proc = 0
for p in range(x.comm.size):
if displ[p] > sample:
break
proc = p
xi = ht.zeros(x.shape[1], dtype=x.dtype)
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
centroids[i, :] = xi

else:
raise NotImplementedError("Not implemented for other splitting-axes")

self._cluster_centers = centroids
idx = ht.random.randint(0, x.shape[0] - 1, size=(self.n_clusters,), split=None)
centroids = x[idx, :]
self._cluster_centers = centroids if x.split == 1 else centroids.resplit_(None)

# directly passed centroids
elif isinstance(self.init, DNDarray):
Expand Down Expand Up @@ -172,7 +147,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
D2 = distances.min(axis=1)
D2.resplit_(axis=None)
prob = D2 / D2.sum()
random_position = ht.random.rand().item()
random_position = ht.random.rand()
sample = 0
sum = 0
for j in range(len(prob)):
Expand Down
1 change: 0 additions & 1 deletion heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def fit(self, x: DNDarray) -> self:
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedians.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fit(self, x: DNDarray):
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
# increment the iteration count
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def fit(self, x: DNDarray):
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
# increment the iteration count
Expand Down
4 changes: 2 additions & 2 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def test_inv(self):
ainv = ht.linalg.inv(a)
i = ht.eye(a.shape, split=1, dtype=a.dtype)
# loss of precision in distributed floating-point ops
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-12))
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-10))

ht.random.seed(42)
a = ht.random.random((20, 20), dtype=ht.float64, split=0)
ainv = ht.linalg.inv(a)
i = ht.eye(a.shape, split=0, dtype=a.dtype)
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-12))
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-10))

with self.assertRaises(RuntimeError):
ht.linalg.inv(ht.array([1, 2, 3], split=0))
Expand Down
4 changes: 2 additions & 2 deletions heat/core/linalg/tests/test_svdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_hsvd_rank_part1(self):
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
self.assertTrue(true_rel_err <= err_est or true_rel_err < dtype_tol)
else:
self.assertEqual(hsvd_rk, 1)
self.assertEqual(ht.norm(U), 0)
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_hsvd_rank_part1(self):
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
self.assertTrue(true_rel_err <= err_est or true_rel_err < dtype_tol)
self.assertTrue(true_rel_err <= tol)
else:
self.assertEqual(hsvd_rk, 1)
Expand Down
Loading

0 comments on commit b153d30

Please sign in to comment.