From 5926de2cb39cc5fd5def1a6ad542aa9fcb24c01d Mon Sep 17 00:00:00 2001 From: Azamat Berdyshev Date: Sat, 7 Sep 2024 11:17:53 +0000 Subject: [PATCH 1/2] fix bug --- src/kmeans.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kmeans.jl b/src/kmeans.jl index e3720e83..0983a416 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -388,6 +388,6 @@ function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix v = view(X, :, j) centers[:, i] = v colwise!(distance, ds, v, X) - tcosts = min(tcosts, ds) + tcosts .= min.(tcosts, ds) end end From 5f280eb5911ff6ba8a6c9c5b9069aa1cdbf4bc13 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 5 Jan 2025 21:50:04 -0800 Subject: [PATCH 2/2] kmeans(): tweak repick/update comments --- src/kmeans.jl | 58 ++++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/src/kmeans.jl b/src/kmeans.jl index 0983a416..efb2171d 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -164,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d update_centers!(X, weights, assignments, to_update, centers, wcounts) if !isempty(unused) - repick_unused_centers(X, costs, centers, unused, distance, rng) + repick_unused_centers!(centers, unused, X, costs, distance, rng) to_update[unused] .= true end @@ -211,18 +211,16 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d wcounts, objv, t, converged) end -# -# Updates assignments, costs, and counts based on -# an updated (squared) distance matrix -# +# Update point assignments, costs, and cluster counts based on +# an updated (squared) distance matrix function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k x n) is_init::Bool, # in: whether it is the initial run assignments::Vector{Int}, # out: assignment vector (n) costs::Vector{<:Real}, # out: costs of the resultant assignment (n) counts::Vector{Int}, # out: # of points assigned to each cluster (k) to_update::Vector{Bool}, # out: whether a center needs update (k) - unused::Vector{Int} # out: list of centers with no points assigned - ) + unused::Vector{Int}, # out: list of centers with no points assigned +) k, n = size(dmat) # re-initialize the counting vector @@ -272,17 +270,15 @@ function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k end end -# -# Update centers based on updated assignments -# -# (specific to the case where points are not weighted) -# +# Update cluster centers and weights to match updated assignments +# (non-weighted points case) function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n) weights::Nothing, # in: point weights assignments::Vector{Int}, # in: assignments (n) to_update::Vector{Bool}, # in: whether a center needs update (k) centers::AbstractMatrix{<:AbstractFloat}, # out: updated centers (d x k) - wcounts::Vector{Int}) # out: updated cluster weights (k) + wcounts::Vector{Int}, # out: updated cluster weights (k) +) d, n = size(X) k = size(centers, 2) @@ -318,18 +314,15 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d end end -# -# Update centers based on updated assignments -# -# (specific to the case where points are weighted) -# +# Update cluster centers and weights to match updated assignments +# (weighted points case) function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n) weights::Vector{W}, # in: point weights (n) assignments::Vector{Int}, # in: assignments (n) to_update::Vector{Bool}, # in: whether a center needs update (k) centers::AbstractMatrix{<:Real}, # out: updated centers (d x k) - wcounts::Vector{W} # out: updated cluster weights (k) - ) where W<:Real + wcounts::Vector{W}, # out: updated cluster weights (k) +) where W<:Real d, n = size(X) k = size(centers, 2) @@ -368,26 +361,25 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n) end -# -# Re-picks centers that have no points assigned to them. -# -function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix (d x n) - costs::Vector{<:Real}, # in: the current assignment costs (n) - centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k) - unused::Vector{Int}, # in: indices of centers to be updated - distance::SemiMetric, # in: function to calculate the distance with - rng::AbstractRNG) # in: RNG object +# Re-pick centers that have no points assigned to them. +function repick_unused_centers!(centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k) + unused::Vector{Int}, # in: indices of centers to be updated (k) + X::AbstractMatrix{<:Real}, # in: the data matrix (d x n) + costs::Vector{<:Real}, # in: the current assignment costs (n) + distance::SemiMetric, # in: function to calculate the distance with + rng::AbstractRNG, +) # pick new centers using a scheme like kmeans++ ds = similar(costs) - tcosts = copy(costs) + tcosts = copy(costs) # temporary costs used as sampling weights n = size(X, 2) for i in unused + # select a random point as a new center j = wsample(rng, 1:n, tcosts) - tcosts[j] = 0 - v = view(X, :, j) - centers[:, i] = v + centers[:, i] = v = view(X, :, j) colwise!(distance, ds, v, X) + ds[j] = 0 # calculated distance might be not exactly zero tcosts .= min.(tcosts, ds) end end