forked from atpaino/deep-text-corrector
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_reader.py
129 lines (100 loc) · 3.96 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import Counter
# Define constants associated with the usual special-case tokens.
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
PAD_TOKEN = "PAD"
EOS_TOKEN = "EOS"
GO_TOKEN = "GO"
class DataReader(object):
def __init__(self, config, train_path, special_tokens=(), dataset_copies=1):
self.config = config
self.dataset_copies = dataset_copies
# Construct vocabulary.
max_vocabulary_size = self.config.max_vocabulary_size
token_counts = Counter()
for tokens in self.read_tokens(train_path):
token_counts.update(tokens)
self.token_counts = token_counts
# Get to max_vocab_size words
count_pairs = sorted(token_counts.items(), key=lambda x: (-x[1], x[0]))
vocabulary, _ = list(zip(*count_pairs))
vocabulary = list(vocabulary)
# Insert the special tokens at the beginning.
vocabulary[0:0] = special_tokens
full_token_and_id = zip(vocabulary, range(len(vocabulary)))
self.full_token_to_id = dict(full_token_and_id)
self.token_to_id = dict(full_token_and_id[:max_vocabulary_size])
self.id_to_token = {v: k for k, v in self.token_to_id.items()}
def read_tokens(self, path):
"""
Reads the given file line by line and yields the list of tokens present
in each line.
:param path:
:return:
"""
raise NotImplementedError("Must implement read_tokens")
def read_samples_by_string(self, path):
"""
Reads the given file line by line and yields the word-form of each
derived sample.
:param path:
:return:
"""
raise NotImplementedError("Must implement read_word_samples")
def unknown_token(self):
raise NotImplementedError("Must implement read_word_samples")
def convert_token_to_id(self, token):
"""
:param token:
:return:
"""
token_with_id = token if token in self.token_to_id else \
self.unknown_token()
return self.token_to_id[token_with_id]
def convert_id_to_token(self, token_id):
return self.id_to_token[token_id]
def is_unknown_token(self, token):
"""
True if the given token is out of the vocabulary used or if it is the
actual unknown token.
:param token:
:return:
"""
return token not in self.token_to_id or token == self.unknown_token()
def sentence_to_token_ids(self, sentence):
"""
Converts a whitespace-delimited sentence into a list of word ids.
"""
return [self.convert_token_to_id(word) for word in sentence.split()]
def token_ids_to_tokens(self, word_ids):
"""
Converts a list of word ids to a list of their corresponding words.
"""
return [self.convert_id_to_token(word) for word in word_ids]
def read_samples(self, path):
"""
:param path:
:return:
"""
for source_words, target_words in self.read_samples_by_string(path):
source = [self.convert_token_to_id(word) for word in source_words]
target = [self.convert_token_to_id(word) for word in target_words]
target.append(EOS_ID)
yield source, target
def build_dataset(self, path):
dataset = [[] for _ in self.config.buckets]
# Make multiple copies of the dataset so that we synthesize different
# dropouts.
for _ in range(self.dataset_copies):
for source, target in self.read_samples(path):
for bucket_id, (source_size, target_size) in enumerate(
self.config.buckets):
if len(source) < source_size and len(
target) < target_size:
dataset[bucket_id].append([source, target])
break
return dataset