From e85ce9082ae38a1a3c30afcca5ffc9ed38af0cd1 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 8 Nov 2024 17:35:53 +0000 Subject: [PATCH] Extend vocabulary with eos token id & pretrained models --- src/prelude.rs | 6 -- src/regex.rs | 1 + src/vocabulary.rs | 201 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 187 insertions(+), 21 deletions(-) diff --git a/src/prelude.rs b/src/prelude.rs index e196e474..d42516b9 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -2,9 +2,3 @@ pub use super::{ primitives::{State, Token, TokenId, TransitionKey}, vocabulary::Vocabulary, }; - -pub(crate) use std::{ - collections::{HashMap, HashSet}, - fmt::{self, Display}, - ops::Deref, -}; diff --git a/src/regex.rs b/src/regex.rs index a41bf862..b5658191 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use std::collections::{HashMap, HashSet}; pub fn walk_fsm( fsm_transitions: &HashMap<(State, TransitionKey), State>, diff --git a/src/vocabulary.rs b/src/vocabulary.rs index f03df8f7..438e94b5 100644 --- a/src/vocabulary.rs +++ b/src/vocabulary.rs @@ -1,4 +1,12 @@ +use std::collections::HashMap; + +use tokenizers::normalizers::Sequence; +use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; + +use crate::locator::EosTokenLocator; use crate::prelude::*; +use crate::processor::TokenProcessor; +use crate::VocabularyError; /// Vocabulary of an LLM. /// @@ -7,19 +15,116 @@ use crate::prelude::*; /// ```rust /// # use outlines_core::prelude::*; /// # -/// let vocabulary = Vocabulary::new() +/// let vocabulary = Vocabulary::new(None) /// .insert("blah", 0) /// .insert("1a", 1) /// .insert("2", 2) /// .insert("0", 3); /// ``` #[derive(Clone, Debug, Default)] -pub struct Vocabulary(pub(crate) HashMap>); +pub struct Vocabulary { + // TODO: Option is temp for back compatibility + eos_token_id: Option, + map: HashMap>, +} impl Vocabulary { /// Creates an empty vocabulary. - pub fn new() -> Vocabulary { - Vocabulary::default() + pub fn new(eos_token_id: Option) -> Self { + Self { + eos_token_id, + map: HashMap::new(), + } + } + + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. + pub fn from_pretrained( + model: &str, + parameters: Option, + ) -> Result { + let mut tokenizer = + Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { + VocabularyError::UnableToCreateTokenizer { + model: model.to_string(), + source: error, + } + })?; + Self::filter_normalizers(&mut tokenizer); + + let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); + let Some(eos_token_id) = eos_token_id else { + return Err(VocabularyError::UnableToLocateEosTokenId { + model: model.to_string(), + }); + }; + + Vocabulary::try_from((&mut tokenizer, eos_token_id)) + } + + /// Per provided token returns vector of `TokenId`s if available in vocabulary. + pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { + self.map.get(token) + } + + /// Gets the identifier of the special end of sentence token. + pub fn eos_token_id(&self) -> Option { + self.eos_token_id + } + + fn filter_normalizers(tokenizer: &mut Tokenizer) { + // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece + // In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text, + // e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.] + // + // We don't want to deal with the special characters, so we remove `Prepend` normalizers. + if let Some(normalizer) = tokenizer.get_normalizer() { + match normalizer { + NormalizerWrapper::Sequence(normalization_sequence) => { + let new_sequence = Sequence::new( + normalization_sequence + .get_normalizers() + .iter() + .filter_map(|normalizer| match normalizer { + NormalizerWrapper::Prepend(_) => None, + _ => Some(normalizer.clone()), + }) + .collect(), + ); + tokenizer.with_normalizer(new_sequence.into()); + } + NormalizerWrapper::Prepend(_) => { + tokenizer.with_normalizer(None::); + } + _ => {} + } + } + } +} + +impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { + type Error = VocabularyError; + + fn try_from(value: (&mut Tokenizer, u32)) -> Result { + let (tokenizer, eos_token_id) = value; + + let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { + if !added_token.special { + vocabulary = vocabulary.insert(added_token.content.clone(), *id); + } + } + + let processor = TokenProcessor::new(tokenizer)?; + for (token, token_id) in tokenizer.get_vocab(false) { + let token_bytes = processor.process(token)?; + // TODO: lossy is temp: + // - in python in was handled by byte_symbol function + // - interface needs to be redefined to treat Token type as bytes: Vec + let processed_token = String::from_utf8_lossy(&token_bytes); + vocabulary = vocabulary.insert(processed_token, token_id); + } + + Ok(vocabulary) } } @@ -43,8 +148,9 @@ impl Vocabulary { impl Vocabulary { /// Inserts a token to the vocabulary with the specified identifier, in place. pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { + // TODO: return error if eos token id is inserted let token = token.into(); - self.0.entry(token).or_default().push(id); + self.map.entry(token).or_default().push(id); } /// Extends the vocabulary with tokens and their identifiers, in place. @@ -54,21 +160,21 @@ impl Vocabulary { ) { for (token, ids) in tokens_and_ids.into_iter() { let token = token.into(); - self.0.entry(token).or_default().extend(ids); + self.map.entry(token).or_default().extend(ids); } } } -impl Deref for Vocabulary { +impl std::ops::Deref for Vocabulary { type Target = HashMap>; fn deref(&self) -> &HashMap> { - &self.0 + &self.map } } -impl Display for Vocabulary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl std::fmt::Display for Vocabulary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for (index, (token, token_ids)) in self.iter().enumerate() { if index != (self.len() - 1) { writeln!(f, "{:?} -> {:?}", token, token_ids)?; @@ -82,7 +188,10 @@ impl Display for Vocabulary { impl From>> for Vocabulary { fn from(map: HashMap>) -> Vocabulary { - Vocabulary(map) + Vocabulary { + eos_token_id: None, + map, + } } } @@ -92,17 +201,17 @@ where I: IntoIterator, { fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new().extend(tokens_and_ids) + Vocabulary::new(None).extend(tokens_and_ids) } } #[cfg(test)] mod tests { - use crate::prelude::*; + use super::*; #[test] fn insert() { - let vocabulary = Vocabulary::new() + let vocabulary = Vocabulary::new(None) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -117,7 +226,7 @@ mod tests { #[test] fn extend() { - let vocabulary = Vocabulary::new().extend([ + let vocabulary = Vocabulary::new(None).extend([ ("blah", vec![0]), ("1a", vec![1]), ("2", vec![2]), @@ -130,4 +239,66 @@ mod tests { assert_eq!(vocabulary["2"], &[2]); assert_eq!(vocabulary["0"], &[3]); } + + #[test] + fn pretrained_from_gpt2() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + assert_eq!(v_eos, 50256); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "<|endoftext|>" + ); + + let token = "Ġal"; + assert!(vocabulary.token_to_ids(token).is_none()); + assert!(tokenizer.token_to_id(token).is_some()); + + for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } + + #[test] + fn pretrained_from_llama() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + assert_eq!(v_eos, 2); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "" + ); + + for (v_token, t_token_expected) in [ + ("abc", "abc"), + (" al", "▁al"), + (" O", "▁O"), + (" ", "▁▁▁"), + // TODO: won't pass since first we need to change token's type to bytes + // ("<0xFF>", "ÿ"), + // ("<0x20>", "▁"), + ] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } }