Skip to content

Commit

Permalink
use real thread-local storage
Browse files Browse the repository at this point in the history
upstream code in openai/tiktoken is wrapped with PyO3 so they're concerned about short-lived python-land threads eating up memory

in our case, we have a fixed actor-thread-pool dedicated to tokenization so we don't need extra copies and hash collisions getting in the way
  • Loading branch information
tmm1 committed Oct 17, 2024
1 parent 2a6523f commit 57d81fe
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 37 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ base64 = "0.21.0"
thiserror = "1.0.38"
const-primes = "0.8.7"
odht = "0.3.1"
thread_local = "1.1.8"

[[bench]]
name = "bench"
Expand Down
48 changes: 11 additions & 37 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;
use rustc_hash::FxHashSet as HashSet;
use std::sync::Arc;
use thread_local::ThreadLocal;

pub type Rank = u32;

Expand Down Expand Up @@ -129,44 +127,26 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> V
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}

const MAX_NUM_THREADS: usize = 8;

#[derive(Debug)]
pub struct CoreBPE {
pub encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
decoder: HashMap<Rank, &'static [u8]>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
regex_tls: Arc<[Regex]>,
special_regex_tls: Arc<[Regex]>,
regex: Regex,
special_regex: Regex,
regex_tls: ThreadLocal<Regex>,
special_regex_tls: ThreadLocal<Regex>,
sorted_token_bytes: Vec<&'static [u8]>,
}

impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
// See performance notes above for what this is about
// It's also a little janky, please make a better version of it!
// However, it's nice that this doesn't leak memory to short-lived threads
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
self.regex_tls.get_or(|| self.regex.clone())
}

fn _get_tl_special_regex(&self) -> &Regex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
self.special_regex_tls.get_or(|| self.special_regex.clone())
}

fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
Expand Down Expand Up @@ -460,16 +440,10 @@ impl CoreBPE {
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: Arc::from(
(0..MAX_NUM_THREADS)
.map(|_| regex.clone())
.collect::<Vec<_>>(),
),
special_regex_tls: Arc::from(
(0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect::<Vec<_>>(),
),
regex,
special_regex,
regex_tls: ThreadLocal::new(),
special_regex_tls: ThreadLocal::new(),
sorted_token_bytes,
})
}
Expand Down

0 comments on commit 57d81fe

Please sign in to comment.