diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..ff977a7d --- /dev/null +++ b/src/error.rs @@ -0,0 +1,27 @@ +use thiserror::Error; + +#[derive(Error, Debug, PartialEq)] +pub enum Error { + #[error("The vocabulary does not allow us to build a sequence that matches the input")] + IndexError, + #[error("Unable to create tokenizer for {model}")] + UnableToCreateTokenizer { model: String }, + #[error("Unable to locate EOS token for {model}")] + UnableToLocateEosTokenId { model: String }, + #[error("Tokenizer is not supported by token processor")] + UnsupportedByTokenProcessor, + #[error("Decoder unpacking failed for token processor")] + DecoderUnpackingFailed, + #[error("Token processing failed for byte level processor")] + ByteProcessorFailed, + #[error("Token processing failed for byte fallback level processor")] + ByteFallbackProcessorFailed, +} + +#[cfg(feature = "python-bindings")] +impl From for pyo3::PyErr { + fn from(e: Error) -> Self { + use pyo3::{exceptions::PyValueError, PyErr}; + PyErr::new::(e.to_string()) + } +} diff --git a/src/index.rs b/src/index.rs index 587cd76a..cc1187e8 100644 --- a/src/index.rs +++ b/src/index.rs @@ -2,10 +2,9 @@ use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; +use crate::{Error, Result}; use std::collections::{HashMap, HashSet}; -pub type Result = std::result::Result; - #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, @@ -101,7 +100,7 @@ impl Index { eos_token_id, }) } else { - Err(crate::Error::IndexError) + Err(Error::IndexError) } } diff --git a/src/lib.rs b/src/lib.rs index 695c529e..4c45de4a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod index; pub mod json_schema; pub mod prelude; @@ -5,56 +6,9 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -#[cfg(feature = "python-bindings")] -mod python_bindings; - -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum Error { - #[error("The vocabulary does not allow us to build a sequence that matches the input")] - IndexError, -} +use error::Error; -#[derive(Error, Debug)] -#[error("Tokenizer error")] -pub struct TokenizerError(tokenizers::Error); - -impl PartialEq for TokenizerError { - fn eq(&self, other: &Self) -> bool { - self.0.to_string() == other.0.to_string() - } -} - -#[derive(Error, Debug, PartialEq)] -pub enum VocabularyError { - #[error("Unable to create tokenizer for {model}, source {source}")] - UnableToCreateTokenizer { - model: String, - source: TokenizerError, - }, - #[error("Unable to locate EOS token for {model}")] - UnableToLocateEosTokenId { model: String }, - #[error("Unable to process token")] - TokenProcessorError(#[from] TokenProcessorError), -} - -#[derive(Error, Debug, PartialEq)] -pub enum TokenProcessorError { - #[error("Tokenizer is not supported")] - UnsupportedTokenizer, - #[error("Decoder unpacking failed")] - DecoderUnpackingFailed, - #[error("Token processing failed for byte level processor")] - ByteProcessorFailed, - #[error("Token processing failed for byte fallback level processor")] - ByteFallbackProcessorFailed, -} +pub type Result = std::result::Result; #[cfg(feature = "python-bindings")] -impl From for pyo3::PyErr { - fn from(e: Error) -> Self { - use pyo3::{exceptions::PyValueError, PyErr}; - PyErr::new::(e.to_string()) - } -} +mod python_bindings; diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index b62c22e7..45aaef53 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,7 +3,8 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::{prelude::*, TokenizerError, VocabularyError}; +use crate::prelude::*; +use crate::{Error, Result}; use locator::EosTokenLocator; use processor::TokenProcessor; @@ -44,19 +45,18 @@ impl Vocabulary { pub fn from_pretrained( model: &str, parameters: Option, - ) -> Result { + ) -> Result { let mut tokenizer = - Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { - VocabularyError::UnableToCreateTokenizer { + Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { + Error::UnableToCreateTokenizer { model: model.to_string(), - source: TokenizerError(error), } })?; Self::filter_prepend_normalizers(&mut tokenizer); let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { - return Err(VocabularyError::UnableToLocateEosTokenId { + return Err(Error::UnableToLocateEosTokenId { model: model.to_string(), }); }; @@ -106,9 +106,9 @@ impl Vocabulary { } impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { - type Error = VocabularyError; + type Error = Error; - fn try_from(value: (&mut Tokenizer, u32)) -> Result { + fn try_from(value: (&mut Tokenizer, u32)) -> Result { let (tokenizer, eos_token_id) = value; let mut vocabulary = Vocabulary::new(Some(eos_token_id)); @@ -313,12 +313,7 @@ mod tests { assert!(vocabulary.is_err()); if let Err(e) = vocabulary { - assert_eq!( - e, - VocabularyError::TokenProcessorError( - crate::TokenProcessorError::UnsupportedTokenizer - ) - ) + assert_eq!(e, Error::UnsupportedByTokenProcessor) } } @@ -328,9 +323,8 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None); assert!(vocabulary.is_err()); - if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary { + if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary { assert_eq!(model, model.to_string()); - assert_eq!(source.to_string(), "Tokenizer error".to_string()); } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 9488a78f..cec32f52 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -5,9 +5,7 @@ use serde::Deserialize; use tokenizers::normalizers::Replace; use tokenizers::{DecoderWrapper, Tokenizer}; -use crate::TokenProcessorError; - -type Result = std::result::Result; +use crate::{Error, Result}; /// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete /// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each @@ -157,7 +155,7 @@ impl TokenProcessor { /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. pub(crate) fn new(tokenizer: &Tokenizer) -> Result { match tokenizer.get_decoder() { - None => Err(TokenProcessorError::UnsupportedTokenizer), + None => Err(Error::UnsupportedByTokenProcessor), Some(decoder) => match decoder { DecoderWrapper::ByteLevel(_) => Ok(Self { level: TokenProcessorLevel::Byte, @@ -188,10 +186,10 @@ impl TokenProcessor { level: TokenProcessorLevel::ByteFallback(Mods { spacechar }), }) } else { - Err(TokenProcessorError::UnsupportedTokenizer) + Err(Error::UnsupportedByTokenProcessor) } } - _ => Err(TokenProcessorError::UnsupportedTokenizer), + _ => Err(Error::UnsupportedByTokenProcessor), }, } } @@ -199,23 +197,22 @@ impl TokenProcessor { /// Operates on each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { - TokenProcessorLevel::Byte => { - let mut bytes = vec![]; - for char in token.chars() { - match CHAR_MAP.get(&char) { - None => return Err(TokenProcessorError::ByteProcessorFailed), - Some(b) => bytes.push(*b), - } - } - Ok(bytes) - } + TokenProcessorLevel::Byte => token + .chars() + .map(|char| { + CHAR_MAP + .get(&char) + .copied() + .ok_or(Error::ByteProcessorFailed) + }) + .collect(), TokenProcessorLevel::ByteFallback(mods) => { // If the token is of form `<0x__>`: if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { // Get to a single byte specified in the __ part and parse it in base 16 to a byte. match u8::from_str_radix(&token[3..5], 16) { Ok(byte) => Ok([byte].to_vec()), - Err(_) => Err(TokenProcessorError::ByteFallbackProcessorFailed), + Err(_) => Err(Error::ByteFallbackProcessorFailed), } } else { Ok(mods.apply_default(token).as_bytes().to_vec()) @@ -228,10 +225,10 @@ impl TokenProcessor { /// into local `ReplaceDecoder` structure. fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { - Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + Err(_) => Err(Error::DecoderUnpackingFailed), Ok(value) => match serde_json::from_value(value) { Ok(d) => Ok(d), - Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + Err(_) => Err(Error::DecoderUnpackingFailed), }, } } @@ -324,7 +321,7 @@ mod tests { let result = TokenProcessor::new(&tokenizer); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::UnsupportedTokenizer) + assert_eq!(e, Error::UnsupportedByTokenProcessor) } } @@ -338,7 +335,7 @@ mod tests { let result = processor.process(token.to_string()); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::ByteProcessorFailed) + assert_eq!(e, Error::ByteProcessorFailed) } } } @@ -352,7 +349,7 @@ mod tests { let result = processor.process("<0x6y>".to_string()); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed) + assert_eq!(e, Error::ByteFallbackProcessorFailed) } } }