-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantization.py
53 lines (47 loc) · 2.01 KB
/
quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors, kneighbors_graph
from sklearn.cluster import AgglomerativeClustering
def RandomChoiceQuantization(X=None, D=None, prms={}):
n_centroids = prms['n_clusters']
npts = len(X) if X is not None else len(D)
indices = np.random.choice(npts, min(npts, n_centroids), replace=False)
if D is not None:
labels = np.argmin(D[:,indices], axis=1)
else:
nbrs = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit(X[indices,:])
_, labels = nbrs.kneighbors(X)
return indices, labels.ravel()
def HierarchicalQuantization(X=None, D=None, prms={}):
n_centroids = prms['n_clusters']
npts = len(X) if X is not None else len(D)
if D is not None:
clus = AgglomerativeClustering(n_clusters=n_centroids, affinity='precomputed', linkage='single')
clus.fit(D)
else:
clus = AgglomerativeClustering(n_clusters=n_centroids, linkage='single')
clus.fit(X)
labels = clus.labels_.ravel()
if X is None:
indices = np.array([np.argwhere(labels==c).ravel()[0] for c in range(n_centroids)])
else:
centroids = np.vstack([ np.mean(X[np.argwhere(labels==c).ravel(),:],axis=0)[None,:] for c in range(n_centroids) ])
nbrs = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit(X)
_, indices = nbrs.kneighbors(centroids)
return indices.ravel(), labels
def KMeansQuantization(X=None, D=None, prms={}):
pprms = {k:v for k,v in prms.items()}
nclus = pprms.pop('n_clusters')
new_nclus = min(len(X), nclus)
clus = KMeans(n_clusters=new_nclus, **pprms).fit(X)
cluster_centers = np.vstack([np.mean(X[np.argwhere(clus.labels_ == i).ravel(),:], axis=0)[None,:] for i in range(new_nclus)])
nbrs = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit(X)
_, indices = nbrs.kneighbors(cluster_centers)
return indices.ravel(), clus.labels_
def QuantizationSizeMinParameters(n_clusters):
def Q(X=None, D=None):
if D is None:
return {'n_clusters': min(len(X), n_clusters)}
else:
return {'n_clusters': min(len(D), n_clusters)}
return Q