-
Notifications
You must be signed in to change notification settings - Fork 4
/
predict.py
executable file
·101 lines (76 loc) · 2.98 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import os
import time
import torch
import config
from inference.evaluate import evaluate_greedy_generator
from pykp.model import Seq2SeqModel
from pykp.utils.io import build_interactive_predict_dataset
from utils.data_loader import load_vocab, build_data_loader
from utils.functions import common_process_opt, read_tokenized_src_file
from utils.functions import time_since
from inference.sequence_generator import SequenceGenerator
def process_opt(opt):
opt = common_process_opt(opt)
if not os.path.exists(opt.pred_path):
os.makedirs(opt.pred_path)
if torch.cuda.is_available():
if not opt.gpuid:
opt.gpuid = 0
opt.device = torch.device("cuda:%d" % opt.gpuid)
else:
opt.device = torch.device("cpu")
opt.gpuid = -1
print("CUDA is not available, fall back to CPU.")
return opt
def init_pretrained_model(opt):
model = Seq2SeqModel(opt)
model.load_state_dict(torch.load(opt.model))
model.to(opt.device)
model.eval()
return model
def predict(test_data_loader, model, opt):
generator = SequenceGenerator.from_opt(model, opt)
evaluate_greedy_generator(test_data_loader, generator, opt)
def main(opt):
vocab = load_vocab(opt)
src_file = opt.src_file
tokenized_src = read_tokenized_src_file(src_file)
retriever = None
if opt.use_multidoc_graph:
from retrievers.retriever import Retriever
logging.info("Initialized retriever and loading references documents. ")
retriever = Retriever(opt)
opt.retriever = retriever
if opt.one2many:
mode = 'one2many'
else:
mode = 'one2one'
test_data = build_interactive_predict_dataset(tokenized_src, opt, mode=mode, include_original=True)
torch.save(test_data, open(opt.exp_path + "/test_%s.pt" % mode, 'wb'))
test_loader = build_data_loader(data=test_data, opt=opt, shuffle=False, load_train=False)
logging.info('#(test data size: #(batch)=%d' % (len(test_loader)))
# init the pretrained model
model = init_pretrained_model(opt)
# Print out predict path
logging.info("Prediction path: %s" % opt.pred_path)
# predict the keyphrases of the src file and output it to opt.pred_path/predictions.txt
start_time = time.time()
predict(test_loader, model, opt)
training_time = time_since(start_time)
logging.info('Time for training: %.1f' % training_time)
if __name__ == '__main__':
# load settings for training
parser = argparse.ArgumentParser(
description='interactive_predict.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
config.vocab_opts(parser)
config.model_opts(parser)
config.predict_opts(parser)
config.retriever_opts(parser)
opt = parser.parse_args()
opt = process_opt(opt)
logging = config.init_logging(log_file=opt.exp_path + '/output.log', stdout=True)
logging.info('Parameters:')
[logging.info('%s : %s' % (k, str(v))) for k, v in opt.__dict__.items()]
main(opt)