-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe.py
304 lines (262 loc) · 13 KB
/
bpe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into
sequences of integers, where each integer represents small chunks of commonly
occuring characters. This implementation is based on openai's gpt2 encoder.py:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
but was mildly modified because the original implementation is a bit confusing.
I also tried to add as many comments as possible, my own understanding of what's
going on.
"""
import os
import json
import regex as re
import requests
import torch
# -----------------------------------------------------------------------------
def bytes_to_unicode():
"""
Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode
character that represents it visually. Some bytes have their appearance preserved
because they don't cause any trouble. These are defined in list bs. For example:
chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".
However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these
bytes, into new characters in a range where chr() returns a single nice character.
So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).
In particular, the space character is 32, which we can see by ord(' '). Instead,
this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.
So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters
that "look nice", either in their original form, or a funny shifted character
like 'Ā', or 'Ġ', etc.
"""
# the 188 integers that render fine in their original form and need no shifting
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
# now get the representations of the other 68 integers that do need shifting
# each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
n = 0
for b in range(2**8):
if b not in bs:
# if this byte is "ugly" then map it to the next available "nice" character
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
d = dict(zip(bs, cs))
return d
def get_pairs(word):
"""
Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges):
# byte encoder/decoder
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
# bpe token encoder/decoder
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
# bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# the splitting pattern used for pre-tokenization
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment
"""
ok so what is this regex looking for, exactly?
python re reference: https://docs.python.org/3/library/re.html
- the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right
- '\'s' would split up things like Andrej's -> (Andrej, 's)
- ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter"
- ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number"
- ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number
- '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace
so this will consume whitespace characters in a sequence but exclude the last whitespace in
that sequence. that last whitespace has the opportunity to then match the optional ' ?' in
earlier patterns.
- '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string
So TLDR:
- we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens
- we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces
"""
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.cache = {}
def bpe(self, token):
"""
this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens
up the tree. token is a string of one individual 'word' (after regex tokenization)
and after byte encoding, e.g. 'Ġthere'.
"""
# token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'
# memoization, for efficiency
if token in self.cache:
return self.cache[token]
word = tuple(token) # individual characters that make up the token, in a tuple
pairs = get_pairs(word) # get all bigrams
if not pairs:
return token
while True:
# find the next lowest rank bigram that can be merged
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break # no more bigrams are eligible to be merged
first, second = bigram
# we will now replace all occurences of (first, second) in the list of current
# words into one merged token first_second, in the output list new_words
new_word = []
i = 0
while i < len(word):
# find the next occurence of first in the sequence of current words
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
# if this occurence is also followed by second, then merge them into one
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
# all occurences of (first, second) have been merged to first_second
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
# concat all words into a string, and use ' ' as the separator. Note that
# by now all characters have been byte encoded, guaranteeing that ' ' is
# not used in the actual data and is a 'special' delimiter character
word = ' '.join(word)
# cache the result and return
self.cache[token] = word
return word
def encode(self, text):
""" string goes in, list of integers comes out """
bpe_idx = []
# pre-tokenize the input text into string tokens (words, roughly speaking)
tokens = re.findall(self.pat, text)
# process each token into BPE integers
for token in tokens:
# encode the token as a bytes (b'') object
token_bytes = token.encode('utf-8')
# translate all bytes to their unicode string representation and flatten
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
# perform all the applicable bpe merges according to self.bpe_ranks
token_merged = self.bpe(token_translated).split(' ')
# translate all bpe tokens to integers
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
# extend our running list of all output integers
bpe_idx.extend(token_ix)
return bpe_idx
def encode_and_show_work(self, text):
""" debugging function, same as encode but returns all intermediate work """
bpe_idx = []
parts = []
tokens = re.findall(self.pat, text)
for token in tokens:
token_bytes = token.encode('utf-8')
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
token_merged = self.bpe(token_translated).split(' ')
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
bpe_idx.extend(token_ix)
parts.append({
'token': token,
'token_bytes': token_bytes,
'token_translated': token_translated,
'token_merged': token_merged,
'token_ix': token_ix,
})
out = {
'bpe_idx': bpe_idx, # the actual output sequence
'tokens': tokens, # result of pre-tokenization
'parts': parts, # intermediates for each token part
}
return out
def decode(self, bpe_idx):
""" list of integers comes in, string comes out """
# inverse map the integers to get the tokens
tokens_merged = [self.decoder[token] for token in bpe_idx]
# inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
tokens_flat = ''.join(tokens_merged)
tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
# recover the full utf-8 string
text = tokens_bytes.decode('utf-8', errors='replace')
return text
def get_file(local_file, remote_file):
""" downloads remote_file to local_file if necessary """
if not os.path.isfile(local_file):
print(f"downloading {remote_file} to {local_file}")
response = requests.get(remote_file)
open(local_file, "wb").write(response.content)
def get_encoder():
"""
Returns an instance of the GPT BPE Encoder/Decoder
and handles caching of "database" files.
"""
os.makedirs(".", exist_ok=True)
# load encoder.json that has the raw mappings from token -> bpe index
encoder_local_file = os.path.join('.', 'encoder.json')
encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json'
get_file(encoder_local_file, encoder_remote_file)
with open(encoder_local_file, 'r') as f:
encoder = json.load(f)
assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token
# load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure
# in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab
vocab_local_file = os.path.join('.', 'vocab.bpe')
vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
get_file(vocab_local_file, vocab_remote_file)
with open(vocab_local_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
# light postprocessing: strip the version on first line and the last line is a blank
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
assert len(bpe_merges) == 50000 # 50,000 merged tokens
# construct the Encoder object and return
enc = Encoder(encoder, bpe_merges)
return enc
# -----------------------------------------------------------------------------
class BPETokenizer:
""" PyTorch-aware class that wraps the Encoder above """
def __init__(self):
self.encoder = get_encoder()
def __call__(self, text, return_tensors='pt'):
# PyTorch only; here because we want to match huggingface/transformers interface
assert return_tensors == 'pt'
# single string input for now, in the future potentially a list of strings
assert isinstance(text, str)
# encode and create a "batch dimension" of 1
idx = [self.encoder.encode(text)]
# wrap into PyTorch tensor
out = torch.tensor(idx, dtype=torch.long)
return out
def decode(self, idx):
# ensure a simple 1D tensor for now
assert idx.ndim == 1
# decode indices to text
text = self.encoder.decode(idx.tolist())
return text
if __name__ == '__main__':
tokenizer = BPETokenizer()
# here is an encoding example
while(True):
text = input("Please input a string!\n")
print("\nCodes:\n")
if (text == ""):
print("Exit!\n")
break
encoded = tokenizer(text);
print(encoded)
print("ready to feed into a Transformer!")
ret = input("\nPlease input the output of Transformer!\n")
out_codes = [eval(i) for i in ret.split("\t")]
print("\nDecoded text:\n")
print(tokenizer.decode(torch.Tensor(out_codes).squeeze()))
print("\n")