Skip to content

Commit

Permalink
remove Arc<> for encoder hashmap and use directly (#12)
Browse files Browse the repository at this point in the history
partial revert of #11

was seeing weird perf regression in benchmarks

rust borrow checker makes it really hard for two things to own the same
data, so take the simple way out by letting CoreBPE own mergeable_ranks
  • Loading branch information
tmm1 authored Oct 17, 2024
1 parent 617cd6a commit 615bf1f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ const MAX_NUM_THREADS: usize = 8;

#[derive(Debug)]
pub struct CoreBPE {
encoder: Arc<HashMap<Vec<u8>, usize>>,
pub encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
Expand Down Expand Up @@ -429,7 +429,7 @@ impl CoreBPE {

impl CoreBPE {
pub fn new(
encoder: Arc<HashMap<Vec<u8>, usize>>,
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> Result<Self, fancy_regex::Error> {
Expand Down
11 changes: 4 additions & 7 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ pub struct Encoding {
pub name: String,
/// The regular expression pattern used to split text into pieces.
pat_str: String,
/// The map from mergeable byte sequences to their ranks.
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
/// The maximum length of the keys in `mergeable_ranks`.
mergeable_ranks_max_key_len: usize,
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
Expand Down Expand Up @@ -64,7 +62,7 @@ impl Encoding {
pub fn new(
name: &str,
pat_str: &str,
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
mergeable_ranks: HashMap<Vec<u8>, usize>,
special_tokens: HashMap<String, usize>,
explicit_n_vocab: Option<usize>,
) -> Result<Self, EncodingError> {
Expand Down Expand Up @@ -113,7 +111,6 @@ impl Encoding {
Ok(Self {
name: name.to_string(),
pat_str: pat_str.to_string(),
mergeable_ranks,
mergeable_ranks_max_key_len,
prefixes_of_mergeable_ranks,
special_tokens,
Expand Down Expand Up @@ -157,7 +154,7 @@ impl Encoding {
if current_token.len() > 1 {
new_current_token.clear();
new_current_token.push(current_token.pop().unwrap());
while !self.mergeable_ranks.contains_key(&current_token) {
while !self.core_bpe.encoder.contains_key(&current_token) {
if current_token.len() == 1 {
break;
}
Expand All @@ -177,14 +174,14 @@ impl Encoding {
}
}

while !self.mergeable_ranks.contains_key(&current_token) {
while !self.core_bpe.encoder.contains_key(&current_token) {
if current_token.len() == 0 {
break;
}
if current_token.len() > 1 {
new_current_token.clear();
new_current_token.push(current_token.pop().unwrap());
while !self.mergeable_ranks.contains_key(&current_token) {
while !self.core_bpe.encoder.contains_key(&current_token) {
if current_token.len() == 1 {
break;
}
Expand Down
5 changes: 2 additions & 3 deletions src/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use sha2::Sha256;
// call its methods without adding to the namespace.
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::engine::Engine as _;
use std::sync::Arc;

// define the error
#[derive(Debug, Clone)]
Expand All @@ -17,7 +16,7 @@ pub enum Error {
pub fn load_tiktoken_bpe(
tiktoken_bpe_contents: &[u8],
shasum: &str,
) -> Result<Arc<HashMap<Vec<u8>, usize>>, Error> {
) -> Result<HashMap<Vec<u8>, usize>, Error> {
// check the shasum
let mut hasher = Sha256::new();
hasher.update(tiktoken_bpe_contents);
Expand All @@ -43,5 +42,5 @@ pub fn load_tiktoken_bpe(
map.insert(token, rank);
}
map.shrink_to_fit();
Ok(Arc::new(map))
Ok(map)
}

0 comments on commit 615bf1f

Please sign in to comment.