-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
93 lines (72 loc) · 3.21 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import unicode_literals, print_function, division
import torch
from io import open
import unicodedata
import string
import re
import random
from vars import *
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]
def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
def tensorsFromPair(pair):
input_tensor = tensorFromSentence(input_lang, pair[0])
target_tensor = tensorFromSentence(output_lang, pair[1])
return (input_tensor, target_tensor)
# Turn a Unicode string to plain ASCII
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.,!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.,!?ßüöä€]+", r" ", s)
return s
def evaluate(encoder, decoder, sentence, input_lang, output_lang, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = tensorFromSentence(input_lang, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei],
encoder_hidden)
encoder_outputs[ei] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device) # SOS
decoder_hidden = encoder_hidden
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
for di in range(max_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
if topi.item() == EOS_token:
break
else:
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoded_words, decoder_attentions[:di + 1]
def deNormalize(input_string):
return re.sub(r'\s([?.!"](?:\s|$))', r'\1', input_string)
def evaluateRandomly(encoder, decoder, pairs, n=10):
for i in range(n):
pair = random.choice(pairs)
print('Input : ', pair[0])
print('Target: ', pair[1])
output_words, attentions = evaluate(encoder, decoder, pair[0])
output_sentence = deNormalize(' '.join(output_words))
print('Output: ', output_sentence)
print()
def translate(input_string, encoder, decoder, input_lang, output_lang):
st = normalizeString(input_string)
output_words, attentions = evaluate(
encoder, decoder, st, input_lang, output_lang)
output_string = ' '.join(output_words)
return deNormalize(output_string)