Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses Regex instead of fancy-regex - 6x speedup #331

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::collections::HashSet;
use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use fancy_regex::Regex as FancyRegex;
use regex::Regex as Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
Expand Down Expand Up @@ -89,7 +90,7 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> V
.collect()
}

// Various performance notes:
// Various performance notes (should be updated, also PERFORMANCE.md is removed now):
//
// Regex
// =====
Expand Down Expand Up @@ -154,7 +155,7 @@ struct CoreBPE {
decoder: HashMap<Rank, Vec<u8>>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
special_regex_tls: Vec<FancyRegex>,
sorted_token_bytes: Vec<Vec<u8>>,
}

Expand All @@ -166,7 +167,7 @@ impl CoreBPE {
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

fn _get_tl_special_regex(&self) -> &Regex {
fn _get_tl_special_regex(&self) -> &FancyRegex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

Expand All @@ -183,23 +184,82 @@ impl CoreBPE {
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
// This wrapper function is needed for those callers that do not pass ret.
let mut ret = vec![];
self._encode_ordinary_native_impl(text, &mut ret);
ret
}

fn _encode_ordinary_native_impl(&self, text: &str, ret: &mut Vec<Rank>) -> usize {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut last_end = 0;
let mut last_piece_token_len = 0;
let mut piece:&[u8] = &[];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
piece = mat.as_str().as_bytes();
let start = mat.start();
let end = mat.end();

// If there is a whitespace gap between peice and the previous piece, add its tokens
if last_end < start {
// If current piece starts with a whitespace, the whole gap is one new piece
if mat.as_str().chars().next().map_or(false, |c| c.is_whitespace()) {
let wpiece = text[last_end..start].as_bytes();
match self.encoder.get(wpiece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece, &self.encoder)),
}
// otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
} else {
let last_char_size = &text[last_end..start].chars().next_back().unwrap().len_utf8();
// Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
// so the gap makes just one piece
if last_char_size < &(start - last_end) {
let wpiece1 = text[last_end..start - last_char_size].as_bytes();
match self.encoder.get(wpiece1) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece1, &self.encoder)),
}
}
let wpiece2 = text[start - last_char_size..start].as_bytes();
match self.encoder.get(wpiece2) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece2, &self.encoder)),
}
}
}
last_end = end;

// Now add piece tokens
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}
ret
// Gap of whitespaces at the end of text
if last_end < text.len() {
piece = text[last_end..text.len()].as_bytes();
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}

if !piece.is_empty() {
last_piece_token_len =
match self.encoder.get(piece){
Some(token) => 1,
None => byte_pair_encode(piece, &self.encoder).len()
};
};

last_piece_token_len
}

fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];

let mut start = 0;
Expand All @@ -221,18 +281,9 @@ impl CoreBPE {
}
}
let end = next_special.map_or(text.len(), |m| m.start());

// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
if end > start {
// regex is not created and passed here, but it seems harmless.
last_piece_token_len = self._encode_ordinary_native_impl(&text[start..end], &mut ret);
}

match next_special {
Expand Down Expand Up @@ -425,7 +476,7 @@ impl CoreBPE {
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
FancyRegex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
};

Expand Down
16 changes: 9 additions & 7 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ENDOFPROMPT = "<|endofprompt|>"


# We drop |\s+(?!\S)|\s+ from end of all patterns, implementing it with scripting in src/lib.rs, to make the patterns compatible with Rust's Regex module
def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
Expand All @@ -20,7 +21,7 @@ def gpt2():
# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -34,7 +35,7 @@ def r50k_base():
return {
"name": "r50k_base",
"explicit_n_vocab": 50257,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -48,7 +49,7 @@ def p50k_base():
return {
"name": "p50k_base",
"explicit_n_vocab": 50281,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -62,7 +63,7 @@ def p50k_edit():
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
"name": "p50k_edit",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand All @@ -82,7 +83,10 @@ def cl100k_base():
}
return {
"name": "cl100k_base",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
# The original pattern uses possessive quantifiers ?+ and ++ which seem to be not supported by Rust's Regex module:
# r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# It turns out that using the greedy version of quantifiers, we get a mathematically equivallent pattern:
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand All @@ -105,8 +109,6 @@ def o200k_base():
r"""\p{N}{1,3}""",
r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
r"""\s*[\r\n]+""",
r"""\s+(?!\S)""",
r"""\s+""",
]
)
return {
Expand Down