diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 7f63e60aa..16ba32f03 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" itertools = "0.12" +ahash = { version = "0.8.11", features = ["serde"] } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 424be9f57..7233fed14 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,9 +1,12 @@ use std::collections::HashMap; +use std::hash::Hash; +use std::ops::{Deref, DerefMut}; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use crate::token::PyToken; use crate::trainers::PyTrainer; +use ahash::AHashMap; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -31,6 +34,53 @@ pub struct PyModel { pub model: Arc>, } +// Newtype wrapper for AHashMap +#[derive(Clone, Debug)] +pub struct PyAHashMap(pub AHashMap); + +impl IntoPy for PyAHashMap +where + K: IntoPy + Eq + Hash, + V: IntoPy, +{ + fn into_py(self, py: Python<'_>) -> PyObject { + let dict = PyDict::new_bound(py); + for (k, v) in self.0 { + dict.set_item(k.into_py(py), v.into_py(py)).unwrap(); + } + dict.into() + } +} + +impl<'source, K, V> FromPyObject<'source> for PyAHashMap +where + K: FromPyObject<'source> + Eq + Hash, + V: FromPyObject<'source>, +{ + fn extract(ob: &'source PyAny) -> PyResult { + let dict = ob.downcast::()?; + let mut map = AHashMap::new(); + for (k, v) in dict.iter() { + map.insert(K::extract(k)?, V::extract(v)?); + } + Ok(PyAHashMap(map)) + } +} + +impl Deref for PyAHashMap { + type Target = AHashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for PyAHashMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl PyModel { pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { let base = self.clone(); @@ -62,6 +112,10 @@ impl Model for PyModel { self.model.read().unwrap().get_vocab() } + fn get_vocab_ahash(&self) -> AHashMap { + self.model.read().unwrap().get_vocab_ahash() + } + fn get_vocab_size(&self) -> usize { self.model.read().unwrap().get_vocab_size() } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 25396f192..3b1cc1171 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -68,6 +68,9 @@ fancy-regex = { version = "0.13", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +ahash = { version = "0.8.11", features = ["serde"] } +dary_heap = { version = "0.3.6", features = ["serde"] } +compact_str = { version = "0.8.0", features = ["serde"] } [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 1585da761..977e251b8 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,19 +2,21 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; +use ahash::AHashMap; use serde_json::Value; use std::borrow::Cow; + +use std::collections::HashMap; use std::{ - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -pub type Vocab = HashMap; -type VocabR = HashMap; -pub type MergeMap = HashMap; +pub type Vocab = AHashMap; +type VocabR = AHashMap; +pub type MergeMap = AHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -41,7 +43,7 @@ impl Default for BpeBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), + vocab: AHashMap::new(), merges: vec![], cache_capacity: DEFAULT_CACHE_CAPACITY, dropout: None, @@ -324,7 +326,7 @@ impl BPE { let mut buffer = String::new(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; - let mut vocab = HashMap::new(); + let mut vocab = AHashMap::new(); match json { Value::Object(m) => { for (token, id) in m { @@ -354,7 +356,11 @@ impl BPE { } } - pub fn get_vocab(&self) -> Vocab { + pub fn get_vocab(&self) -> HashMap { + self.vocab.clone().into_iter().collect() + } + + pub fn get_vocab_ahash(&self) -> AHashMap { self.vocab.clone() } @@ -481,6 +487,10 @@ impl Model for BPE { type Trainer = BpeTrainer; fn get_vocab(&self) -> HashMap { + self.vocab.clone().into_iter().collect() + } + + fn get_vocab_ahash(&self) -> AHashMap { self.vocab.clone() } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cc15102..98cf54944 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,10 +1,10 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use ahash::AHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; -use std::collections::HashMap; impl Serialize for BPE { fn serialize(&self, serializer: S) -> Result @@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor { V: MapAccess<'de>, { let mut builder = BpeBuilder::new(); - let mut vocab: Option> = None; + let mut vocab: Option> = None; #[derive(Debug, Deserialize)] #[serde(untagged)] diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 3689a856a..09e62be2c 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,15 +4,17 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use ahash::{AHashMap, AHashSet}; +use compact_str::CompactString; +use dary_heap::OctonaryHeap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap, HashSet}; #[derive(Debug, Eq)] struct Merge { pair: Pair, count: u64, - pos: HashSet, + pos: AHashSet, } impl PartialEq for Merge { fn eq(&self, other: &Self) -> bool { @@ -41,7 +43,7 @@ struct Config { show_progress: bool, special_tokens: Vec, limit_alphabet: Option, - initial_alphabet: HashSet, + initial_alphabet: AHashSet, continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, @@ -62,7 +64,7 @@ impl Default for BpeTrainerBuilder { show_progress: true, special_tokens: vec![], limit_alphabet: None, - initial_alphabet: HashSet::new(), + initial_alphabet: AHashSet::new(), continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, @@ -114,7 +116,7 @@ impl BpeTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + pub fn initial_alphabet(mut self, alphabet: AHashSet) -> Self { self.config.initial_alphabet = alphabet; self } @@ -151,7 +153,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, - words: HashMap::new(), + words: AHashMap::new(), } } } @@ -187,7 +189,7 @@ pub struct BpeTrainer { pub limit_alphabet: Option, /// The initial alphabet we want absolutely to include. This allows to cover /// some characters that are not necessarily in the training set - pub initial_alphabet: HashSet, + pub initial_alphabet: AHashSet, /// An optional prefix to use on any subword that exist only behind another one pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword @@ -195,7 +197,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: AHashMap, } impl Default for BpeTrainer { @@ -251,11 +253,16 @@ impl BpeTrainer { } /// Add the provided special tokens to the initial vocabulary - fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + fn add_special_tokens( + &self, + w2id: &mut AHashMap, + id2w: &mut Vec, + ) { for token in &self.special_tokens { - if !w2id.contains_key(&token.content) { - id2w.push(token.content.to_owned()); - w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32); + // get hash of content + if !w2id.contains_key(&CompactString::from(&token.content)) { + id2w.push(CompactString::from(&token.content)); + w2id.insert(CompactString::from(&token.content), (id2w.len() - 1) as u32); } } } @@ -263,12 +270,12 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &AHashMap, + w2id: &mut AHashMap, + id2w: &mut Vec, ) { // Compute the alphabet from seen words - let mut alphabet: HashMap = HashMap::new(); + let mut alphabet: AHashMap = AHashMap::new(); for (word, count) in wc { for c in word.chars() { alphabet @@ -312,19 +319,26 @@ impl BpeTrainer { kept.sort_unstable_by_key(|k| (*k.0) as u32); kept.into_iter().for_each(|(c, _)| { let s = c.to_string(); + /* if !w2id.contains_key(&s) { id2w.push(s.clone()); w2id.insert(s, (id2w.len() - 1) as u32); } + */ + // u64 hash version + if !w2id.contains_key(&CompactString::from(&s)) { + id2w.push(CompactString::from(&s)); + w2id.insert(CompactString::from(&s), (id2w.len() - 1) as u32); + } }); } /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &AHashMap, + w2id: &mut AHashMap, + id2w: &mut Vec, p: &Option, ) -> (Vec, Vec) { let mut words: Vec = Vec::with_capacity(wc.len()); @@ -336,7 +350,7 @@ impl BpeTrainer { for (is_first, is_last, c) in word.chars().with_first_and_last() { let mut s = c.to_string(); - if w2id.contains_key(&s) { + if w2id.contains_key(&CompactString::from(&s)) { // Found the initial char in the authorized alphabet // Add the `continuing_subword_prefix` if relevant @@ -353,11 +367,11 @@ impl BpeTrainer { } // Insert the new formed string if necessary - if !w2id.contains_key(&s) { - id2w.push(s.clone()); - w2id.insert(s.clone(), (id2w.len() - 1) as u32); + if !w2id.contains_key(&CompactString::from(&s)) { + id2w.push(CompactString::from(&s)); + w2id.insert(CompactString::from(&s), (id2w.len() - 1) as u32); } - current_word.add(w2id[&s], 1); // We do not care about the len here + current_word.add(w2id[&CompactString::from(&s)], 1); // We do not care about the len here } } words.push(current_word); @@ -375,13 +389,13 @@ impl BpeTrainer { words: &[Word], counts: &[u64], p: &Option, - ) -> (HashMap, HashMap>) { + ) -> (AHashMap, AHashMap>) { words .maybe_par_iter() .enumerate() .map(|(i, word)| { - let mut pair_counts = HashMap::new(); - let mut where_to_update: HashMap> = HashMap::new(); + let mut pair_counts = AHashMap::new(); + let mut where_to_update: AHashMap> = AHashMap::new(); for window in word.get_chars().windows(2) { let cur_pair: Pair = (window[0], window[1]); @@ -399,7 +413,7 @@ impl BpeTrainer { h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = AHashSet::new(); h.insert(i); h }); @@ -413,7 +427,7 @@ impl BpeTrainer { (pair_counts, where_to_update) }) .reduce( - || (HashMap::new(), HashMap::new()), + || (AHashMap::new(), AHashMap::new()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); @@ -431,11 +445,11 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &AHashMap, model: &mut BPE, ) -> Result> { - let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); - let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); + let mut word_to_id: AHashMap = AHashMap::with_capacity(self.vocab_size); + let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); let progress = self.setup_progress(); @@ -464,7 +478,7 @@ impl BpeTrainer { self.update_progress(&progress, words.len(), "Count pairs"); let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress); // Insert them in the queue - let mut queue = BinaryHeap::with_capacity(pair_counts.len()); + let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; if count > 0 { @@ -510,7 +524,7 @@ impl BpeTrainer { if let Some(prefix) = &self.continuing_subword_prefix { if part_b.starts_with(prefix) { let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); - part_b = part_b[prefix_byte_len..].to_string(); + part_b = CompactString::from(&part_b[prefix_byte_len..]); } } let new_token = format!("{}{}", part_a, part_b); @@ -520,12 +534,12 @@ impl BpeTrainer { // Insert new token if it does not already exist let new_token_id = word_to_id - .get(&new_token) + .get(&CompactString::from(&new_token)) .copied() .unwrap_or(id_to_word.len() as u32); - if !word_to_id.contains_key(&new_token) { - id_to_word.push(new_token.clone()); - word_to_id.insert(new_token.clone(), new_token_id); + if !word_to_id.contains_key(&CompactString::from(&new_token)) { + id_to_word.push(CompactString::from(&new_token)); + word_to_id.insert(CompactString::from(&new_token), new_token_id); } merges.push((top.pair, new_token_id)); @@ -536,7 +550,7 @@ impl BpeTrainer { .flat_map(|&i| { let word = &words[i] as *const _ as *mut Word; // We can merge each of these words in parallel here because each position - // can be there only once (HashSet). So this is safe. + // can be there only once (AHashSet). So this is safe. unsafe { // let word: &mut Word = &mut (*word); (*word) @@ -562,7 +576,7 @@ impl BpeTrainer { h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = AHashSet::new(); h.insert(iw); h }); @@ -586,7 +600,12 @@ impl BpeTrainer { self.finalize_progress(&progress, merges.len()); // Transfer new vocab & options to model - model.vocab = word_to_id; + //model.vocab = word_to_id; + model.vocab = word_to_id + .into_iter() + // we have to look up the string in id_to_word because the key in word_to_id is a hash + .map(|(_key, val)| (id_to_word[val as usize].to_string(), val)) + .collect(); model.vocab_r = model .vocab .iter() @@ -632,18 +651,20 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = AHashMap::new(); for word in words { - map.entry(word).and_modify(|c| *c += 1).or_insert(1); + map.entry(CompactString::from(word)) + .and_modify(|c| *c += 1) + .or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(AHashMap::new()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -661,11 +682,12 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; - use std::collections::HashMap; + use ahash::AHashMap; + use compact_str::CompactString; #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: AHashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -690,7 +712,7 @@ mod tests { // Vocab should contain all of the characters from the `word_counts` mapping // as well as three merges: 're', 'are', and 'is'. - let expected_vocab: HashMap = [ + let expected_vocab: AHashMap = [ ("-".into(), 0), ("2".into(), 1), ("B".into(), 2), @@ -726,7 +748,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: HashMap = [ + let expected_merges: AHashMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' @@ -744,7 +766,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: AHashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -759,7 +781,7 @@ mod tests { ("GPT-2", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (CompactString::from(key.to_string()), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(max_token_length)) @@ -784,7 +806,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: AHashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), @@ -799,7 +821,7 @@ mod tests { ("GP", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (CompactString::from(key.to_string()), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(2)) @@ -808,8 +830,8 @@ mod tests { .build(); let mut model = BPE::default(); trainer.do_train(&long_word_counts, &mut model).unwrap(); - let trained_vocab: HashMap = model.get_vocab(); - let expected_vocab: HashMap = [ + let trained_vocab: AHashMap = model.get_vocab_ahash(); + let expected_vocab: AHashMap = [ ("短", 12), ("n", 6), ("i", 5), diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 6fc8033e3..df9d32a94 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,7 +1,8 @@ use super::Pair; +use ahash::AHashMap; +use dary_heap::QuaternaryHeap; use rand::{thread_rng, Rng}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap}; #[derive(Debug, Eq)] struct Merge { @@ -158,8 +159,8 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &HashMap, dropout: Option) { - let mut queue = BinaryHeap::with_capacity(self.symbols.len()); + pub(super) fn merge_all(&mut self, merges: &AHashMap, dropout: Option) { + let mut queue = QuaternaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); queue.extend( diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index cdfb731a8..8692e2502 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,6 +5,7 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; +use ahash::AHashMap; use std::collections::HashMap; use std::path::{Path, PathBuf}; @@ -19,11 +20,11 @@ use crate::{AddedToken, Model, Result, Token, Trainer}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, + vocab_r: &'a AHashMap, } impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { + fn new(vocab_r: &'a AHashMap) -> Self { Self { vocab_r } } } @@ -179,6 +180,15 @@ impl Model for ModelWrapper { } } + fn get_vocab_ahash(&self) -> AHashMap { + match self { + Self::WordLevel(t) => t.get_vocab_ahash(), + Self::WordPiece(t) => t.get_vocab_ahash(), + Self::BPE(t) => t.get_vocab_ahash(), + Self::Unigram(t) => t.get_vocab_ahash(), + } + } + fn get_vocab_size(&self) -> usize { match self { Self::WordLevel(t) => t.get_vocab_size(), @@ -284,8 +294,8 @@ mod tests { #[test] fn incomplete_ordered_vocab() { - let vocab_r: HashMap = - HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + let vocab_r: AHashMap = + AHashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); let ordered = OrderedVocabIter::new(&vocab_r); diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index 30b82245d..710112916 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -1,13 +1,13 @@ +use dary_heap::QuaternaryHeap; use rand::distributions::WeightedIndex; use rand::prelude::*; use std::cell::RefCell; use std::cmp::{min, Ordering}; -use std::collections::BinaryHeap; use std::rc::Rc; type NodeRef = Rc>; type HypothesisRef = Rc>; -type Agenda = BinaryHeap; +type Agenda = QuaternaryHeap; struct Hypothesis { node_ref: NodeRef, @@ -240,7 +240,7 @@ impl<'a> Lattice<'a> { 1 => vec![self.viterbi()], _ => { // let k_reserved_hypothesis_size = 512; - let mut agenda: Agenda = BinaryHeap::new(); + let mut agenda: Agenda = QuaternaryHeap::new(); let mut hypotheses: Vec> = vec![]; let eos = self.eos_node(); let score = eos.borrow().score; @@ -282,7 +282,7 @@ impl<'a> Lattice<'a> { let k_max_agenda_size = 100_000; let k_min_agenda_size = 512; if agenda.len() > k_max_agenda_size { - let mut new_agenda = BinaryHeap::new(); + let mut new_agenda = QuaternaryHeap::new(); let len = min(k_min_agenda_size, n * 10); for _i in 0..len { new_agenda.push(agenda.pop().unwrap()); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index defc7d93d..ad406d493 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -5,13 +5,14 @@ use super::{ }; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; - use std::collections::HashMap; + +use ahash::AHashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -type TokenMap = HashMap; +type TokenMap = AHashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. @@ -98,7 +99,7 @@ impl Unigram { byte_fallback: bool, ) -> Result { let n = vocab.len(); - let mut token_to_ids: TokenMap = HashMap::new(); + let mut token_to_ids: TokenMap = AHashMap::new(); let mut builder = TrieBuilder::default(); if let Some(unk_id) = unk_id { @@ -404,6 +405,10 @@ impl Model for Unigram { type Trainer = UnigramTrainer; fn get_vocab(&self) -> HashMap { + self.token_to_ids.clone().into_iter().collect() + } + + fn get_vocab_ahash(&self) -> AHashMap { self.token_to_ids.clone() } diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..e61e3898d 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,10 +2,10 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use ahash::{AHashMap, AHashSet}; use log::debug; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; use std::convert::TryInto; // A token and a score @@ -57,8 +57,8 @@ pub struct UnigramTrainer { pub shrinking_factor: f64, #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "HashSet::new()")] - pub initial_alphabet: HashSet, + #[builder(default = "AHashSet::new()")] + pub initial_alphabet: AHashSet, #[builder(default = "None")] pub unk_token: Option, @@ -67,8 +67,8 @@ pub struct UnigramTrainer { pub max_piece_length: usize, #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "HashMap::new()")] - words: HashMap, + #[builder(default = "AHashMap::new()")] + words: AHashMap, } impl Default for UnigramTrainer { @@ -110,17 +110,17 @@ impl UnigramTrainer { true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { + fn finalize(&self, model: Unigram, required_chars: AHashSet) -> Result { let mut min_score_penalty = 0.0; let min_score_penalty_delta = 0.0001; let mut pieces: Vec<(String, f64)> = vec![]; - let mut inserted: HashSet = HashSet::new(); + let mut inserted: AHashSet = AHashSet::new(); // We don't want to include the that was used to train inserted.insert("".into()); - let existing_pieces: HashMap = model.iter().cloned().collect(); + let existing_pieces: AHashMap = model.iter().cloned().collect(); for c in required_chars { if let Some(t) = existing_pieces.get(&c) { inserted.insert(c.clone()); @@ -185,7 +185,7 @@ impl UnigramTrainer { ) } - fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { + fn required_chars(&self, word_counts: &[Sentence]) -> AHashSet { word_counts .iter() .flat_map(|(s, _count)| s.chars()) @@ -205,7 +205,7 @@ impl UnigramTrainer { .sum::() + sentences.len(); let mut flat_string = String::with_capacity(total); - let mut all_chars: HashMap = HashMap::new(); + let mut all_chars: AHashMap = AHashMap::new(); let c_sentence_boundary = '\0'; let k_sentence_boundary = '\0'.to_string(); for (string, n) in sentences { @@ -631,18 +631,18 @@ impl Trainer for UnigramTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = AHashMap::new(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(AHashMap::new()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -716,7 +716,7 @@ mod tests { fn test_initial_alphabet() { let trainer = UnigramTrainerBuilder::default() .show_progress(false) - .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) + .initial_alphabet(AHashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) .build() .unwrap(); @@ -727,7 +727,7 @@ mod tests { vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"] .into_iter() .map(|s| s.to_owned()) - .collect::>() + .collect::>() ); } diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index 2f94b1766..dd06f7f02 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use ahash::AHashMap; use std::hash::Hash; #[derive(Default)] @@ -78,14 +78,14 @@ impl