Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting bug in repick_unused_centers() for K-means #283

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 26 additions & 34 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
update_centers!(X, weights, assignments, to_update, centers, wcounts)

if !isempty(unused)
repick_unused_centers(X, costs, centers, unused, distance, rng)
repick_unused_centers!(centers, unused, X, costs, distance, rng)
to_update[unused] .= true
end

Expand Down Expand Up @@ -211,18 +211,16 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
wcounts, objv, t, converged)
end

#
# Updates assignments, costs, and counts based on
# an updated (squared) distance matrix
#
# Update point assignments, costs, and cluster counts based on
# an updated (squared) distance matrix
function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k x n)
is_init::Bool, # in: whether it is the initial run
assignments::Vector{Int}, # out: assignment vector (n)
costs::Vector{<:Real}, # out: costs of the resultant assignment (n)
counts::Vector{Int}, # out: # of points assigned to each cluster (k)
to_update::Vector{Bool}, # out: whether a center needs update (k)
unused::Vector{Int} # out: list of centers with no points assigned
)
unused::Vector{Int}, # out: list of centers with no points assigned
)
k, n = size(dmat)

# re-initialize the counting vector
Expand Down Expand Up @@ -272,17 +270,15 @@ function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k
end
end

#
# Update centers based on updated assignments
#
# (specific to the case where points are not weighted)
#
# Update cluster centers and weights to match updated assignments
# (non-weighted points case)
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
weights::Nothing, # in: point weights
assignments::Vector{Int}, # in: assignments (n)
to_update::Vector{Bool}, # in: whether a center needs update (k)
centers::AbstractMatrix{<:AbstractFloat}, # out: updated centers (d x k)
wcounts::Vector{Int}) # out: updated cluster weights (k)
wcounts::Vector{Int}, # out: updated cluster weights (k)
)
d, n = size(X)
k = size(centers, 2)

Expand Down Expand Up @@ -318,18 +314,15 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d
end
end

#
# Update centers based on updated assignments
#
# (specific to the case where points are weighted)
#
# Update cluster centers and weights to match updated assignments
# (weighted points case)
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
weights::Vector{W}, # in: point weights (n)
assignments::Vector{Int}, # in: assignments (n)
to_update::Vector{Bool}, # in: whether a center needs update (k)
centers::AbstractMatrix{<:Real}, # out: updated centers (d x k)
wcounts::Vector{W} # out: updated cluster weights (k)
) where W<:Real
wcounts::Vector{W}, # out: updated cluster weights (k)
) where W<:Real
d, n = size(X)
k = size(centers, 2)

Expand Down Expand Up @@ -368,26 +361,25 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
end


#
# Re-picks centers that have no points assigned to them.
#
function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
costs::Vector{<:Real}, # in: the current assignment costs (n)
centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
unused::Vector{Int}, # in: indices of centers to be updated
distance::SemiMetric, # in: function to calculate the distance with
rng::AbstractRNG) # in: RNG object
# Re-pick centers that have no points assigned to them.
function repick_unused_centers!(centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
unused::Vector{Int}, # in: indices of centers to be updated (k)
X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
costs::Vector{<:Real}, # in: the current assignment costs (n)
distance::SemiMetric, # in: function to calculate the distance with
rng::AbstractRNG,
)
# pick new centers using a scheme like kmeans++
ds = similar(costs)
tcosts = copy(costs)
tcosts = copy(costs) # temporary costs used as sampling weights
n = size(X, 2)

for i in unused
# select a random point as a new center
j = wsample(rng, 1:n, tcosts)
tcosts[j] = 0
v = view(X, :, j)
centers[:, i] = v
centers[:, i] = v = view(X, :, j)
colwise!(distance, ds, v, X)
tcosts = min(tcosts, ds)
ds[j] = 0 # calculated distance might be not exactly zero
tcosts .= min.(tcosts, ds)
end
end
Loading