Skip to content

Commit

Permalink
unsupervised labels no works with labels directly but assumes K=num_c…
Browse files Browse the repository at this point in the history
…lasses.
  • Loading branch information
astirn committed Jul 18, 2019
1 parent 331294e commit 284213a
Showing 1 changed file with 15 additions and 42 deletions.
57 changes: 15 additions & 42 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pickle
import numpy as np
from scipy import stats
from scipy.optimize import linear_sum_assignment


Expand Down Expand Up @@ -39,56 +38,30 @@ def save_performance(perf, epoch, save_path):
assert str(perf) == str(perf_load), 'performance saving failed'


def unsupervised_labels(alpha, y, mdl, loss_type):
def unsupervised_labels(y, y_hat, num_classes, num_clusters):
"""
:param alpha: concentration parameter
:param y: true label
:param mdl: the model object
:param loss_type: name used for printing updates
:param y_hat: concentration parameter
:param num_classes: number of classes (determined by data)
:param num_clusters: number of clusters (determined by model)
:return: classification error rate
"""
# same number of classes as labels?
if mdl.K == mdl.num_classes:
assert num_classes == num_clusters

# construct y-hat
y_hat = np.argmax(alpha, axis=1)
# initialize count matrix
cnt_mtx = np.zeros([num_classes, num_classes])

# initialize count matrix
cnt_mtx = np.zeros([mdl.K, mdl.K])
# fill in matrix
for i in range(len(y)):
cnt_mtx[int(y_hat[i]), int(y[i])] += 1

# fill in matrix
for i in range(len(y)):
cnt_mtx[int(y_hat[i]), int(y[i])] += 1
# find optimal permutation
row_ind, col_ind = linear_sum_assignment(-cnt_mtx)

# find optimal permutation
row_ind, col_ind = linear_sum_assignment(-cnt_mtx)

# compute error
error = 1 - cnt_mtx[row_ind, col_ind].sum() / cnt_mtx.sum()

# different number of classes than labels
else:

# initialize y-hat
y_hat = -np.ones(y.shape)

# loop over the number of latent clusters
for i in range(mdl.K):

# find the real label corresponding to the largest concentration for this cluster
i_sort = np.argsort(alpha[:, i])[-100:]
y_real = stats.mode(y[i_sort])[0]

# assign that label to all points where its concentration is maximal
y_hat[np.argmax(alpha, axis=1) == i] = y_real

# make sure we handled everyone
assert np.sum(y_hat < 0) == 0

# compute the error
error = np.mean(y != y_hat)
# compute error
error = 1 - cnt_mtx[row_ind, col_ind].sum() / cnt_mtx.sum()

# print results
print('Classification error for ' + loss_type + ' data = {:.4f}'.format(error))
print('Classification error = {:.4f}'.format(error))

return error

0 comments on commit 284213a

Please sign in to comment.