From bc2e28743cb6b79c42836a6fa1df4d6578bb5a7b Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:34:09 +0200 Subject: [PATCH] mimic sklearn weighted random init --- sklearn_numba_dpex/kmeans/engine.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index dcc8812..f4f2e3f 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -188,16 +188,13 @@ def init_centroids(self, X, sample_weight): else: # NB: sampling without replacement must be executed sequentially so # it's better done on CPU + sample_weight_numpy = dpt.asnumpy(sample_weight) + p = sample_weight_numpy / sample_weight_numpy.sum() centers_idx = self.random_state.choice( - X.shape[0], size=n_clusters, replace=False + X.shape[0], size=n_clusters, replace=False, p=p ) # Poor man's fancy indexing - # TODO: write a kernel ? or replace with better equivalent when available ? - # Relevant issue: https://github.com/IntelPython/dpctl/issues/1003 - centers_t = dpt.concat( - [dpt.expand_dims(X[center_idx], axis=1) for center_idx in centers_idx], - axis=1, - ) + centers_t = dpt.take(X.T, dpt.asarray(centers_idx), axis=1) return centers_t