diff --git a/src/error.rs b/src/error.rs index 652fa740..f589731c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,9 @@ use thiserror::Error; +pub type Result = std::result::Result; + #[derive(Error, Debug)] +#[error("{0}")] pub struct TokenizersError(pub tokenizers::Error); impl PartialEq for TokenizersError { @@ -9,12 +12,6 @@ impl PartialEq for TokenizersError { } } -impl std::fmt::Display for TokenizersError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - #[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] diff --git a/src/lib.rs b/src/lib.rs index 4c45de4a..08c47def 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,9 +6,7 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -use error::Error; - -pub type Result = std::result::Result; +use error::{Error, Result}; #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 782b621a..d3f8bcfc 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -4,6 +4,7 @@ use tokenizers::{FromPretrainedParameters, Tokenizer}; use crate::primitives::*; +/// Mapping of characters to bytes for GPT-2 like tokenizers. /// List of common eos token locations appearing on hugging face hub, ordered by priority. const COMMON_LOCATIONS: &[EosTokenLocation] = &[ // Most projects have `generation_config.json` that looks like: @@ -71,6 +72,7 @@ struct Object { eos_token: Content, } +/// `eos_token` provided in a `Content`. #[derive(Debug, Serialize, Deserialize)] struct Content { content: String, @@ -91,6 +93,7 @@ struct EosTokenLocation { /// Locates eos token id. pub(crate) trait Locator { + /// Locates eos token id in defined locations by `Locator`. fn locate_eos_token_id( model: &str, tokenizer: &Tokenizer, @@ -102,6 +105,7 @@ pub(crate) trait Locator { pub(crate) struct HFLocator; impl Locator for HFLocator { + /// Locates eos token id in defined locations. fn locate_eos_token_id( model: &str, tokenizer: &Tokenizer, diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 55b6cde1..7426f249 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -77,12 +77,6 @@ static CHAR_MAP: Lazy> = Lazy::new(|| { char_map }); -/// Token processor to adjust tokens according to the tokenizer's level. -#[derive(Debug)] -pub(crate) struct TokenProcessor { - level: TokenProcessorLevel, -} - /// Recognizes different tokenizer's levels. #[derive(Debug, Clone, PartialEq)] pub(crate) enum TokenProcessorLevel { @@ -99,13 +93,17 @@ pub(crate) struct Mods { spacechar: char, } -/// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. -static DEFAULT_MODS: Mods = Mods { spacechar: ' ' }; +impl Default for Mods { + /// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. + fn default() -> Self { + Self { spacechar: ' ' } + } +} impl Mods { - /// Apply default modifications. + /// Apply default modifications to each token. fn apply_default(&self, token: String) -> String { - let to = DEFAULT_MODS.spacechar.to_string(); + let to = Self::default().spacechar.to_string(); token.replace(self.spacechar, &to) } } @@ -142,6 +140,12 @@ enum ReplacePattern { String(String), } +/// Token processor to adjust tokens according to the tokenizer's level. +#[derive(Debug)] +pub(crate) struct TokenProcessor { + level: TokenProcessorLevel, +} + impl TokenProcessor { /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. pub(crate) fn new(tokenizer: &Tokenizer) -> Result {