-
Notifications
You must be signed in to change notification settings - Fork 1
/
get_pretrain_vecs.py
executable file
·64 lines (50 loc) · 1.92 KB
/
get_pretrain_vecs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import h5py
import re
import sys
import operator
import argparse
def load_glove_vec(fname, vocab):
dim = 0
word_vecs = {}
word_vec_size = None
for line in open(fname, 'r'):
d = line.split()
# get info from the first word
if word_vec_size is None:
word_vec_size = len(d) - 1
word = ' '.join(d[:len(d)-word_vec_size])
vec = d[-word_vec_size:]
vec = np.array(list(map(float, vec)))
dim = vec.size
if len(d) - word_vec_size != 1:
#print('multi word token found: {0}'.format(line))
pass
if word in vocab:
word_vecs[word] = vec
return word_vecs, dim
def main():
parser = argparse.ArgumentParser(
description =__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--dir', help="The path to data dir", type=str, default='data/squad-v1.1/')
parser.add_argument('--dict', help="*.dict file", type=str, default='squad.word.dict')
parser.add_argument('--glove', help='pretrained word vectors', type=str, default='')
parser.add_argument('--output', help="output hdf5 file", type=str, default='glove')
args = parser.parse_args()
args.dict = args.dir + args.dict
args.output = args.dir + args.output
vocab = open(args.dict, "r").read().split("\n")[:-1]
vocab = list(map(lambda x: (x.split()[0], int(x.split()[1])), vocab))
word2idx = {x[0]: x[1] for x in vocab}
print("vocab size: " + str(len(vocab)))
w2v, dim = load_glove_vec(args.glove, word2idx)
print("matched word vector size: {0}, dim: {1}".format(len(w2v), dim))
rs = np.random.normal(scale = 0.05, size = (len(vocab), dim))
print("num words in pretrained model is " + str(len(w2v)))
for word, vec in w2v.items():
rs[word2idx[word]] = vec
with h5py.File(args.output + '.hdf5', "w") as f:
f["word_vecs"] = np.array(rs)
if __name__ == '__main__':
main()