-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvocab.py
73 lines (59 loc) · 2.23 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
from vars import *
from engine import normalizeString
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2 # Count SOS and EOS
self.vocab = set()
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def readLangs(lang1, lang2, reverse=False):
# print("Reading lines...")
# Read the file and split into lines
lines = open('files/deu.txt', encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
# Reverse pairs, make Lang instances
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = Lang(lang2)
output_lang = Lang(lang1)
else:
input_lang = Lang(lang1)
output_lang = Lang(lang2)
return input_lang, output_lang, pairs
def filterPair(p):
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
def prepareData(lang1, lang2, reverse=True):
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
# print("Read %s sentence pairs" % len(pairs))
pairs = filterPairs(pairs)
# print("Trimmed to %s sentence pairs" % len(pairs))
# print("Counting words...")
for pair in pairs:
input_lang.addSentence(pair[0])
output_lang.addSentence(pair[1])
# print("Counted words:")
# print(input_lang.name, input_lang.n_words)
# print(output_lang.name, output_lang.n_words)
return input_lang, output_lang, pairs