forked from rsingha108/TransLIST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding_gen.py
78 lines (65 loc) · 2.49 KB
/
embedding_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import pandas
import torch
from torch import nn
from tqdm import tqdm
data = pandas.read_csv('H22-data/zero-shot_collection3_test.csv')#'H22-data/overall_test.csv'
char_emb_path = 'H22_embeds/H22.char.vec'
word_emb_path = 'H22_embeds/H22.word.vec'
word_char_mix_path = 'H22_embeds/H22_word_char_mix.txt'
bigram_emb_path = 'H22_embeds/H22.bigram.vec'
characters = []
words = []
bigrams = []
words0 = set()
words1 = set()
words2 = set()
words3 = set()
words4 = set()
for i,row in tqdm(data.iterrows()):
temp = list(set(list(row['input'].replace(' ','_'))))
characters = list(set(characters+temp))
temp = list(set([row['input'].replace(' ','_')[i:i+2] for i in range(len(row['input'])-1)]))
bigrams = list(set(bigrams+temp))
temp = list(set(list(row['output'].replace(' ','_'))))
characters = list(set(characters+temp))
temp = list(set([row['output'].replace(' ','_')[i:i+2] for i in range(len(row['output'])-1)]))
bigrams = list(set(bigrams+temp))
temp = row['output'].split('_')
#words = list(set(words+list(set(temp))))
if i%5==0 : words0.update(temp)
if i%5==1 : words1.update(temp)
if i%5==2 : words2.update(temp)
if i%5==3 : words3.update(temp)
if i%5==4 : words4.update(temp)
words = list(words0|words1|words2|words3|words4)
#print('Characters: ',characters[0:10],str(len(words) != len(set(words))))
#print('Words: ',words[0:10],str(len(words) != len(set(words))))
embedding = nn.Embedding(len(characters),50)
f = open(char_emb_path,'w')
for i in range(len(characters)):
temp = embedding(torch.LongTensor([i])).tolist()
f.write(characters[i]+' '+str(' '.join(str(v) for v in temp))+'\n')
f.close()
embedding = nn.Embedding(len(bigrams),50)
f = open(bigram_emb_path,'w')
for i in range(len(bigrams)):
temp = embedding(torch.LongTensor([i])).tolist()
f.write(bigrams[i]+' '+' '.join([str(x) for x in temp[0]])+'\n')
f.close()
embedding = nn.Embedding(len(words),50)
f = open(word_emb_path,'w')
for i in range(len(words)):
temp = embedding(torch.LongTensor([i])).tolist()
f.write(words[i]+' '+str(' '.join(str(v) for v in temp))+'\n')
f.close()
lexicon_f = open(word_emb_path,'r')
char_f = open(char_emb_path,'r')
output_f = open(word_char_mix_path,'w')
lexicon_lines = lexicon_f.readlines()
for l in lexicon_lines:
l_split = l.strip().split()
if len(l_split[0]) != 1:
print(l.strip(),file=output_f)
char_lines = char_f.readlines()
for l in char_lines:
print(l.strip(),file=output_f)