Skip to content

Commit

Permalink
remove unnecessary copies to reduce memory usage (#11)
Browse files Browse the repository at this point in the history
tiktoken-node goes from 194mb -> 140mb as shown by `heap -s
<nodejs_pid>` on macOS
  • Loading branch information
tmm1 authored Oct 14, 2024
1 parent d08521b commit 7386c29
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ fn hash_current_thread() -> usize {
u64::from(x) as usize
}

const MAX_NUM_THREADS: usize = 128;
const MAX_NUM_THREADS: usize = 8;

#[derive(Debug)]
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
encoder: Arc<HashMap<Vec<u8>, usize>>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
Expand Down Expand Up @@ -429,7 +429,7 @@ impl CoreBPE {

impl CoreBPE {
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
encoder: Arc<HashMap<Vec<u8>, usize>>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> Result<Self, fancy_regex::Error> {
Expand Down
4 changes: 2 additions & 2 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Encoding {
/// The regular expression pattern used to split text into pieces.
pat_str: String,
/// The map from mergeable byte sequences to their ranks.
mergeable_ranks: HashMap<Vec<u8>, usize>,
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
/// The maximum length of the keys in `mergeable_ranks`.
mergeable_ranks_max_key_len: usize,
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
Expand Down Expand Up @@ -64,7 +64,7 @@ impl Encoding {
pub fn new(
name: &str,
pat_str: &str,
mergeable_ranks: HashMap<Vec<u8>, usize>,
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
special_tokens: HashMap<String, usize>,
explicit_n_vocab: Option<usize>,
) -> Result<Self, EncodingError> {
Expand Down
5 changes: 3 additions & 2 deletions src/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use sha2::Sha256;
// call its methods without adding to the namespace.
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::engine::Engine as _;
use std::sync::Arc;

// define the error
#[derive(Debug, Clone)]
Expand All @@ -16,7 +17,7 @@ pub enum Error {
pub fn load_tiktoken_bpe(
tiktoken_bpe_contents: &[u8],
shasum: &str,
) -> Result<HashMap<Vec<u8>, usize>, Error> {
) -> Result<Arc<HashMap<Vec<u8>, usize>>, Error> {
// check the shasum
let mut hasher = Sha256::new();
hasher.update(tiktoken_bpe_contents);
Expand All @@ -42,5 +43,5 @@ pub fn load_tiktoken_bpe(
map.insert(token, rank);
}
map.shrink_to_fit();
Ok(map)
Ok(Arc::new(map))
}

0 comments on commit 7386c29

Please sign in to comment.