|
6 | 6 | #include <regex>
|
7 | 7 | #include <iostream>
|
8 | 8 | #include <iterator>
|
| 9 | +#include <queue> |
9 | 10 | #include <string>
|
10 | 11 | #include <math.h>
|
11 | 12 |
|
@@ -294,58 +295,146 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
294 | 295 | return tokens;
|
295 | 296 | }
|
296 | 297 |
|
297 |
| -// TODO: Calculate this constant from the vocabulary |
298 |
| -#define MAX_TOKEN_LEN 18 |
299 |
| -// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece |
300 |
| -std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { |
301 |
| - std::vector<gpt_vocab::id> res; |
302 |
| - std::vector<int> score; |
303 |
| - std::vector<gpt_vocab::id> prev; |
304 |
| - int len = text.length(); |
305 |
| - |
306 |
| - score.resize(len + 1); |
307 |
| - prev.resize(len + 1); |
308 |
| - |
309 |
| - // Forward pass |
310 |
| - for (int i = 0; i < len; i++) { |
311 |
| - int max_len = std::min(len - i, MAX_TOKEN_LEN); |
312 |
| - for (int sub_len = 1; sub_len <= max_len; sub_len++) { |
313 |
| - auto sub = text.substr(i, sub_len); |
314 |
| - auto token = vocab.token_to_id.find(sub); |
315 |
| - if (token != vocab.token_to_id.end()) { |
316 |
| - int token_score = sub.length() * sub.length(); |
317 |
| - int local_score = score[i] + token_score; |
318 |
| - int next = i + sub_len; |
319 |
| - if (score[next] < local_score) { |
320 |
| - score[next] = local_score; |
321 |
| - prev[next] = (*token).second; |
| 298 | +static size_t utf8_len(char src) { |
| 299 | + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; |
| 300 | + uint8_t highbits = static_cast<uint8_t>(src) >> 4; |
| 301 | + return lookup[highbits]; |
| 302 | +} |
| 303 | + |
| 304 | +struct llama_sp_symbol { |
| 305 | + using index = int; |
| 306 | + index prev; |
| 307 | + index next; |
| 308 | + std::string_view text; |
| 309 | +}; |
| 310 | + |
| 311 | +struct llama_sp_bigram { |
| 312 | + struct comparator { |
| 313 | + bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { |
| 314 | + return (l.score < r.score) || (l.score == r.score && l.left > r.left); |
| 315 | + } |
| 316 | + }; |
| 317 | + using queue_storage = std::vector<llama_sp_bigram>; |
| 318 | + using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>; |
| 319 | + llama_sp_symbol::index left; |
| 320 | + llama_sp_symbol::index right; |
| 321 | + float score; |
| 322 | + size_t size; |
| 323 | +}; |
| 324 | + |
| 325 | +struct llama_tokenizer { |
| 326 | + llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {} |
| 327 | + |
| 328 | + void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) { |
| 329 | + // split string into utf8 chars |
| 330 | + int index = 0; |
| 331 | + while (!text.empty()) { |
| 332 | + llama_sp_symbol sym; |
| 333 | + size_t char_len = std::min(text.size(), utf8_len(text.data()[0])); |
| 334 | + sym.text = std::string_view(text.data(), char_len); |
| 335 | + sym.prev = index - 1; |
| 336 | + text.remove_prefix(char_len); |
| 337 | + sym.next = text.empty() ? -1 : index + 1; |
| 338 | + index++; |
| 339 | + symbols_.emplace_back(std::move(sym)); |
| 340 | + } |
| 341 | + |
| 342 | + // seed the work queue with all possible 2-character tokens. |
| 343 | + for (size_t i = 1; i < symbols_.size(); ++i) { |
| 344 | + try_add_bigram(i - 1, i); |
| 345 | + } |
| 346 | + |
| 347 | + // keep substituting the highest frequency pairs for as long as we can. |
| 348 | + while (!work_queue_.empty()) { |
| 349 | + auto bigram = work_queue_.top(); |
| 350 | + work_queue_.pop(); |
| 351 | + |
| 352 | + auto & left_sym = symbols_[bigram.left]; |
| 353 | + auto & right_sym = symbols_[bigram.right]; |
| 354 | + |
| 355 | + // if one of the symbols already got merged, skip it. |
| 356 | + if (left_sym.text.empty() || right_sym.text.empty() || |
| 357 | + left_sym.text.size() + right_sym.text.size() != bigram.size) { |
| 358 | + continue; |
| 359 | + } |
| 360 | + |
| 361 | + // merge the right sym into the left one |
| 362 | + left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size()); |
| 363 | + right_sym.text = std::string_view(""); |
| 364 | + |
| 365 | + // remove the right sym from the chain |
| 366 | + left_sym.next = right_sym.next; |
| 367 | + if (right_sym.next >= 0) { |
| 368 | + symbols_[right_sym.next].prev = bigram.left; |
| 369 | + } |
| 370 | + |
| 371 | + // find more substitutions |
| 372 | + try_add_bigram(left_sym.prev, bigram.left); |
| 373 | + try_add_bigram(bigram.left, left_sym.next); |
| 374 | + } |
| 375 | + |
| 376 | + for (int i = 0; i != -1; i = symbols_[i].next) { |
| 377 | + auto& symbol = symbols_[i]; |
| 378 | + auto token = vocab_.token_to_id.find(std::string(symbol.text)); |
| 379 | + |
| 380 | + if (token == vocab_.token_to_id.end()) { |
| 381 | + // output any symbols that did not form tokens as bytes. |
| 382 | + for (int j = 0; j < symbol.text.size(); ++j) { |
| 383 | + gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3; |
| 384 | + output.push_back(token_id); |
322 | 385 | }
|
| 386 | + } else { |
| 387 | + output.push_back((*token).second); |
323 | 388 | }
|
324 | 389 | }
|
325 | 390 | }
|
326 | 391 |
|
327 |
| - // Backward pass |
328 |
| - int i = len; |
329 |
| - while (i > 0) { |
330 |
| - gpt_vocab::id token_id = prev[i]; |
331 |
| - if (token_id == 0) { |
332 |
| - // TODO: Return error or something more meaningful |
333 |
| - printf("failed to tokenize string!\n"); |
334 |
| - break; |
| 392 | +private: |
| 393 | + void try_add_bigram(int left, int right) { |
| 394 | + if (left == -1 || right == -1) { |
| 395 | + return; |
| 396 | + } |
| 397 | + |
| 398 | + std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size()); |
| 399 | + auto token = vocab_.token_to_id.find(std::string(text)); |
| 400 | + |
| 401 | + if (token == vocab_.token_to_id.end()) { |
| 402 | + return; |
335 | 403 | }
|
336 |
| - res.push_back(token_id); |
337 |
| - auto token = (*vocab.id_to_token.find(token_id)).second; |
338 |
| - i -= token.length(); |
| 404 | + |
| 405 | + auto score = vocab_.score.find((*token).second); |
| 406 | + |
| 407 | + if (score == vocab_.score.end()) { |
| 408 | + return; |
| 409 | + } |
| 410 | + |
| 411 | + llama_sp_bigram bigram; |
| 412 | + bigram.left = left; |
| 413 | + bigram.right = right; |
| 414 | + bigram.score = (*score).second; |
| 415 | + bigram.size = text.size(); |
| 416 | + work_queue_.push(bigram); |
339 | 417 | }
|
340 | 418 |
|
341 |
| - if (bos) { |
342 |
| - res.push_back(1); // TODO: replace with vocab.bos |
| 419 | + const gpt_vocab & vocab_; |
| 420 | + std::vector<llama_sp_symbol> symbols_; |
| 421 | + llama_sp_bigram::queue work_queue_; |
| 422 | +}; |
| 423 | + |
| 424 | +std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) { |
| 425 | + llama_tokenizer tokenizer(vocab); |
| 426 | + std::vector<gpt_vocab::id> output; |
| 427 | + |
| 428 | + if (text.size() == 0) { |
| 429 | + return output; |
343 | 430 | }
|
344 | 431 |
|
345 |
| - // Pieces are in reverse order so correct that |
346 |
| - std::reverse(res.begin(), res.end()); |
| 432 | + if (bos) { |
| 433 | + output.push_back(1); |
| 434 | + } |
347 | 435 |
|
348 |
| - return res; |
| 436 | + tokenizer.tokenize(text, output); |
| 437 | + return output; |
349 | 438 | }
|
350 | 439 |
|
351 | 440 | bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
|
0 commit comments