forked from tomsherborne/nlu_cw2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate.py
135 lines (113 loc) · 6 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import logging
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.serialization import default_restore_location
from seq2seq import models, utils
from seq2seq.data.dictionary import Dictionary
from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
def get_args():
""" Defines generation-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
parser.add_argument('--cuda', default = False, help = 'Use a GPU')
parser.add_argument('--seed', default = 42, type = int, help = 'pseudo random number generator seed')
# Add data arguments
parser.add_argument('--data', default = 'data-bin', help = 'path to data directory')
parser.add_argument('--checkpoint-path', default = 'checkpoints/checkpoint_best.pt',
help = 'path to the model file')
parser.add_argument('--batch-size', default = None, type = int, help = 'maximum number of sentences in a batch')
parser.add_argument('--output', default = 'model_translations.txt', type = str,
help = 'path to the output file destination')
parser.add_argument('--max-len', default = 25, type = int, help = 'maximum length of generated sequence')
return parser.parse_args()
def main(args):
""" Main translation function' """
# Load arguments from checkpoint
torch.manual_seed(args.seed)
state_dict = torch.load(args.checkpoint_path, map_location = lambda s, l: default_restore_location(s, 'cpu'))
saved_args = vars(state_dict['args'])
for k in saved_args:
if type(saved_args[k]) == str and '/home/lvyajie/nlu_cw2/' in saved_args[k]:
saved_args[k] = saved_args[k].replace('/home/lvyajie/nlu_cw2/', '')
args = argparse.Namespace(**{**vars(args), **saved_args})
# absolute path to relative path
utils.init_logging(args)
# Load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))
# Load dataset
test_dataset = Seq2SeqDataset(
src_file = os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
tgt_file = os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
src_dict = src_dict, tgt_dict = tgt_dict)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers = 1, collate_fn = test_dataset.collater,
batch_sampler = BatchSampler(test_dataset, 9999999,
args.batch_size, 1, 0, shuffle = False,
seed = args.seed))
# Build model and criterion
model = models.build_model(args, src_dict, tgt_dict)
if args.cuda and torch.cuda.is_available():
model = model.cuda()
model.eval()
model.load_state_dict(state_dict['model'])
logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
progress_bar = tqdm(test_loader, desc = '| Generation', leave = False)
# Iterate over the test set
all_hyps = {}
for i, sample in enumerate(progress_bar):
if torch.cuda.is_available() and args.cuda:
for k in sample:
if type(sample[k]) == torch.Tensor:
sample[k] = sample[k].cuda()
with torch.no_grad():
# Compute the encoder output
encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
go_slice = \
torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
prev_words = go_slice
next_words = None
for _ in range(args.max_len):
with torch.no_grad():
# Compute the decoder output by repeatedly feeding it the decoded sentence prefix
decoder_out, _ = model.decoder(prev_words, encoder_out)
# Suppress <UNK>s
_, next_candidates = torch.topk(decoder_out, 2, dim = -1)
best_candidates = next_candidates[:, :, 0]
backoff_candidates = next_candidates[:, :, 1]
next_words = torch.where(best_candidates == tgt_dict.unk_idx, backoff_candidates, best_candidates)
prev_words = torch.cat([go_slice, next_words], dim = 1)
# Segment into sentences
if next_words.is_cuda:
decoded_batch = next_words.detach().cpu().numpy()
else:
decoded_batch = next_words.numpy()
output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]
assert (len(output_sentences) == len(sample['id'].data))
# Remove padding
temp = list()
for sent in output_sentences:
first_eos = np.where(sent == tgt_dict.eos_idx)[0]
if len(first_eos) > 0:
temp.append(sent[:first_eos[0]])
else:
temp.append([])
output_sentences = temp
# Convert arrays of indices into strings of words
output_sentences = [tgt_dict.string(sent) for sent in output_sentences]
# Save translations
assert (len(output_sentences) == len(sample['id'].data))
for ii, sent in enumerate(output_sentences):
all_hyps[int(sample['id'].data[ii])] = sent
# Write to file
if args.output is not None:
with open(args.output, 'w') as out_file:
for sent_id in range(len(all_hyps.keys())):
out_file.write(all_hyps[sent_id] + '\n')
logging.info("Output {:d} translations to {:s}".format(len(all_hyps), args.output))
if __name__ == '__main__':
args = get_args()
main(args)