From 7f52f3dd300bb992f27f094a285cdabe906dfea0 Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:02:05 -0700 Subject: [PATCH 1/6] add replace spaces flag --- src/encoding.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/encoding.rs b/src/encoding.rs index d772dc4..b500d39 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -127,7 +127,11 @@ impl Encoding { self.core_bpe.encode_ordinary(text) } - pub fn estimate_num_tokens_no_special_tokens_fast(&self, text: &str) -> usize { + pub fn estimate_num_tokens_no_special_tokens_fast(&self, text: &str, replace_spaces_with_lower_one_eighth_block: bool = false) -> usize { + if replace_spaces_with_lower_one_eighth_block { + text = text.replace(" ", "\u{2581}"); + } + let mut token_count = 0; let mut current_token = Vec::new(); let mut current_token_hash: i64 = 0; From 196de6e20927f3f7bf322107788c93ce36875337 Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:12:46 -0700 Subject: [PATCH 2/6] fix the preprocessing --- src/encoding.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/encoding.rs b/src/encoding.rs index b500d39..6e15038 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -5,6 +5,7 @@ use rustc_hash::FxHashSet as HashSet; use std::sync::Arc; use thiserror::Error; use const_primes::is_prime; +use std::borrow::Cow; /// A struct that represents an encoding scheme based on byte-pair encoding (BPE). #[derive(Debug)] @@ -127,17 +128,19 @@ impl Encoding { self.core_bpe.encode_ordinary(text) } - pub fn estimate_num_tokens_no_special_tokens_fast(&self, text: &str, replace_spaces_with_lower_one_eighth_block: bool = false) -> usize { - if replace_spaces_with_lower_one_eighth_block { - text = text.replace(" ", "\u{2581}"); - } + pub fn estimate_num_tokens_no_special_tokens_fast(&self, text: &str, replace_spaces_with_lower_one_eighth_block: bool) -> usize { + let preprocessed_text = if replace_spaces_with_lower_one_eighth_block { + Cow::Owned(text.replace(" ", "\u{2581}")) + } else { + Cow::Borrowed(text) + }; let mut token_count = 0; let mut current_token = Vec::new(); let mut current_token_hash: i64 = 0; let mut new_current_token = Vec::new(); - for byte in text.bytes() { + for byte in preprocessed_text.bytes() { current_token.push(byte); current_token_hash = roll_hash(current_token_hash, byte); From e4795fbbd346a85565126a3043282b19406554c2 Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:26:05 -0700 Subject: [PATCH 3/6] better --- src/encoding.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/encoding.rs b/src/encoding.rs index 6e15038..313434d 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -129,18 +129,22 @@ impl Encoding { } pub fn estimate_num_tokens_no_special_tokens_fast(&self, text: &str, replace_spaces_with_lower_one_eighth_block: bool) -> usize { - let preprocessed_text = if replace_spaces_with_lower_one_eighth_block { - Cow::Owned(text.replace(" ", "\u{2581}")) + let tokens = if replace_spaces_with_lower_one_eighth_block { + self.count_tokens(&text.replace(" ", "\u{2581}")) } else { - Cow::Borrowed(text) + self.count_tokens(text) }; + tokens + } + + fn count_tokens(&self, text: &str) -> usize { let mut token_count = 0; let mut current_token = Vec::new(); let mut current_token_hash: i64 = 0; let mut new_current_token = Vec::new(); - for byte in preprocessed_text.bytes() { + for byte in text.bytes() { current_token.push(byte); current_token_hash = roll_hash(current_token_hash, byte); From caafbc20c3fc850e2743a8b20a0d235ba26ca5ea Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:52:35 -0700 Subject: [PATCH 4/6] remove unused import --- src/encoding.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/encoding.rs b/src/encoding.rs index 313434d..d3b9b7e 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -5,7 +5,6 @@ use rustc_hash::FxHashSet as HashSet; use std::sync::Arc; use thiserror::Error; use const_primes::is_prime; -use std::borrow::Cow; /// A struct that represents an encoding scheme based on byte-pair encoding (BPE). #[derive(Debug)] From bfe4968594df7cef7b5eb71985f99291312eafdd Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:55:26 -0700 Subject: [PATCH 5/6] add flag to tests too --- src/tests.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tests.rs b/src/tests.rs index 8647ef8..a3fe6aa 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -205,7 +205,7 @@ fn estimation_is_close() { .unwrap() .len(); - let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(file); + let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(file, false); println!("Real count: {}", real_count); println!("Estimated count: {}", estimated_count); @@ -229,7 +229,7 @@ fn simple_estimation_is_close() { .unwrap() .len(); - let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(&test); + let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(&test, false); println!("Real count: {}", real_count); println!("Estimated count: {}", estimated_count); @@ -290,7 +290,7 @@ fn estimation_is_close_o200k() { .unwrap() .len(); - let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(file); + let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(file, false); println!("Real count: {}", real_count); println!("Estimated count: {}", estimated_count); @@ -314,7 +314,7 @@ fn simple_estimation_is_close_o200k() { .unwrap() .len(); - let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(&test); + let estimated_count = enc.estimate_num_tokens_no_special_tokens_fast(&test, false); println!("Real count: {}", real_count); println!("Estimated count: {}", estimated_count); From 38bfee2ae47c989d9985a5334b8f717a383fe0b3 Mon Sep 17 00:00:00 2001 From: Rishabh Yadav Date: Sun, 29 Sep 2024 18:57:42 -0700 Subject: [PATCH 6/6] also here! --- benches/bench.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/benches/bench.rs b/benches/bench.rs index c6a04bb..bf3fa39 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -9,6 +9,7 @@ fn cl100k_base_benchmark(c: &mut Criterion) { b.iter(|| { black_box(x.estimate_num_tokens_no_special_tokens_fast( &t, + false )); }); });