diff --git a/Cargo.toml b/Cargo.toml index 0e83d020..c6054953 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,10 @@ thiserror = "1.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } regex = "1.10.6" serde-pyobject = "0.4.0" -serde_json = { version = "1.0.125", features = ["preserve_order"] } +serde_json = { version = "1.0", features = ["preserve_order"] } +serde = {version = "1", features = ["derive"]} +hf-hub = "0.3.2" +tokenizers = { version = "0.20.0", features = ["http"] } [features] python-bindings = ["pyo3"] diff --git a/src/lib.rs b/src/lib.rs index 71787e2e..576d9962 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ pub mod primitives; pub mod regex; pub mod vocabulary; +mod locator; + #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/locator.rs b/src/locator.rs new file mode 100644 index 00000000..272e02e9 --- /dev/null +++ b/src/locator.rs @@ -0,0 +1,215 @@ +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use serde::{Deserialize, Serialize}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use crate::primitives::*; + +/// List of common eos token locations appearing on hugging face hub, ordered by priority. +const COMMON_LOCATIONS: &[EosTokenLocation] = &[ + // Most projects have `generation_config.json` that looks like: + // { + // ... + // "eos_token_id": 50256, + // ... + // } + // So it's the first place we look for the eos token id. + // + // For example: + // - https://huggingface.co/openai-community/gpt2/blob/main/generation_config.json + EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Id, + }, + // The ones that don't have `generation_config.json` usually have `tokenizer_config.json`: + // { + // ... + // "eos_token": "<|endoftext|>", + // ... + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/microsoft/phi-2/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Value, + }, + // Sometimes `tokenizer_config.json` can have the following format as well: + // { + // "eos_token": { + // ... + // "content": "", + // ... + // }, + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/hf-internal-testing/llama-tokenizer/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Object, + }, +]; + +#[derive(Debug, Serialize, Deserialize)] +struct Id { + eos_token_id: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Value { + eos_token: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Object { + eos_token: Content, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Content { + content: String, +} + +/// Kind of the json field which will be checked for eos token id. +enum EosTokenField { + Id, + Value, + Object, +} + +/// Location of the end of sentence token id in a config file. +struct EosTokenLocation { + file: &'static str, + location: EosTokenField, +} + +pub(crate) struct EosTokenLocator; + +impl EosTokenLocator { + pub(crate) fn locate( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) + } +} + +impl EosTokenLocation { + /// Finds eos token within defined location in related config file. + fn lookup( + &self, + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + let file_path = Self::download_config(model, self.file, parameters).ok()?; + let file = std::fs::File::open(file_path).ok()?; + + match self.location { + EosTokenField::Id => { + let config: Id = serde_json::from_reader(file).ok()?; + u32::try_from(config.eos_token_id).ok() + } + EosTokenField::Value => { + let config: Value = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token) + } + EosTokenField::Object => { + let config: Object = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token.content) + } + } + } + + /// Downloads a config file from Hugging Face Hub. + fn download_config( + project: &str, + file: &str, + parameters: &Option, + ) -> tokenizers::Result { + // Adapted from + // https://github.com/huggingface/tokenizers/blob/9b77c054ef4297c7057fa8db875368c7c02f1bfc/tokenizers/src/utils/from_pretrained.rs#L26 + + let params = parameters.clone().unwrap_or_default(); + + Self::validate(project)?; + Self::validate(¶ms.revision)?; + + let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision); + let api = ApiBuilder::new() + .with_token(params.auth_token) + .build()? + .repo(repo); + + Ok(api.get(file)?) + } + + fn validate(input: &str) -> tokenizers::Result<()> { + let valid_chars = ['-', '_', '.', '/']; + + if !input + .chars() + .all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c)) + { + return Err(format!( + "Input {input} contains invalid characters, expected only alphanumeric or {}", + valid_chars + .iter() + .map(|x| format!("'{}'", x)) + .collect::>() + .join(", ") + ) + .into()); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn common_locations() { + for (model, expected_token_id, expected_token) in &[ + ("openai-community/gpt2", 50256, "<|endoftext|>"), + ("microsoft/phi-2", 50256, "<|endoftext|>"), + ("hf-internal-testing/llama-tokenizer", 2, ""), + ] { + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let located = + EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located"); + + assert_eq!(located, *expected_token_id); + assert_eq!( + tokenizer.id_to_token(located).expect("Token is not found"), + expected_token.to_string() + ); + } + } + + #[test] + fn bad_location() { + let bad_location = EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Id, + }; + let model = "microsoft/phi-2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let token_id = bad_location.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + + let bad_file = EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Value, + }; + let token_id = bad_file.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + } +}