Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 11, 2024
1 parent 5eb350d commit 9b42797
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 7 deletions.
16 changes: 13 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,36 @@ mod python_bindings;

use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("The vocabulary does not allow us to build a sequence that matches the input")]
IndexError,
}

#[derive(Error, Debug)]
#[error("Tokenizer error")]
pub struct TokenizerError(tokenizers::Error);

impl PartialEq for TokenizerError {
fn eq(&self, other: &Self) -> bool {
self.0.to_string() == other.0.to_string()
}
}

#[derive(Error, Debug, PartialEq)]
pub enum VocabularyError {
#[error("Unable to create tokenizer for {model}, source {source}")]
UnableToCreateTokenizer {
model: String,
source: tokenizers::Error,
source: TokenizerError,
},
#[error("Unable to locate EOS token for {model}")]
UnableToLocateEosTokenId { model: String },
#[error("Unable to process token")]
TokenProcessorError(#[from] TokenProcessorError),
}

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum TokenProcessorError {
#[error("Tokenizer is not supported")]
UnsupportedTokenizer,
Expand Down
65 changes: 62 additions & 3 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use std::collections::HashMap;
use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

use crate::prelude::*;
use crate::VocabularyError;
use crate::{prelude::*, TokenizerError, VocabularyError};

use locator::EosTokenLocator;
use processor::TokenProcessor;
Expand Down Expand Up @@ -50,7 +49,7 @@ impl Vocabulary {
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| {
VocabularyError::UnableToCreateTokenizer {
model: model.to_string(),
source: error,
source: TokenizerError(error),
}
})?;
Self::filter_normalizers(&mut tokenizer);
Expand Down Expand Up @@ -305,4 +304,64 @@ mod tests {
}
}
}

#[test]
fn token_processor_error() {
let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM";
let vocabulary = Vocabulary::from_pretrained(model, None);

assert!(vocabulary.is_err());
if let Err(e) = vocabulary {
assert_eq!(
e,
VocabularyError::TokenProcessorError(
crate::TokenProcessorError::UnsupportedTokenizer
)
)
}
}

#[test]
fn tokenizer_error() {
let model = "hf-internal-testing/some-non-existent-model";
let vocabulary = Vocabulary::from_pretrained(model, None);

assert!(vocabulary.is_err());
if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary {
assert_eq!(model, model.to_string());
assert_eq!(source.to_string(), "Tokenizer error".to_string());
}
}

#[test]
fn prepend_normalizers_filtered_out() {
use tokenizers::normalizers::{Prepend, Sequence};

let prepend = Prepend::new("_".to_string());
let prepend_normalizer = NormalizerWrapper::Prepend(prepend);
let sequence = Sequence::new(vec![prepend_normalizer.clone()]);
let sequence_normalizer = NormalizerWrapper::Sequence(sequence);

let model = "hf-internal-testing/llama-tokenizer";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");

for normalizer in [prepend_normalizer, sequence_normalizer] {
let mut normalized_t = tokenizer.clone();
normalized_t.with_normalizer(Some(normalizer));
Vocabulary::filter_normalizers(&mut normalized_t);
if let Some(n) = normalized_t.get_normalizer() {
match n {
NormalizerWrapper::Sequence(seq) => {
for n in seq.get_normalizers() {
if let NormalizerWrapper::Prepend(_) = n {
unreachable!()
}
}
}
NormalizerWrapper::Prepend(_) => unreachable!(),
_ => {}
}
}
}
}
}
42 changes: 41 additions & 1 deletion src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl TokenProcessor {
}
}

/// Process each token based on the level ofTokenProcesso.
/// Process each token based on the level of `TokenProcessor`.
pub(crate) fn process(&self, token: String) -> Result<Vec<u8>> {
match &self.level {
TokenProcessorLevel::Byte => {
Expand Down Expand Up @@ -312,4 +312,44 @@ mod tests {
assert_eq!(processed, expected);
}
}

#[test]
fn unsupported_tokenizer_error() {
let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");

let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::UnsupportedTokenizer)
}
}

#[test]
fn byte_processor_error() {
let model = "openai-community/gpt2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] {
let result = processor.process(token.to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::ByteProcessorFailed)
}
}
}

#[test]
fn byte_fallback_processor_error() {
let model = "hf-internal-testing/llama-tokenizer";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

let result = processor.process("<0x6y>".to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed)
}
}
}

0 comments on commit 9b42797

Please sign in to comment.