-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathberts.py
94 lines (73 loc) · 2.14 KB
/
berts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import mxnet as mx
from bert_embedding import BertEmbedding
import numpy as np
import os
def cos_sim(a, b):
return np.inner(a, b) / (np.linalg.norm(a) * (np.linalg.norm(b)))
# ctx = mx.gpu()
pos = 0
count = 0
posx = 0
countx = 0
lists = os.listdir('/home/mukul/Mixed-Initiative/Switchboard-Corpus/')
for name in lists:
file = open('/home/mukul/Mixed-Initiative/Switchboard-Corpus/'+name,'r')
lines = file.readlines()
string = ''
for line in lines:
if len(line.split("|")) > 1:
string+=line.split("|")[1]+"\n"
bert_abstract = string.strip().split("\n")
sentences = bert_abstract
bert = BertEmbedding()
result = bert(sentences)
# scan_range = 8
# print(file.name)
# for index, utterance in enumerate(lines):
# if index < scan_range or index >= len(lines) - scan_range:
# continue
# a1 = []
# a2 = []
# for i in range(-scan_range,0):
# sen1 = np.mean(result[index ][1],axis=0)
# sen2 = np.mean(result[index+i][1],axis=0)
# a1.append(cos_sim(sen1, sen2))
# for i in range(1,scan_range+1):
# sen1 = np.mean(result[index][1],axis=0)
# sen2 = np.mean(result[index+1][1],axis=0)
# a2.append(cos_sim(sen1, sen2))
# a1 = np.array(a1)
# a2 = np.array(a2)
# arr1 = a1
# arr2 = a2
# a1 = a1[a1 > 0.55]
# a2 = a2[a2 > 0.55]
# if len(lines[index].split("|")) > 3:
# if "+" in lines[index].split("|")[3]:
# if len(a1)-len(a2) <= 0:
# sum1 = np.sum(arr1)
# sum2 = np.sum(arr2)
# val = np.exp(sum1-sum2)
# pos += 1
# if len(lines[index].split("|")) == 5:
# if val >= 0.5:
# if 's' in lines[index].split("|")[4]:
# posx += 1
# countx += 1
# print("Correct",lines[index],sum1-sum2)
# else:
# countx += 1
# print("Wrong=>sbd",lines[index],sum1-sum2)
# else:
# if 'i' in lines[index].split("|")[4]:
# posx += 1
# countx += 1
# print("Correct",lines[index],sum1-sum2)
# else:
# countx += 1
# print("wrong=>i",lines[index],sum1-sum2)
# count += 1
# print("Topic Shift",pos,count)
# print("Sub vs I",posx,countx)
# print(pos/count)
# print(posx,countx)