diff --git a/src/corebpe.rs b/src/corebpe.rs index 07fff28..1bee6fc 100644 --- a/src/corebpe.rs +++ b/src/corebpe.rs @@ -166,7 +166,7 @@ const MAX_NUM_THREADS: usize = 8; #[derive(Debug)] pub struct CoreBPE { - encoder: Arc, usize>>, + pub encoder: HashMap, usize>, special_tokens_encoder: HashMap, decoder: HashMap>, special_tokens_decoder: HashMap>, @@ -429,7 +429,7 @@ impl CoreBPE { impl CoreBPE { pub fn new( - encoder: Arc, usize>>, + encoder: HashMap, usize>, special_tokens_encoder: HashMap, pattern: &str, ) -> Result { diff --git a/src/encoding.rs b/src/encoding.rs index d77f6d7..7a9e0ca 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -13,8 +13,6 @@ pub struct Encoding { pub name: String, /// The regular expression pattern used to split text into pieces. pat_str: String, - /// The map from mergeable byte sequences to their ranks. - mergeable_ranks: Arc, usize>>, /// The maximum length of the keys in `mergeable_ranks`. mergeable_ranks_max_key_len: usize, /// All prefixes of the mergeable ranks. May or may not be tokens themselves! @@ -64,7 +62,7 @@ impl Encoding { pub fn new( name: &str, pat_str: &str, - mergeable_ranks: Arc, usize>>, + mergeable_ranks: HashMap, usize>, special_tokens: HashMap, explicit_n_vocab: Option, ) -> Result { @@ -113,7 +111,6 @@ impl Encoding { Ok(Self { name: name.to_string(), pat_str: pat_str.to_string(), - mergeable_ranks, mergeable_ranks_max_key_len, prefixes_of_mergeable_ranks, special_tokens, @@ -157,7 +154,7 @@ impl Encoding { if current_token.len() > 1 { new_current_token.clear(); new_current_token.push(current_token.pop().unwrap()); - while !self.mergeable_ranks.contains_key(¤t_token) { + while !self.core_bpe.encoder.contains_key(¤t_token) { if current_token.len() == 1 { break; } @@ -177,14 +174,14 @@ impl Encoding { } } - while !self.mergeable_ranks.contains_key(¤t_token) { + while !self.core_bpe.encoder.contains_key(¤t_token) { if current_token.len() == 0 { break; } if current_token.len() > 1 { new_current_token.clear(); new_current_token.push(current_token.pop().unwrap()); - while !self.mergeable_ranks.contains_key(¤t_token) { + while !self.core_bpe.encoder.contains_key(¤t_token) { if current_token.len() == 1 { break; } diff --git a/src/load.rs b/src/load.rs index 0b324d7..988ffe3 100644 --- a/src/load.rs +++ b/src/load.rs @@ -5,7 +5,6 @@ use sha2::Sha256; // call its methods without adding to the namespace. use base64::engine::general_purpose::STANDARD as BASE64; use base64::engine::Engine as _; -use std::sync::Arc; // define the error #[derive(Debug, Clone)] @@ -17,7 +16,7 @@ pub enum Error { pub fn load_tiktoken_bpe( tiktoken_bpe_contents: &[u8], shasum: &str, -) -> Result, usize>>, Error> { +) -> Result, usize>, Error> { // check the shasum let mut hasher = Sha256::new(); hasher.update(tiktoken_bpe_contents); @@ -43,5 +42,5 @@ pub fn load_tiktoken_bpe( map.insert(token, rank); } map.shrink_to_fit(); - Ok(Arc::new(map)) + Ok(map) }