-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
executable file
·55 lines (51 loc) · 1.71 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
import re
from model import *
from utils import *
def load_model():
word_to_idx = load_word_to_idx(sys.argv[2])
tag_to_idx = load_tag_to_idx(sys.argv[3])
idx_to_tag = [tag for tag, _ in sorted(tag_to_idx.items(), key = lambda x: x[1])]
model = lstm_crf(len(word_to_idx), len(tag_to_idx))
if CUDA:
model = model.cuda()
print(model)
load_checkpoint(sys.argv[1], model)
return model, word_to_idx, tag_to_idx, idx_to_tag
def run_model(model, idx_to_tag, data):
batch = []
z = len(data)
while len(data) < BATCH_SIZE:
data.append(["", [EOS_IDX]])
data.sort(key = lambda x: len(x[1]), reverse = True)
batch_len = len(data[0][1])
batch = [x + [PAD_IDX] * (batch_len - len(x)) for _, x in data]
batch = Var(LongTensor(batch))
result = model.decode(batch)
for i in range(z):
data[i].append([idx_to_tag[j] for j in result[i]])
return data[:z]
def predict():
data = []
model, word_to_idx, tag_to_idx, idx_to_tag = load_model()
fo = open(sys.argv[4])
for line in fo:
line = line.strip()
tokens = tokenize(line, "char")
x = [word_to_idx[i] if i in word_to_idx else UNK_IDX for i in tokens] + [EOS_IDX]
data.append([line, x])
if len(data) == BATCH_SIZE:
result = run_model(model, idx_to_tag, data)
for x in result:
print(x)
data = []
fo.close()
if len(data):
result = run_model(model, idx_to_tag, data)
for x in result:
print(x)
if __name__ == "__main__":
if len(sys.argv) != 5:
sys.exit("Usage: %s model word_to_idx tag_to_idx test_data" % sys.argv[0])
print("cuda: %s" % CUDA)
predict()