diff --git a/src/lib.rs b/src/lib.rs index e3978000..695c529e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,18 +10,28 @@ mod python_bindings; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] IndexError, } #[derive(Error, Debug)] +#[error("Tokenizer error")] +pub struct TokenizerError(tokenizers::Error); + +impl PartialEq for TokenizerError { + fn eq(&self, other: &Self) -> bool { + self.0.to_string() == other.0.to_string() + } +} + +#[derive(Error, Debug, PartialEq)] pub enum VocabularyError { #[error("Unable to create tokenizer for {model}, source {source}")] UnableToCreateTokenizer { model: String, - source: tokenizers::Error, + source: TokenizerError, }, #[error("Unable to locate EOS token for {model}")] UnableToLocateEosTokenId { model: String }, @@ -29,7 +39,7 @@ pub enum VocabularyError { TokenProcessorError(#[from] TokenProcessorError), } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum TokenProcessorError { #[error("Tokenizer is not supported")] UnsupportedTokenizer, diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 2851752d..fef311f4 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,8 +3,7 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::prelude::*; -use crate::VocabularyError; +use crate::{prelude::*, TokenizerError, VocabularyError}; use locator::EosTokenLocator; use processor::TokenProcessor; @@ -50,7 +49,7 @@ impl Vocabulary { Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { VocabularyError::UnableToCreateTokenizer { model: model.to_string(), - source: error, + source: TokenizerError(error), } })?; Self::filter_normalizers(&mut tokenizer); @@ -305,4 +304,64 @@ mod tests { } } } + + #[test] + fn token_processor_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + assert!(vocabulary.is_err()); + if let Err(e) = vocabulary { + assert_eq!( + e, + VocabularyError::TokenProcessorError( + crate::TokenProcessorError::UnsupportedTokenizer + ) + ) + } + } + + #[test] + fn tokenizer_error() { + let model = "hf-internal-testing/some-non-existent-model"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + assert!(vocabulary.is_err()); + if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary { + assert_eq!(model, model.to_string()); + assert_eq!(source.to_string(), "Tokenizer error".to_string()); + } + } + + #[test] + fn prepend_normalizers_filtered_out() { + use tokenizers::normalizers::{Prepend, Sequence}; + + let prepend = Prepend::new("_".to_string()); + let prepend_normalizer = NormalizerWrapper::Prepend(prepend); + let sequence = Sequence::new(vec![prepend_normalizer.clone()]); + let sequence_normalizer = NormalizerWrapper::Sequence(sequence); + + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + for normalizer in [prepend_normalizer, sequence_normalizer] { + let mut normalized_t = tokenizer.clone(); + normalized_t.with_normalizer(Some(normalizer)); + Vocabulary::filter_normalizers(&mut normalized_t); + if let Some(n) = normalized_t.get_normalizer() { + match n { + NormalizerWrapper::Sequence(seq) => { + for n in seq.get_normalizers() { + if let NormalizerWrapper::Prepend(_) = n { + unreachable!() + } + } + } + NormalizerWrapper::Prepend(_) => unreachable!(), + _ => {} + } + } + } + } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 5048b11e..ce149f80 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -194,7 +194,7 @@ impl TokenProcessor { } } - /// Process each token based on the level ofTokenProcesso. + /// Process each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { TokenProcessorLevel::Byte => { @@ -312,4 +312,44 @@ mod tests { assert_eq!(processed, expected); } } + + #[test] + fn unsupported_tokenizer_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::UnsupportedTokenizer) + } + } + + #[test] + fn byte_processor_error() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] { + let result = processor.process(token.to_string()); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::ByteProcessorFailed) + } + } + } + + #[test] + fn byte_fallback_processor_error() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + let result = processor.process("<0x6y>".to_string()); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed) + } + } }