-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvocabulary.py
171 lines (135 loc) · 5.31 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import re
from typing import List, Union, Optional, Dict, Any, Set
class WordVocabulary:
def __init__(
self,
list_of_sentences: Optional[List[Union[str, List[str]]]] = None,
allow_any_word: bool = False,
split_punctuation: bool = True,
):
self.words: Dict[str, int] = {}
self.inv_words: Dict[int, str] = {}
# always map <pad> to 0!
self._add_word("<pad>")
self.to_save = [
"words",
"inv_words",
"_unk_index",
"allow_any_word",
"split_punctuation",
]
self.allow_any_word = allow_any_word
self.initialized = False
self.split_punctuation = split_punctuation
if list_of_sentences is not None:
words = set()
for s in list_of_sentences:
words |= set(self.split_sentence(s))
self._add_set(words)
self.finalize()
def finalize(self):
self._unk_index = self.words.get("<UNK>", self.words.get("<unk>"))
if self.allow_any_word:
assert self._unk_index is not None
self.initialized = True
def _add_word(self, w: str):
next_id = len(self.words)
self.words[w] = next_id
self.inv_words[next_id] = w
def _add_set(self, words: Set[str]):
for w in sorted(words):
self._add_word(w)
def _process_word(self, w: str) -> int:
res = self.words.get(w, self._unk_index)
assert (
res != self._unk_index
) or self.allow_any_word, f"WARNING: unknown word: '{w}'"
return res
def _process_index(self, i: int) -> str:
res = self.inv_words.get(i, None)
if res is None:
return f"<!INV: {i}!>"
return res
def __getitem__(self, item: Union[int, str]) -> Union[str, int]:
if isinstance(item, int):
return self._process_index(item)
else:
return self._process_word(item)
def split_sentence(self, sentence: Union[str, List[str]]) -> List[str]:
if isinstance(sentence, list):
# Already tokenized.
return sentence
if self.split_punctuation:
return re.findall(r"\w+|[^\w\s]", sentence, re.UNICODE)
else:
return [x for x in sentence.split(" ") if x]
def sentence_to_indices(self, sentence: Union[str, List[str]]) -> List[int]:
assert self.initialized
words = self.split_sentence(sentence)
return [self._process_word(w) for w in words]
def indices_to_sentence(self, indices: List[int]) -> List[str]:
assert self.initialized
return [self._process_index(i) for i in indices]
def __call__(self, seq: Union[List[Union[str, int]], str]) -> List[Union[int, str]]:
if seq is None or (isinstance(seq, list) and not seq):
return seq
if isinstance(seq, str) or isinstance(seq[0], str):
return self.sentence_to_indices(seq)
else:
return self.indices_to_sentence(seq)
def __len__(self) -> int:
return len(self.words)
def state_dict(self) -> Dict[str, Any]:
return {k: self.__dict__[k] for k in self.to_save}
def load_state_dict(self, state: Dict[str, Any]):
self.initialized = True
self.__dict__.update(state)
def __add__(self, other):
res = WordVocabulary(
allow_any_word=self.allow_any_word and other.allow_any_word,
split_punctuation=self.split_punctuation,
)
res._add_set(set(self.words.keys()) | set(other.words.keys()))
res.finalize()
return res
def mapfrom(self, other) -> Dict[int, int]:
return {other.words[w]: i for w, i in self.words.items() if w in other.words}
class CharVocabulary:
def __init__(self, chars: Optional[Set[str]]):
self.initialized = False
if chars is not None:
self.from_set(chars)
def from_set(self, chars: Set[str]):
chars = list(sorted(chars))
self.to_index = {c: i for i, c in enumerate(chars)}
self.from_index = {i: c for i, c in enumerate(chars)}
self.initialized = True
def __len__(self):
return len(self.to_index)
def state_dict(self) -> Dict[str, Any]:
return {"chars": set(self.to_index.keys())}
def load_state_dict(self, state: Dict[str, Any]):
self.from_set(state["chars"])
def str_to_ind(self, data: str) -> List[int]:
return [self.to_index[c] for c in data]
def ind_to_str(self, data: List[int]) -> str:
return "".join([self.from_index[i] for i in data])
def _is_string(self, i):
return isinstance(i, str)
def __call__(self, seq: Union[List[int], str]) -> Union[List[int], str]:
assert self.initialized
if seq is None or (isinstance(seq, list) and not seq):
return seq
if self._is_string(seq):
return self.str_to_ind(seq)
else:
return self.ind_to_str(seq)
def __add__(self, other):
return self.__class__(
set(self.to_index.values()) | set(other.to_index.values())
)
class ByteVocabulary(CharVocabulary):
def ind_to_str(self, data: List[int]) -> bytearray:
return bytearray([self.from_index[i] for i in data])
def _is_string(self, i):
return isinstance(i, bytearray)