From 775eefcdc25e0f67ceee2f866edc584104cf14d7 Mon Sep 17 00:00:00 2001 From: lbollar Date: Thu, 29 Sep 2016 23:46:33 -0500 Subject: [PATCH] Issue #64: added n_init to kmeans --- src/kmeans.jl | 165 ++++++++++++++++++++++++++++--------------------- test/kmeans.jl | 6 +- 2 files changed, 97 insertions(+), 74 deletions(-) diff --git a/src/kmeans.jl b/src/kmeans.jl index 3dcaf9cb..f553b996 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -17,12 +17,14 @@ const _kmeans_default_init = :kmpp const _kmeans_default_maxiter = 100 const _kmeans_default_tol = 1.0e-6 const _kmeans_default_display = :none +const _kmeans_default_n_init = 10 function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T}; weights=nothing, maxiter::Integer=_kmeans_default_maxiter, tol::Real=_kmeans_default_tol, display::Symbol=_kmeans_default_display) + m, n = size(X) m2, k = size(centers) @@ -43,18 +45,37 @@ function kmeans(X::Matrix, k::Int; weights=nothing, init=_kmeans_default_init, maxiter::Integer=_kmeans_default_maxiter, + n_init::Integer=_kmeans_default_n_init, tol::Real=_kmeans_default_tol, display::Symbol=_kmeans_default_display) + m, n = size(X) (2 <= k < n) || error("k must have 2 <= k < n.") - iseeds = initseeds(init, X, k) - centers = copyseeds(X, iseeds) - kmeans!(X, centers; - weights=weights, - maxiter=maxiter, - tol=tol, - display=display) + n_init > 0 || error("n_init must be greater than 0") + + lowestcost::Float64 = Inf + local bestresult::KmeansResult + + for i = 1:n_init + + iseeds = initseeds(init, X, k) + centers = copyseeds(X, iseeds) + result = kmeans!(X, centers; + weights=weights, + maxiter=maxiter, + tol=tol, + display=display) + + if result.totalcost < lowestcost + lowestcost = result.totalcost + bestresult = result + end + + end + + return bestresult + end #### Core implementation @@ -72,86 +93,88 @@ function _kmeans!{T<:AbstractFloat}( tol::Real, # in: tolerance of change at convergence displevel::Int) # in: the level of display - # initialize - - k = size(centers, 2) - to_update = Array(Bool, k) # indicators of whether a center needs to be updated - unused = Int[] - num_affected::Int = k # number of centers, to which the distances need to be recomputed - - dmat = pairwise(SqEuclidean(), centers, x) - dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T - update_assignments!(dmat, true, assignments, costs, counts, to_update, unused) - objv = w == nothing ? sum(costs) : dot(w, costs) - - # main loop - t = 0 - converged = false - if displevel >= 2 - @printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected" - println("-------------------------------------------------------------") - @printf("%7d %18.6e\n", t, objv) - end - while !converged && t < maxiter - t = t + 1 + + # initialize - # update (affected) centers + k = size(centers, 2) + to_update = Array(Bool, k) # indicators of whether a center needs to be updated + unused = Int[] + num_affected::Int = k # number of centers, to which the distances need to be recomputed - update_centers!(x, w, assignments, to_update, centers, cweights) + dmat = pairwise(SqEuclidean(), centers, x) + dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T + update_assignments!(dmat, true, assignments, costs, counts, to_update, unused) + objv = w == nothing ? sum(costs) : dot(w, costs) - if !isempty(unused) - repick_unused_centers(x, costs, centers, unused) - end + # main loop + t = 0 + converged = false + if displevel >= 2 + @printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected" + println("-------------------------------------------------------------") + @printf("%7d %18.6e\n", t, objv) + end - # update pairwise distance matrix + while !converged && t < maxiter + t = t + 1 - if !isempty(unused) - to_update[unused] = true - end + # update (affected) centers - if t == 1 || num_affected > 0.75 * k - pairwise!(dmat, SqEuclidean(), centers, x) - else - # if only a small subset is affected, only compute for that subset - affected_inds = find(to_update) - dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x) - dmat[affected_inds, :] = dmat_p - end + update_centers!(x, w, assignments, to_update, centers, cweights) - # update assignments + if !isempty(unused) + repick_unused_centers(x, costs, centers, unused) + end - update_assignments!(dmat, false, assignments, costs, counts, to_update, unused) - num_affected = sum(to_update) + length(unused) + # update pairwise distance matrix - # compute change of objective and determine convergence + if !isempty(unused) + to_update[unused] = true + end - prev_objv = objv - objv = w == nothing ? sum(costs) : dot(w, costs) - objv_change = objv - prev_objv + if t == 1 || num_affected > 0.75 * k + pairwise!(dmat, SqEuclidean(), centers, x) + else + # if only a small subset is affected, only compute for that subset + affected_inds = find(to_update) + dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x) + dmat[affected_inds, :] = dmat_p + end - if objv_change > tol - warn("The objective value changes towards an opposite direction") - end + # update assignments - if abs(objv_change) < tol - converged = true - end + update_assignments!(dmat, false, assignments, costs, counts, to_update, unused) + num_affected = sum(to_update) + length(unused) - # display iteration information (if asked) + # compute change of objective and determine convergence - if displevel >= 2 - @printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected) - end - end + prev_objv = objv + objv = w == nothing ? sum(costs) : dot(w, costs) + objv_change = objv - prev_objv - if displevel >= 1 - if converged - println("K-means converged with $t iterations (objv = $objv)") - else - println("K-means terminated without convergence after $t iterations (objv = $objv)") - end - end + if objv_change > tol + warn("The objective value changes towards an opposite direction") + end + + if abs(objv_change) < tol + converged = true + end + + # display iteration information (if asked) + + if displevel >= 2 + @printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected) + end + end + + if displevel >= 1 + if converged + println("K-means converged with $t iterations (objv = $objv)") + else + println("K-means terminated without convergence after $t iterations (objv = $objv)") + end + end return KmeansResult(centers, assignments, costs, counts, cweights, @compat(Float64(objv)), t, converged) diff --git a/test/kmeans.jl b/test/kmeans.jl index 7e678ca7..2c9acd79 100644 --- a/test/kmeans.jl +++ b/test/kmeans.jl @@ -12,7 +12,7 @@ k = 10 x = rand(m, n) # non-weighted -r = kmeans(x, k; maxiter=50) +r = kmeans(x, k; maxiter=50, n_init=2) @test isa(r, KmeansResult{Float64}) @test size(r.centers) == (m, k) @test length(r.assignments) == n @@ -24,7 +24,7 @@ r = kmeans(x, k; maxiter=50) @test_approx_eq sum(r.costs) r.totalcost # non-weighted (float32) -r = kmeans(@compat(map(Float32, x)), k; maxiter=50) +r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2) @test isa(r, KmeansResult{Float32}) @test size(r.centers) == (m, k) @test length(r.assignments) == n @@ -37,7 +37,7 @@ r = kmeans(@compat(map(Float32, x)), k; maxiter=50) # weighted w = rand(n) -r = kmeans(x, k; maxiter=50, weights=w) +r = kmeans(x, k; maxiter=50, weights=w, n_init=2) @test isa(r, KmeansResult{Float64}) @test size(r.centers) == (m, k) @test length(r.assignments) == n