-
Notifications
You must be signed in to change notification settings - Fork 0
/
twc_clustering.py
177 lines (154 loc) · 6.74 KB
/
twc_clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from scipy.spatial.distance import cosine
import argparse
import json
import pdb
import torch
import torch.nn.functional as F
import numpy as np
import time
from collections import OrderedDict
class TWCClustering:
def __init__(self):
print("In Zscore Clustering")
def compute_matrix(self,embeddings):
#print("Computing similarity matrix ...)")
embeddings= np.array(embeddings)
start = time.time()
vec_a = embeddings.T #vec_a shape (1024,)
vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows
vec_a = vec_a.T #vec_a shape becomes (,1024)
similarity_matrix = np.inner(vec_a,vec_a)
end = time.time()
time_val = (end-start)*1000
#print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes")
return similarity_matrix
def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold):
run_index = pivot_index
picked_arr = []
while (run_index < len(embeddings)):
if (matrix[pivot_index][run_index] >= threshold):
picked_arr.append(run_index)
run_index += 1
return picked_arr
def update_picked_dict_arr(self,picked_dict,arr):
for i in range(len(arr)):
picked_dict[arr[i]] = 1
def update_picked_dict(self,picked_dict,in_dict):
for key in in_dict:
picked_dict[key] = 1
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
center_index = pivot_index
center_score = 0
center_dict = {}
for i in range(len(arr)):
node_i_index = arr[i]
running_score = 0
temp_dict = {}
for j in range(len(arr)):
node_j_index = arr[j]
cosine_dist = matrix[node_i_index][node_j_index]
if ((cosine_dist < threshold) and strict_cluster):
continue
running_score += cosine_dist
temp_dict[node_j_index] = cosine_dist
if (running_score > center_score):
center_index = node_i_index
center_dict = temp_dict
center_score = running_score
sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True))
return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d}
def update_overlap_stats(self,overlap_dict,cluster_info):
arr = list(cluster_info["neighs"].keys())
for val in arr:
if (val not in overlap_dict):
overlap_dict[val] = 1
else:
overlap_dict[val] += 1
def bucket_overlap(self,overlap_dict):
bucket_dict = {}
for key in overlap_dict:
if (overlap_dict[key] not in bucket_dict):
bucket_dict[overlap_dict[key]] = 1
else:
bucket_dict[overlap_dict[key]] += 1
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
return sorted_d
def merge_clusters(self,ref_cluster,curr_cluster):
dup_arr = ref_cluster.copy()
for j in range(len(curr_cluster)):
if (curr_cluster[j] not in dup_arr):
ref_cluster.append(curr_cluster[j])
def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
candidates = []
for i in range(len(embeddings)):
if (i in picked_dict):
continue
zscore = mean + threshold*std
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
candidates.append(arr)
self.update_picked_dict_arr(picked_dict,arr)
# Merge arrays to create non-overlapping sets
run_index_i = 0
while (run_index_i < len(candidates)):
ref_cluster = candidates[run_index_i]
run_index_j = run_index_i + 1
found = False
while (run_index_j < len(candidates)):
curr_cluster = candidates[run_index_j]
for k in range(len(curr_cluster)):
if (curr_cluster[k] in ref_cluster):
self.merge_clusters(ref_cluster,curr_cluster)
candidates.pop(run_index_j)
found = True
run_index_i = 0
break
if (found):
break
else:
run_index_j += 1
if (not found):
run_index_i += 1
zscore = mean + threshold*std
for i in range(len(candidates)):
arr = candidates[i]
cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
cluster_dict["clusters"].append(cluster_info)
return {}
def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
zscore = mean + threshold*std
for i in range(len(embeddings)):
if (i in picked_dict):
continue
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
self.update_picked_dict(picked_dict,cluster_info["neighs"])
self.update_overlap_stats(overlap_dict,cluster_info)
cluster_dict["clusters"].append(cluster_info)
sorted_d = self.bucket_overlap(overlap_dict)
return sorted_d
def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
is_overlapped = True if clustering_type == "overlapped" else False
matrix = self.compute_matrix(embeddings)
mean = np.mean(matrix)
std = np.std(matrix)
zscores = []
inc = 0
value = mean
while (value < 1):
zscores.append({"threshold":inc,"cosine":round(value,2)})
inc += 1
value = mean + inc*std
#print("In clustering:",round(std,2),zscores)
cluster_dict = {}
cluster_dict["clusters"] = []
if (is_overlapped):
sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
else:
sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
return cluster_dict