Skip to content

Commit

Permalink
Extend vocabulary with eos token id & pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 8, 2024
1 parent c3b4430 commit e85ce90
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 21 deletions.
6 changes: 0 additions & 6 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,3 @@ pub use super::{
primitives::{State, Token, TokenId, TransitionKey},
vocabulary::Vocabulary,
};

pub(crate) use std::{
collections::{HashMap, HashSet},
fmt::{self, Display},
ops::Deref,
};
1 change: 1 addition & 0 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::prelude::*;
use std::collections::{HashMap, HashSet};

pub fn walk_fsm(
fsm_transitions: &HashMap<(State, TransitionKey), State>,
Expand Down
201 changes: 186 additions & 15 deletions src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
use std::collections::HashMap;

use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

use crate::locator::EosTokenLocator;
use crate::prelude::*;
use crate::processor::TokenProcessor;
use crate::VocabularyError;

/// Vocabulary of an LLM.
///
Expand All @@ -7,19 +15,116 @@ use crate::prelude::*;
/// ```rust
/// # use outlines_core::prelude::*;
/// #
/// let vocabulary = Vocabulary::new()
/// let vocabulary = Vocabulary::new(None)
/// .insert("blah", 0)
/// .insert("1a", 1)
/// .insert("2", 2)
/// .insert("0", 3);
/// ```
#[derive(Clone, Debug, Default)]
pub struct Vocabulary(pub(crate) HashMap<Token, Vec<TokenId>>);
pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
map: HashMap<Token, Vec<TokenId>>,
}

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new() -> Vocabulary {
Vocabulary::default()
pub fn new(eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
map: HashMap::new(),
}
}

/// Creates the vocabulary of pre-trained model from Hugging Face Hub.
pub fn from_pretrained(
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self, VocabularyError> {
let mut tokenizer =
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| {
VocabularyError::UnableToCreateTokenizer {
model: model.to_string(),
source: error,
}
})?;
Self::filter_normalizers(&mut tokenizer);

let eos_token_id = EosTokenLocator::locate(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
return Err(VocabularyError::UnableToLocateEosTokenId {
model: model.to_string(),
});
};

Vocabulary::try_from((&mut tokenizer, eos_token_id))
}

/// Per provided token returns vector of `TokenId`s if available in vocabulary.
pub fn token_to_ids(&self, token: &str) -> Option<&Vec<TokenId>> {
self.map.get(token)
}

/// Gets the identifier of the special end of sentence token.
pub fn eos_token_id(&self) -> Option<TokenId> {
self.eos_token_id
}

fn filter_normalizers(tokenizer: &mut Tokenizer) {
// Main concern is prepend normalizers, for example https://github.com/google/sentencepiece
// In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text,
// e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.]
//
// We don't want to deal with the special characters, so we remove `Prepend` normalizers.
if let Some(normalizer) = tokenizer.get_normalizer() {
match normalizer {
NormalizerWrapper::Sequence(normalization_sequence) => {
let new_sequence = Sequence::new(
normalization_sequence
.get_normalizers()
.iter()
.filter_map(|normalizer| match normalizer {
NormalizerWrapper::Prepend(_) => None,
_ => Some(normalizer.clone()),
})
.collect(),
);
tokenizer.with_normalizer(new_sequence.into());
}
NormalizerWrapper::Prepend(_) => {
tokenizer.with_normalizer(None::<NormalizerWrapper>);
}
_ => {}
}
}
}
}

impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary {
type Error = VocabularyError;

fn try_from(value: (&mut Tokenizer, u32)) -> Result<Vocabulary, VocabularyError> {
let (tokenizer, eos_token_id) = value;

let mut vocabulary = Vocabulary::new(Some(eos_token_id));
for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() {
if !added_token.special {
vocabulary = vocabulary.insert(added_token.content.clone(), *id);
}
}

let processor = TokenProcessor::new(tokenizer)?;
for (token, token_id) in tokenizer.get_vocab(false) {
let token_bytes = processor.process(token)?;
// TODO: lossy is temp:
// - in python in was handled by byte_symbol function
// - interface needs to be redefined to treat Token type as bytes: Vec<u8>
let processed_token = String::from_utf8_lossy(&token_bytes);
vocabulary = vocabulary.insert(processed_token, token_id);
}

Ok(vocabulary)
}
}

Expand All @@ -43,8 +148,9 @@ impl Vocabulary {
impl Vocabulary {
/// Inserts a token to the vocabulary with the specified identifier, in place.
pub fn insert_in_place(&mut self, token: impl Into<Token>, id: TokenId) {
// TODO: return error if eos token id is inserted
let token = token.into();
self.0.entry(token).or_default().push(id);
self.map.entry(token).or_default().push(id);
}

/// Extends the vocabulary with tokens and their identifiers, in place.
Expand All @@ -54,21 +160,21 @@ impl Vocabulary {
) {
for (token, ids) in tokens_and_ids.into_iter() {
let token = token.into();
self.0.entry(token).or_default().extend(ids);
self.map.entry(token).or_default().extend(ids);
}
}
}

impl Deref for Vocabulary {
impl std::ops::Deref for Vocabulary {
type Target = HashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.0
&self.map
}
}

impl Display for Vocabulary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl std::fmt::Display for Vocabulary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (index, (token, token_ids)) in self.iter().enumerate() {
if index != (self.len() - 1) {
writeln!(f, "{:?} -> {:?}", token, token_ids)?;
Expand All @@ -82,7 +188,10 @@ impl Display for Vocabulary {

impl From<HashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(map: HashMap<Token, Vec<TokenId>>) -> Vocabulary {
Vocabulary(map)
Vocabulary {
eos_token_id: None,
map,
}
}
}

Expand All @@ -92,17 +201,17 @@ where
I: IntoIterator<Item = TokenId>,
{
fn from_iter<A: IntoIterator<Item = (T, I)>>(tokens_and_ids: A) -> Self {
Vocabulary::new().extend(tokens_and_ids)
Vocabulary::new(None).extend(tokens_and_ids)
}
}

#[cfg(test)]
mod tests {
use crate::prelude::*;
use super::*;

#[test]
fn insert() {
let vocabulary = Vocabulary::new()
let vocabulary = Vocabulary::new(None)
.insert("blah", 0)
.insert("1a", 1)
.insert("2", 2)
Expand All @@ -117,7 +226,7 @@ mod tests {

#[test]
fn extend() {
let vocabulary = Vocabulary::new().extend([
let vocabulary = Vocabulary::new(None).extend([
("blah", vec![0]),
("1a", vec![1]),
("2", vec![2]),
Expand All @@ -130,4 +239,66 @@ mod tests {
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
}

#[test]
fn pretrained_from_gpt2() {
let model = "openai-community/gpt2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary");
assert_eq!(v_eos, 50256);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
"<|endoftext|>"
);

let token = "Ġal";
assert!(vocabulary.token_to_ids(token).is_none());
assert!(tokenizer.token_to_id(token).is_some());

for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] {
let v_ids = vocabulary.token_to_ids(v_token);
assert!(v_ids.is_some());
for v_id in v_ids.unwrap() {
let t_token = tokenizer
.id_to_token(*v_id)
.expect("Token id not found in tokenizer");
assert_eq!(&t_token, t_token_expected);
}
}
}

#[test]
fn pretrained_from_llama() {
let model = "hf-internal-testing/llama-tokenizer";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary");
assert_eq!(v_eos, 2);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
"</s>"
);

for (v_token, t_token_expected) in [
("abc", "abc"),
(" al", "▁al"),
(" O", "▁O"),
(" ", "▁▁▁"),
// TODO: won't pass since first we need to change token's type to bytes
// ("<0xFF>", "ÿ"),
// ("<0x20>", "▁"),
] {
let v_ids = vocabulary.token_to_ids(v_token);
assert!(v_ids.is_some());
for v_id in v_ids.unwrap() {
let t_token = tokenizer
.id_to_token(*v_id)
.expect("Token id not found in tokenizer");
assert_eq!(&t_token, t_token_expected);
}
}
}
}

0 comments on commit e85ce90

Please sign in to comment.