forked from zenRRan/Sentiment-Analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder.py
131 lines (116 loc) · 5.13 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
# encoding: utf-8
"""
@version: python3.6
@author: 'zenRRan'
@license: Apache Licence
@contact: [email protected]
@software: PyCharm
@file: decoder.py
@time: 2018/12/9 19:21
"""
import torch
from utils.Common import *
from utils.build_batch import Build_Batch
import argparse
from utils.opts import *
from torch.autograd import Variable
class Decoder:
def __init__(self, opts):
self.opts = opts
self.model = torch.load(self.opts.model_path)
self.features_list, self.vocab, self.char_vocab, self.label_vocab, self.rel_vocab \
= None, None, None, None, None
self.batch_size = self.opts.batch_size
self.save_path = self.opts.save_path
self.load_data(self.opts.dir, self.opts.type)
self.decoder()
def load_data(self, data_dir, type):
self.features_list = torch.load(data_dir + '/'+ type +'.sst')
self.vocab = torch.load(data_dir + '/vocab.sst')
self.char_vocab = torch.load(data_dir + '/char_vocab.sst')
self.label_vocab = torch.load(data_dir + '/label_vocab.sst')
self.rel_vocab = torch.load(data_dir + '/rel_vocab.sst')
def decoder(self):
'''
build train dev test batches
'''
padding_id = self.vocab.from_string(padding_key)
char_padding_id = self.char_vocab.from_string(padding_key)
rel_padding_id = None
if self.rel_vocab is not None:
rel_padding_id = self.rel_vocab.from_string(padding_key)
self.build_batch = Build_Batch(features=self.features_list,
batch_size=self.batch_size,
opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id,
rel_padding_id=rel_padding_id)
self.batch_features, self.data_batchs = self.build_batch.create_sorted_normal_batch()
# correct_num = 0
data_batchs = self.data_batchs
wrongs = []
for batch in data_batchs:
self.model.eval()
if 'tree' in self.opts.model:
sents = Variable(torch.LongTensor(batch[0]), requires_grad=False)
label = Variable(torch.LongTensor(batch[1]), requires_grad=False)
heads = batch[4]
xlength = batch[6]
tag_rels = Variable(torch.LongTensor(batch[7]), requires_grad=False)
if self.opts.use_cuda:
sents = sents.cuda()
label = label.cuda()
tag_rels = tag_rels.cuda()
if self.opts.model in ['treelstm', 'bitreelstm']:
pred = self.model(sents, heads, xlength)
if self.opts.model in ['lstm_treelstm_rel', 'treelstm_rel', 'bitreelstm_rel']:
pred = self.model(sents, tag_rels, heads, xlength)
else:
sents = Variable(torch.LongTensor(batch[0]))
label = Variable(torch.LongTensor(batch[1]))
char_data = []
if 'Char' in self.opts.model:
for char_list in batch[2]:
char_data.append(Variable(torch.LongTensor(char_list)))
if self.opts.use_cuda:
sents = sents.cuda()
label = label.cuda()
new_char_data = []
for data in char_data:
new_char_data.append(data.cuda())
char_data = new_char_data
if 'Char' in self.opts.model:
pred = self.model(sents, char_data)
else:
pred = self.model(sents)
# correct_num += (torch.max(pred, 1)[1].view(label.size()).data == label.data).sum()
pred_index = torch.max(pred, 1)[1].view(label.size()).data.tolist()
sents = batch[0]
label = batch[1]
for index, (t, p) in enumerate(zip(label, pred_index)):
if t != p:
wrong_sent, length = self.get_sent(sents[index])
right_label = self.get_label(t)
wrong_label = self.get_label(p)
wrongs.append((wrong_sent, length, wrong_label, right_label))
self.write(wrongs)
def write(self, wrongs):
with open(self.save_path, 'w', encoding='utf8') as f:
for wrong in wrongs:
f.write('pred: ' + str(wrong[2]) + ' right: ' + str(wrong[3]) + ' length: ' + str(wrong[1]) + ' sent: ' + wrong[0] + '\n')
def get_sent(self, idx):
sent = []
length = 0
for id in idx:
word = self.vocab.id2string[id]
if word != padding_key:
sent.append(word)
length += 1
return ' '.join(sent), length
def get_label(self, id):
return self.label_vocab.id2string[id]
if __name__ == '__main__':
parser = argparse.ArgumentParser('Train opts')
parser = decoder_opts(parser)
opts = parser.parse_args()
# print(opts)
Decoder(opts)