forked from e-lab/clustering-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonline-kmeans.lua
91 lines (77 loc) · 2.32 KB
/
online-kmeans.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
--
-- The k-means algorithm. ONLINE VERSION by E. Culurciello Jan 2013
--
-- > x: is supposed to be an MxN matrix, where M is the nb of samples and each sample is N-dim
-- > k: is the number of kernels
-- > niter: the number of iterations
-- > batchsize: the batch size [large is good, to parallelize matrix multiplications]
-- > callback: optional callback, at each iteration end
-- > verbose: prints a progress bar...
--
-- < returns the k means (centroids)
--
function okmeans(x, k, centroids, std, niter, batchsize, callback, verbose)
-- args
batchsize = batchsize or 1000
std = std or 0.1
-- some shortcuts
local sum = torch.sum
local max = torch.max
local pow = torch.pow
local randn = torch.randn
local zeros = torch.zeros
-- dims
local nsamples = (#x)[1]
local ndims = (#x)[2]
-- initialize means
local x2 = sum(pow(x,2),2)
if not(centroids) then
centroids = randn(k,ndims)*std
end
local totalcounts = zeros(k)
-- do niter iterations
for i = 1,niter do
-- progress
if verbose then xlua.progress(i,niter) end
-- sums of squares
local c2 = sum(pow(centroids,2),2)*0.5
-- init some variables
local summation = zeros(k,ndims)
local counts = zeros(k)
local loss = 0
-- process batch
for i = 1,nsamples,batchsize do
-- indices
local lasti = math.min(i+batchsize-1,nsamples)
local m = lasti - i + 1
-- k-means step, on minibatch
local batch = x[{ {i,lasti},{} }]
local batch_t = batch:t()
local tmp = centroids * batch_t
for n = 1,(#batch)[1] do
tmp[{ {},n }]:add(-1,c2)
end
local val,labels = max(tmp,1)
loss = loss + sum(x2[{ {i,lasti} }]*0.5 - val:t())
-- count examplars per template
local S = zeros(m,k)
for i = 1,(#labels)[2] do
S[i][labels[1][i]] = 1
end
summation:add( S:t() * batch )
counts:add( sum(S,1) )
end
-- normalize
for i = 1,k do
if counts[i] ~= 0 then
centroids[i] = summation[i]:div(counts[i])
end
end
-- total counts
totalcounts:add(counts)
-- callback?
if callback then callback(centroids) end
end
-- done
return centroids,totalcounts
end