From 606b4606fc9e369fd57a33a82e64559d1e101cbc Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 12 Dec 2024 20:04:58 +0000 Subject: [PATCH 01/22] Build Index from regex --- Cargo.toml | 1 + src/error.rs | 23 ++++---- src/index.rs | 124 +++++++++++++++++++++++++++++++++++++++++- src/vocabulary/mod.rs | 17 +++++- 4 files changed, 147 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f3f31a..f49df1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bincode = "2.0.0-rc.3" hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } rustc-hash = "2.1.0" +regex-automata = "0.4.9" [features] python-bindings = ["pyo3"] diff --git a/src/error.rs b/src/error.rs index d7905ab..a8e5864 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,21 +3,18 @@ use thiserror::Error; pub type Result = std::result::Result; #[derive(Error, Debug)] -#[error("{0}")] -pub struct TokenizersError(pub tokenizers::Error); - -impl PartialEq for TokenizersError { - fn eq(&self, other: &Self) -> bool { - self.0.to_string() == other.0.to_string() - } -} - -#[derive(Error, Debug, PartialEq)] pub enum Error { - #[error("The vocabulary does not allow us to build a sequence that matches the input")] - IndexError, + #[error("The vocabulary does not allow to build an index that matches the input")] + InsufficientVocabulary, + // TODO: this error will be removed once eos_token_id for vocabulary won't be optional + #[error("Index failed since vocabulary doesn't provide eos token id")] + IndexEosTokenIdNotAvailable, + #[error("Failed to build DFA {0}")] + IndexDfaError(#[from] Box), + #[error("Index failed since anchored universal start state doesn't exist")] + IndexNoAnchoredUniversalStartState, #[error(transparent)] - TokenizersError(#[from] TokenizersError), + TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] UnsupportedTokenizer { model: String, reason: String }, #[error("Unable to locate EOS token for {model}")] diff --git a/src/index.rs b/src/index.rs index 5fcc3e9..127ea4c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,9 +1,12 @@ /// Construct an Index. -use crate::prelude::{State, TransitionKey}; +use crate::prelude::*; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; use bincode::{Decode, Encode}; +use regex_automata::dfa::{dense::DFA, Automaton}; +use regex_automata::util::primitives::StateID as AutomataStateId; +use regex_automata::Anchored; use rustc_hash::{FxHashMap, FxHashSet}; #[derive(Debug)] @@ -101,7 +104,96 @@ impl Index { eos_token_id, }) } else { - Err(Error::IndexError) + Err(Error::InsufficientVocabulary) + } + } + + pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result { + let eos_token_id = match vocabulary.eos_token_id() { + Some(s) => s, + None => return Err(Error::IndexEosTokenIdNotAvailable), + }; + + let dfa = DFA::builder().build(regex).map_err(Box::new)?; + let start_state = match dfa.universal_start_state(Anchored::Yes) { + Some(s) => s, + None => return Err(Error::IndexNoAnchoredUniversalStartState), + }; + + let mut index: FxHashMap> = FxHashMap::default(); + let mut seen: FxHashSet = FxHashSet::default(); + let mut final_states: FxHashSet = FxHashSet::default(); + let mut next_states: FxHashSet = FxHashSet::from_iter([start_state]); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + seen.insert(start_state); + + if dfa.is_match_state(dfa.next_eoi_state(start_state)) { + final_states.insert(start_state.as_u32()); + } + + 'token_loop: for (token, ids) in vocabulary.tokens_to_ids().iter() { + if ids.contains(&eos_token_id) { + continue; + } + + let mut next_state = start_state; + for transition_byte in token.as_bytes() { + next_state = dfa.next_state(next_state, *transition_byte); + if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { + continue 'token_loop; + } + } + + if dfa.is_match_state(next_state) { + // Token either matched or matched except the last character. + // Check what happens if the input suddenly ends after reaching this state. + // If the automata still matches, then token is exactly matched, if not + // then token didn't match. + let next_eoi_state = dfa.next_eoi_state(next_state); + let token_matched = dfa.is_match_state(next_eoi_state); + if !token_matched { + continue; + } + } + + for token_id in ids { + let mapping = index.entry(start_state.as_u32()).or_default(); + mapping.insert(*token_id, next_state.as_u32()); + + if !seen.contains(&next_state) { + next_states.insert(next_state); + } + } + } + } + + let start_state = start_state.as_u32(); + + // Populate `index` with mappings from `final_states` to `eos_token_id` + for &final_state in &final_states { + index + .entry(final_state) + .or_default() + .insert(eos_token_id, final_state); + } + // Check if there is at least one valid mapping + let is_valid = index.values().any(|mapping| { + mapping + .values() + .any(|end_state| final_states.contains(end_state)) + }); + + if is_valid { + Ok(Self { + initial: start_state, + finals: final_states, + states_to_token_subsets: index, + eos_token_id, + }) + } else { + Err(Error::InsufficientVocabulary) } } @@ -126,7 +218,35 @@ impl Index { self.finals.contains(&state) } + pub(crate) fn final_states(&self) -> &FxHashSet { + &self.finals + } + pub(crate) fn transitions(&self) -> &FxHashMap> { &self.states_to_token_subsets } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn index_from_regex() { + let regex = "0|[1-9][0-9]*"; + let vocabulary = Vocabulary::new(Some(4)) + .insert("blah", 0) + .insert("1a", 1) + .insert("2", 2) + .insert("0", 3) + .insert("", 4); + + let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + assert_eq!(index.initial(), 40); + assert_eq!(index.final_states(), &FxHashSet::from_iter([24, 48, 56])); + assert_eq!( + "{24: {3: 24, 4: 24, 2: 24}, 48: {4: 48}, 40: {3: 48, 2: 56}, 56: {3: 24, 4: 56, 2: 24}}", + format!("{:?}", index.transitions()) + ); + } +} diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 0b3eaa2..c95ed99 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,7 +3,7 @@ use rustc_hash::FxHashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::{error, prelude::*}; +use crate::prelude::*; use crate::{Error, Result}; use locator::{HFLocator, Locator}; @@ -41,6 +41,13 @@ impl Vocabulary { } } + pub fn with_eos_token_id(self, eos_token_id: Option) -> Self { + Self { + eos_token_id, + ..self + } + } + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -55,8 +62,7 @@ impl Vocabulary { model: &str, parameters: Option, ) -> Result { - let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()) - .map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?; + let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone())?; Self::filter_prepend_normalizers(&mut tokenizer); // Locate eos_token_id in defined locations. @@ -95,6 +101,11 @@ impl Vocabulary { Ok(vocabulary) } + /// Returns all tokens with their token ids in vocabulary + pub fn tokens_to_ids(&self) -> &FxHashMap> { + &self.tokens + } + /// Per provided token returns vector of `TokenId`s if available in the vocabulary. pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { self.tokens.get(token) From bdc120d023df8af2143ee7bd73ac7d4239bb28ec Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 12 Dec 2024 20:06:28 +0000 Subject: [PATCH 02/22] Test Index from regex in Guide --- python/outlines_core/fsm/guide.py | 16 +++++++++++--- python/outlines_core/fsm/outlines_core_rs.pyi | 17 ++++++++++++++ src/python_bindings/mod.rs | 22 +++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index 18b4523..bb7561a 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -7,9 +7,10 @@ create_fsm_index_tokenizer, make_byte_level_fsm, make_deterministic_fsm, + reduced_vocabulary, ) -from .outlines_core_rs import Index +from .outlines_core_rs import Index, Vocabulary @dataclass(frozen=True) @@ -137,8 +138,17 @@ def create_states_mapping( final_states: A set of final states in the FSM. """ - regex_fsm = regex_parser(regex_string).to_fsm() - return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) + # regex_fsm = regex_parser(regex_string).to_fsm() + # return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) + + # inlining logic of create_fsm_index_tokenizer + tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) + vocabulary = Vocabulary.from_dict_with_eos_token_id( + tokens_to_token_ids, tokenizer.eos_token_id + ) + index = Index.from_regex(regex_string, vocabulary) + + return index, empty_token_ids, set(index.final_states()) def create_states_mapping_from_fsm( diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index dae645e..81181ed 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -78,6 +78,14 @@ class Vocabulary: Creates a vocabulary from a dictionary of tokens to token IDs. """ ... + @staticmethod + def from_dict_with_eos_token_id( + map: Dict[str, List[int]], eos_token_id: int + ) -> "Vocabulary": + """ + Creates a vocabulary from a dictionary of tokens to token IDs and eos token id. + """ + ... def __repr__(self) -> str: """ Gets the debug string representation of the vocabulary. @@ -90,6 +98,12 @@ class Vocabulary: ... class Index: + @staticmethod + def from_regex(regex: str, vocabulary: "Vocabulary") -> "Index": + """ + Creates an index from a regex and vocabulary. + """ + ... def get_allowed_tokens(self, state: int) -> Optional[List[int]]: """Returns allowed tokens in this state.""" ... @@ -99,6 +113,9 @@ class Index: def is_final_state(self, state: int) -> bool: """Determines whether the current state is a final state.""" ... + def final_states(self) -> List[int]: + """Get all final states.""" + ... def get_index_dict(self) -> Dict[int, Dict[int, int]]: """Returns the Index as a Python Dict object.""" ... diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index f9c4936..059596b 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -93,6 +93,15 @@ impl PyIndex { }) } + #[staticmethod] + fn from_regex(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { + py.allow_threads(|| { + Index::from_regex(regex, &vocabulary.0) + .map(PyIndex) + .map_err(Into::into) + }) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -126,6 +135,10 @@ impl PyIndex { self.0.is_final(state) } + fn final_states(&self) -> FxHashSet { + self.0.final_states().clone() + } + fn get_transitions(&self) -> FxHashMap> { self.0.transitions().clone() } @@ -291,6 +304,15 @@ impl PyVocabulary { PyVocabulary(Vocabulary::from(map)) } + #[staticmethod] + fn from_dict_with_eos_token_id( + map: FxHashMap>, + eos_token_id: TokenId, + ) -> PyVocabulary { + let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); + PyVocabulary(v) + } + fn __repr__(&self) -> String { format!("{:#?}", self.0) } From 6c5b85334e9ad3ef77dd042bfa99f738f595017f Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 13 Dec 2024 21:00:51 +0000 Subject: [PATCH 03/22] Use FxHash* as default Hash* --- src/python_bindings/mod.rs | 50 +++++++++++++++++++------------------- src/regex.rs | 26 ++++++++++---------- src/vocabulary/mod.rs | 18 +++++++------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 059596b..c7676fb 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,21 +10,21 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pyfunction; -use rustc_hash::{FxHashMap, FxHashSet}; use serde_json::Value; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[pyclass(name = "FSMInfo")] pub struct PyFSMInfo { #[pyo3(get)] initial: State, #[pyo3(get)] - finals: FxHashSet, + finals: HashSet, #[pyo3(get)] - transitions: FxHashMap<(State, TransitionKey), State>, + transitions: HashMap<(State, TransitionKey), State>, #[pyo3(get)] alphabet_anything_value: TransitionKey, #[pyo3(get)] - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, } impl From for PyFSMInfo { @@ -57,10 +57,10 @@ impl PyFSMInfo { #[new] fn new( initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, ) -> Self { FSMInfo::new( initial, @@ -84,7 +84,7 @@ impl PyIndex { fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, eos_token_id: u32, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> PyResult { py.allow_threads(|| { Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) @@ -135,11 +135,11 @@ impl PyIndex { self.0.is_final(state) } - fn final_states(&self) -> FxHashSet { + fn final_states(&self) -> HashSet { self.0.final_states().clone() } - fn get_transitions(&self) -> FxHashMap> { + fn get_transitions(&self) -> HashMap> { self.0.transitions().clone() } @@ -171,9 +171,9 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] pub fn walk_fsm_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, + fsm_transitions: HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: FxHashSet, + fsm_finals: HashSet, token_transition_keys: Vec, start_state: State, full_match: bool, @@ -193,13 +193,13 @@ pub fn walk_fsm_py( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] pub fn state_scan_tokens_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, + fsm_transitions: HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: FxHashSet, + fsm_finals: HashSet, vocabulary: &PyVocabulary, - vocabulary_transition_keys: FxHashMap>, + vocabulary_transition_keys: HashMap>, start_state: State, -) -> PyResult> { +) -> PyResult> { Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -213,7 +213,7 @@ pub fn state_scan_tokens_py( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, alphabet_anything_value: TransitionKey, token_str: String, ) -> PyResult> { @@ -229,11 +229,11 @@ pub fn get_token_transition_keys_py( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, alphabet_anything_value: TransitionKey, vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, -) -> PyResult>> { + frozen_tokens: HashSet, +) -> PyResult>> { Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -248,11 +248,11 @@ pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> PyResult> { let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter(vec![fsm_info.initial]); + let mut seen: HashSet = HashSet::default(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -300,13 +300,13 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(map: FxHashMap>) -> PyVocabulary { + fn from_dict(map: HashMap>) -> PyVocabulary { PyVocabulary(Vocabulary::from(map)) } #[staticmethod] fn from_dict_with_eos_token_id( - map: FxHashMap>, + map: HashMap>, eos_token_id: TokenId, ) -> PyVocabulary { let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); diff --git a/src/regex.rs b/src/regex.rs index 24687f1..c9270b6 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,10 +1,10 @@ use crate::prelude::*; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; pub fn walk_fsm( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, + fsm_transitions: &HashMap<(State, TransitionKey), State>, _fsm_initial: State, - fsm_finals: &FxHashSet, + fsm_finals: &HashSet, token_transition_keys: &[TransitionKey], start_state: State, full_match: bool, @@ -39,14 +39,14 @@ pub fn walk_fsm( } pub fn state_scan_tokens( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, + fsm_transitions: &HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: &FxHashSet, + fsm_finals: &HashSet, vocabulary: &Vocabulary, - vocabulary_transition_keys: &FxHashMap>, + vocabulary_transition_keys: &HashMap>, start_state: State, -) -> FxHashSet<(TokenId, State)> { - let mut res = FxHashSet::default(); +) -> HashSet<(TokenId, State)> { + let mut res = HashSet::default(); for (token, token_ids) in vocabulary.iter() { let token_transition_keys = &vocabulary_transition_keys[token]; @@ -72,7 +72,7 @@ pub fn state_scan_tokens( } pub fn get_token_transition_keys( - alphabet_symbol_mapping: &FxHashMap, + alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, token_str: &str, ) -> Vec { @@ -105,12 +105,12 @@ pub fn get_token_transition_keys( } pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &FxHashMap, + alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, vocabulary: &Vocabulary, - frozen_tokens: &FxHashSet, -) -> FxHashMap> { - let mut vocab_transition_keys = FxHashMap::default(); + frozen_tokens: &HashSet, +) -> HashMap> { + let mut vocab_transition_keys = HashMap::default(); for item in vocabulary.iter() { let token_str = item.0.clone(); diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index c95ed99..13156ad 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::FxHashMap as HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -29,7 +29,7 @@ mod processor; pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, - tokens: FxHashMap>, + tokens: HashMap>, } impl Vocabulary { @@ -37,7 +37,7 @@ impl Vocabulary { pub fn new(eos_token_id: Option) -> Self { Self { eos_token_id, - tokens: FxHashMap::default(), + tokens: HashMap::default(), } } @@ -102,7 +102,7 @@ impl Vocabulary { } /// Returns all tokens with their token ids in vocabulary - pub fn tokens_to_ids(&self) -> &FxHashMap> { + pub fn tokens_to_ids(&self) -> &HashMap> { &self.tokens } @@ -185,9 +185,9 @@ impl Vocabulary { } impl std::ops::Deref for Vocabulary { - type Target = FxHashMap>; + type Target = HashMap>; - fn deref(&self) -> &FxHashMap> { + fn deref(&self) -> &HashMap> { &self.tokens } } @@ -205,8 +205,8 @@ impl std::fmt::Display for Vocabulary { } } -impl From>> for Vocabulary { - fn from(tokens: FxHashMap>) -> Vocabulary { +impl From>> for Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { Vocabulary { eos_token_id: None, tokens, @@ -268,7 +268,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map = FxHashMap::default(); + let map = HashMap::default(); let vocabulary = Vocabulary::from(map); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty()); From f3494040bef8ce72dcaaaa3734d93da26513e7aa Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 13 Dec 2024 21:23:08 +0000 Subject: [PATCH 04/22] Cleaner from_regex logic --- src/error.rs | 2 +- src/index.rs | 108 ++++++++++++++++++------------------- src/python_bindings/mod.rs | 2 +- 3 files changed, 54 insertions(+), 58 deletions(-) diff --git a/src/error.rs b/src/error.rs index a8e5864..53a8728 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,7 +12,7 @@ pub enum Error { #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] - IndexNoAnchoredUniversalStartState, + DfaHasNoStartState, #[error(transparent)] TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] diff --git a/src/index.rs b/src/index.rs index 127ea4c..a915766 100644 --- a/src/index.rs +++ b/src/index.rs @@ -7,24 +7,24 @@ use bincode::{Decode, Encode}; use regex_automata::dfa::{dense::DFA, Automaton}; use regex_automata::util::primitives::StateID as AutomataStateId; use regex_automata::Anchored; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, - pub(crate) finals: FxHashSet, - pub(crate) transitions: FxHashMap<(State, TransitionKey), State>, + pub(crate) finals: HashSet, + pub(crate) transitions: HashMap<(State, TransitionKey), State>, pub(crate) alphabet_anything_value: TransitionKey, - pub(crate) alphabet_symbol_mapping: FxHashMap, + pub(crate) alphabet_symbol_mapping: HashMap, } impl FSMInfo { pub fn new( initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, ) -> Self { Self { initial, @@ -39,8 +39,8 @@ impl FSMInfo { #[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, - finals: FxHashSet, - states_to_token_subsets: FxHashMap>, + finals: HashSet, + states_to_token_subsets: HashMap>, eos_token_id: u32, } @@ -49,11 +49,11 @@ impl Index { fsm_info: &FSMInfo, vocabulary: &Vocabulary, eos_token_id: u32, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> Result { - let mut states_to_token_subsets: FxHashMap> = FxHashMap::default(); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter([fsm_info.initial]); + let mut states_to_token_subsets: HashMap> = HashMap::default(); + let mut seen: HashSet = HashSet::default(); + let mut next_states: HashSet = HashSet::from_iter([fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -111,26 +111,25 @@ impl Index { pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result { let eos_token_id = match vocabulary.eos_token_id() { Some(s) => s, + // TODO: this error will be removed once eos_token_id for vocabulary won't be optional None => return Err(Error::IndexEosTokenIdNotAvailable), }; - let dfa = DFA::builder().build(regex).map_err(Box::new)?; + let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { Some(s) => s, - None => return Err(Error::IndexNoAnchoredUniversalStartState), + None => return Err(Error::DfaHasNoStartState), }; - let mut index: FxHashMap> = FxHashMap::default(); - let mut seen: FxHashSet = FxHashSet::default(); - let mut final_states: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter([start_state]); + let mut transitions: HashMap> = HashMap::default(); + let mut final_states: HashSet = HashSet::default(); - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - seen.insert(start_state); + let mut seen: HashSet = HashSet::from_iter([start_state]); + let mut next_states: Vec = vec![start_state]; - if dfa.is_match_state(dfa.next_eoi_state(start_state)) { - final_states.insert(start_state.as_u32()); + while let Some(current_state) = next_states.pop() { + if dfa.is_match_state(dfa.next_eoi_state(current_state)) { + final_states.insert(current_state.as_u32()); } 'token_loop: for (token, ids) in vocabulary.tokens_to_ids().iter() { @@ -138,7 +137,7 @@ impl Index { continue; } - let mut next_state = start_state; + let mut next_state = current_state; for transition_byte in token.as_bytes() { next_state = dfa.next_state(next_state, *transition_byte); if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { @@ -146,40 +145,33 @@ impl Index { } } - if dfa.is_match_state(next_state) { - // Token either matched or matched except the last character. - // Check what happens if the input suddenly ends after reaching this state. - // If the automata still matches, then token is exactly matched, if not - // then token didn't match. - let next_eoi_state = dfa.next_eoi_state(next_state); - let token_matched = dfa.is_match_state(next_eoi_state); - if !token_matched { - continue; + let is_intermediate_state = !dfa.is_match_state(next_state); + let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state)); + if is_intermediate_state || is_full_match_state { + for token_id in ids { + transitions + .entry(current_state.as_u32()) + .or_default() + .insert(*token_id, next_state.as_u32()); } } - - for token_id in ids { - let mapping = index.entry(start_state.as_u32()).or_default(); - mapping.insert(*token_id, next_state.as_u32()); - - if !seen.contains(&next_state) { - next_states.insert(next_state); - } + if !seen.contains(&next_state) { + seen.insert(next_state); + next_states.push(next_state); } } } - let start_state = start_state.as_u32(); - - // Populate `index` with mappings from `final_states` to `eos_token_id` + // Populate `transitions` with mappings from `final_states` to `eos_token_id` for &final_state in &final_states { - index + transitions .entry(final_state) .or_default() .insert(eos_token_id, final_state); } + // Check if there is at least one valid mapping - let is_valid = index.values().any(|mapping| { + let is_valid = transitions.values().any(|mapping| { mapping .values() .any(|end_state| final_states.contains(end_state)) @@ -187,9 +179,9 @@ impl Index { if is_valid { Ok(Self { - initial: start_state, + initial: start_state.as_u32(), finals: final_states, - states_to_token_subsets: index, + states_to_token_subsets: transitions, eos_token_id, }) } else { @@ -218,11 +210,11 @@ impl Index { self.finals.contains(&state) } - pub(crate) fn final_states(&self) -> &FxHashSet { + pub(crate) fn final_states(&self) -> &HashSet { &self.finals } - pub(crate) fn transitions(&self) -> &FxHashMap> { + pub(crate) fn transitions(&self) -> &HashMap> { &self.states_to_token_subsets } } @@ -243,10 +235,14 @@ mod tests { let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); - assert_eq!(index.final_states(), &FxHashSet::from_iter([24, 48, 56])); - assert_eq!( - "{24: {3: 24, 4: 24, 2: 24}, 48: {4: 48}, 40: {3: 48, 2: 56}, 56: {3: 24, 4: 56, 2: 24}}", - format!("{:?}", index.transitions()) - ); + assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); + + let expected: HashMap> = HashMap::from_iter([ + (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])), + (48, HashMap::from_iter([(4, 48)])), + (40, HashMap::from_iter([(3, 48), (2, 56)])), + (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])), + ]); + assert_eq!(&expected, index.transitions()); } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index c7676fb..2db679a 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,8 +10,8 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pyfunction; -use serde_json::Value; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use serde_json::Value; #[pyclass(name = "FSMInfo")] pub struct PyFSMInfo { From 15a85aa571edd19c3d4a1a7c7b219e6f4ec9e19e Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 18 Dec 2024 15:48:36 +0000 Subject: [PATCH 05/22] Use bytes as Token type, more tests for Index --- src/index.rs | 55 +++++++++++++++++++++++--- src/primitives.rs | 2 +- src/vocabulary/mod.rs | 91 +++++++++++++++++++++++++------------------ 3 files changed, 105 insertions(+), 43 deletions(-) diff --git a/src/index.rs b/src/index.rs index a915766..3df6e74 100644 --- a/src/index.rs +++ b/src/index.rs @@ -138,7 +138,7 @@ impl Index { } let mut next_state = current_state; - for transition_byte in token.as_bytes() { + for transition_byte in token { next_state = dfa.next_state(next_state, *transition_byte); if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { continue 'token_loop; @@ -230,19 +230,64 @@ mod tests { .insert("blah", 0) .insert("1a", 1) .insert("2", 2) - .insert("0", 3) - .insert("", 4); + .insert("0", 3); let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); - let expected: HashMap> = HashMap::from_iter([ + let expected = HashMap::from_iter([ (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])), (48, HashMap::from_iter([(4, 48)])), (40, HashMap::from_iter([(3, 48), (2, 56)])), (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])), ]); - assert_eq!(&expected, index.transitions()); + assert_eq!(index.transitions(), &expected); + } + + #[test] + fn index_from_regex_initital_in_allowed() { + let regex = "`\\n(\\.\\n)?`\\n"; + let vocabulary = Vocabulary::new(Some(104)) + .insert("\n", 103) + .insert(".", 102) + .insert("`", 101); + + let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + let allowed = index + .allowed_tokens(index.initial()) + .expect("No allowed tokens"); + assert!(allowed.contains(&101)); + } + + #[test] + fn index_from_regex_multibyte() { + let regex = "😇| [😈-😍][😇-😎]*"; + let vocabulary = Vocabulary::new(Some(8)) + .insert(" 😍", 5) + .insert("blah", 0) + .insert("😇", 2) + .insert("😈a", 1) + .insert("😍", 3) + .insert(vec![32, 240, 159, 152], 7) + .insert(vec![32, 240, 159, 152, 141], 6) + .insert(vec![240, 159, 152, 141], 4); + + let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + + assert_eq!(index.final_states(), &HashSet::from_iter([208, 128])); + + let expected = HashMap::from_iter([ + ( + 208, + HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]), + ), + ( + 80, + HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]), + ), + (128, HashMap::from_iter([(8, 128)])), + ]); + assert_eq!(index.transitions(), &expected); } } diff --git a/src/primitives.rs b/src/primitives.rs index e12bf03..0976f76 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -2,7 +2,7 @@ pub type TransitionKey = u32; /// Token content. -pub type Token = String; +pub type Token = Vec; /// Token identifier. pub type TokenId = u32; diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 13156ad..613d735 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap as HashMap; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -90,11 +90,7 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let token_bytes = processor.process(token)?; - // TODO: lossy is temp: - // - in python in was handled by byte_symbol function - // - interface needs to be redefined to treat Token type as bytes: Vec - let processed_token = String::from_utf8_lossy(&token_bytes); + let processed_token= processor.process(token)?; vocabulary = vocabulary.insert(processed_token, token_id); } @@ -107,7 +103,7 @@ impl Vocabulary { } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. - pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { + pub fn token_to_ids(&self, token: &Token) -> Option<&Vec> { self.tokens.get(token) } @@ -214,6 +210,18 @@ impl From>> for Vocabulary { } } +impl From>> for Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { + Vocabulary { + eos_token_id: None, + tokens: tokens + .into_iter() + .map(|(k,v)| (k.as_bytes().to_vec(), v)) + .collect::>>(), + } + } +} + impl FromIterator<(T, I)> for Vocabulary where T: Into, @@ -237,10 +245,10 @@ mod tests { .insert("0", 3); assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); + assert_eq!(vocabulary["blah".as_bytes()], &[0]); + assert_eq!(vocabulary["1a".as_bytes()], &[1]); + assert_eq!(vocabulary["2".as_bytes()], &[2]); + assert_eq!(vocabulary["0".as_bytes()], &[3]); } #[test] @@ -253,10 +261,10 @@ mod tests { ]); assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); + assert_eq!(vocabulary["blah".as_bytes()], &[0]); + assert_eq!(vocabulary["1a".as_bytes()], &[1]); + assert_eq!(vocabulary["2".as_bytes()], &[2]); + assert_eq!(vocabulary["0".as_bytes()], &[3]); } #[test] @@ -268,7 +276,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map = HashMap::default(); + let map: HashMap> = HashMap::default(); let vocabulary = Vocabulary::from(map); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty()); @@ -276,7 +284,7 @@ mod tests { #[test] fn new_vocabulary_from_iterator() { - let token: Token = "abc".to_string(); + let token: Token = "abc".as_bytes().to_vec(); let id: Vec = vec![1]; let it = vec![(token, id)]; let vocabulary = Vocabulary::from_iter(it); @@ -330,11 +338,12 @@ mod tests { ); let token = "Ġal"; - assert!(vocabulary.token_to_ids(token).is_none()); - assert!(tokenizer.token_to_id(token).is_some()); + let btoken = token.as_bytes().to_vec(); + assert!(vocabulary.token_to_ids(&btoken).is_none()); + assert!(tokenizer.token_to_id(&token).is_some()); for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { - let v_ids = vocabulary.token_to_ids(v_token); + let v_ids = vocabulary.token_to_ids(&v_token.as_bytes().to_vec()); assert!(v_ids.is_some()); for v_id in v_ids.unwrap() { let t_token = tokenizer @@ -361,24 +370,32 @@ mod tests { tokenizer.id_to_token(v_eos).expect("Token not found"), "" ); - - for (v_token, t_token_expected) in [ - ("abc", "abc"), - (" al", "▁al"), - (" O", "▁O"), - (" ", "▁▁▁"), - // TODO: won't pass since first we need to change token's type to bytes - // ("<0xFF>", "ÿ"), - // ("<0x20>", "▁"), - ] { - let v_ids = vocabulary.token_to_ids(v_token); + + let tests: &[(Vec, &[&str])] = &[ + ("abc".as_bytes().to_vec(), &["abc"]), + (" al".as_bytes().to_vec(), &["▁al"]), + (" O".as_bytes().to_vec(), &["▁O"]), + (" ".as_bytes().to_vec(), &["▁▁▁"]), + (" ".as_bytes().to_vec(), &["▁", "<0x20>"]), + ("a".as_bytes().to_vec(), &["a", "<0x61>"]), + (vec![0xFF], &["<0xFF>"]), + (vec![0x20], &["▁", "<0x20>"]), + ]; + for (v_token, t_tokens_expected) in tests { + let v_ids = vocabulary.token_to_ids(&v_token); assert!(v_ids.is_some()); - for v_id in v_ids.unwrap() { - let t_token = tokenizer - .id_to_token(*v_id) - .expect("Token id not found in tokenizer"); - assert_eq!(&t_token, t_token_expected); - } + + let t_tokens = v_ids.unwrap() + .iter() + .map(|v_id| { + tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer") + } + ) + .collect::>(); + let expected = HashSet::from_iter(t_tokens_expected.iter().map(|s| s.to_string())); + assert_eq!(t_tokens, expected) } } From f02faec858897c1f0ae90fcfdb0f40283905ba2d Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 18 Dec 2024 16:54:54 +0000 Subject: [PATCH 06/22] Drop majority of intermediate structures --- python/outlines_core/fsm/guide.py | 319 ---------- python/outlines_core/fsm/outlines_core_rs.pyi | 48 -- python/outlines_core/fsm/regex.py | 482 -------------- src/index.rs | 105 +--- src/lib.rs | 1 - src/prelude.rs | 2 +- src/primitives.rs | 7 +- src/python_bindings/mod.rs | 231 +------ src/regex.rs | 141 ----- tests/fsm/test_guide.py | 24 +- tests/fsm/test_regex.py | 587 ------------------ 11 files changed, 34 insertions(+), 1913 deletions(-) delete mode 100644 python/outlines_core/fsm/guide.py delete mode 100644 python/outlines_core/fsm/regex.py delete mode 100644 src/regex.rs delete mode 100644 tests/fsm/test_regex.py diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py deleted file mode 100644 index bb7561a..0000000 --- a/python/outlines_core/fsm/guide.py +++ /dev/null @@ -1,319 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Protocol, Set, Tuple, Union - -import interegular -import torch -from outlines_core.fsm.regex import ( - create_fsm_index_tokenizer, - make_byte_level_fsm, - make_deterministic_fsm, - reduced_vocabulary, -) - -from .outlines_core_rs import Index, Vocabulary - - -@dataclass(frozen=True) -class Write: - """Write instruction. - - Attributes - ---------- - tokens - The sequence of tokens to be added to the current sequence by the - generation process. - - """ - - tokens: List[int] - - -@dataclass(frozen=True) -class Generate: - """Generate instruction - - Attributes - ---------- - tokens - The tokens that lead to a valid completion if generated. A value - of ``None`` indicates that all tokens are allowed. - """ - - tokens: Optional[List[int]] - - -Instruction = Union[Write, Generate] - - -class Guide(Protocol): - """Base definition of a generation guide. - - A generation guide defines the behavior of a finite-state machine that guides - a text generation procedure. Unlike the DFAs built from regular expressions - guides can also emit a `Write` instructions which tells the model that it can - append a sequence of tokens (or token word) instead of generating it. - - """ - - initial_state: Any - - def get_next_instruction(self, state: Any) -> Instruction: - ... - - def get_next_state(self, state: Any, token_id: int) -> Any: - ... - - def is_final_state(self, state: Any) -> bool: - ... - - def copy(self) -> "Guide": - ... - - -class StopAtEOSGuide(Guide): - """Guide to generate tokens until the EOS token has been generated.""" - - final_state = 1 - start_state = 0 # TODO: remove start_state, use only initial_state - initial_state = 0 - - def __init__(self, tokenizer): - """Initialize the generation guide. - - model - The logit generator used to generate the next token. - - """ - self.eos_token_id = tokenizer.eos_token_id - self.vocabulary = tokenizer.vocabulary.values() - - def get_next_instruction(self, state: int) -> Instruction: - if self.is_final_state(state): - return Write([self.eos_token_id]) - return Generate(None) - - def get_next_state(self, state: int, token_id: int) -> int: - if token_id == self.eos_token_id or state == self.final_state: - return self.final_state - - return self.initial_state - - def is_final_state(self, state: int): - return state == self.final_state - - def copy(self): - return self - - -def create_states_mapping( - regex_string: str, - tokenizer, - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], -) -> Tuple[Index, Set[int], Set[int]]: - """Create the variables related to the mapping between states and tokens from a regex string. - - The parameters of the function are used for caching purpose. - - Parameters - ---------- - regex_string: - The regular expression string to generate a states mapping for. - tokenizer: - The model's tokenizer. - regex_parser: - A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: - A set of token ids that correspond to empty strings. - final_states: - A set of final states in the FSM. - """ - # regex_fsm = regex_parser(regex_string).to_fsm() - # return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) - - # inlining logic of create_fsm_index_tokenizer - tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) - vocabulary = Vocabulary.from_dict_with_eos_token_id( - tokens_to_token_ids, tokenizer.eos_token_id - ) - index = Index.from_regex(regex_string, vocabulary) - - return index, empty_token_ids, set(index.final_states()) - - -def create_states_mapping_from_fsm( - fsm: interegular.fsm.FSM, - tokenizer, - frozen_tokens: List[str] = [], -) -> Tuple[Index, Set[int], Set[int]]: - """Create the variables related to the mapping between states and tokens from an FSM. - - The parameters of the function are used for caching purpose. - - Parameters - ---------- - fsm: - An FSM for the regular expression. - tokenizer: - The model's tokenizer. - frozen_tokens: - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: - A set of token ids that correspond to empty strings. - final_states: - A set of final states in the FSM. - """ - byte_fsm = make_byte_level_fsm( - fsm.reduce(), keep_utf8=True, frozen_tokens=frozen_tokens - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - - -class RegexGuide(Guide): - """Guide to generate text in the language of a regular expression.""" - - initial_state = 0 - - def __init__( - self, states_to_token_maps, empty_token_ids, eos_tensor, initial_state - ): - self.states_to_token_maps = states_to_token_maps - self.empty_token_ids = empty_token_ids - self.eos_tensor = eos_tensor - self.initial_state = initial_state - - @classmethod - def from_regex( - cls, - regex_string: str, - tokenizer, - _create_states_mapping=create_states_mapping, - device=None, - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], - ): - ( - states_to_token_maps, - empty_token_ids, - fsm_finals, - ) = _create_states_mapping( - regex_string, - tokenizer, - regex_parser=regex_parser, - frozen_tokens=frozen_tokens, - ) - eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) - initial_state = states_to_token_maps.get_initial_state() - return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) - - @classmethod - def from_interegular_fsm( - cls, - interegular_fsm: interegular.fsm.FSM, - tokenizer, - _create_states_mapping_from_fsm=create_states_mapping_from_fsm, - device=None, - frozen_tokens: List[str] = [], - ): - ( - states_to_token_maps, - empty_token_ids, - fsm_finals, - ) = _create_states_mapping_from_fsm( - interegular_fsm, tokenizer, frozen_tokens=frozen_tokens - ) - eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) - initial_state = states_to_token_maps.get_initial_state() - return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) - - def get_next_instruction(self, state: int) -> Instruction: - """Return the next instruction for guided generation. - - The initialization of the guide builds an index which maps FSM states to a - map from authorized tokens to the state in which the guide needs to move - if said token is generated. Therefore the authorized tokens at the - current state are the keys of the map returned by the value of the index - for current state. - - If the current state is not contained in the end this means that we are - in a final state of the guide. We only authorize EOS tokens in the final - state. - - Parameters - ---------- - state - The current state of the guide. - - Returns - ------- - A `Generate` instance that contains the model and the allowed token ids. - - """ - if state == -1: - return Write(self.eos_tensor) - next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state) - # TODO: Create the Write and Generate objects within Rust instead? - if next_tokens_mask is None: - return Write(self.eos_tensor) - - return Generate(torch.tensor(next_tokens_mask)) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - We use the index to determine to which state the guide should transition - given the token that was just generated. - - Parameters - ---------- - state - The current state of the guide. - token_id - The id of the token that was just generated. - - Returns - ------- - The new state of the guide. - - """ - if state == -1: - return -1 - next_state = self.states_to_token_maps.get_next_state(state, token_id) - if next_state is None: - return -1 - else: - return next_state - - def is_final_state(self, state: int) -> bool: - """Determine whether the current state of the guide is a final state.""" - return state == -1 or self.states_to_token_maps.is_final_state(state) - - def copy(self): - return self - - def get_index_dict(self): - """Returns the Index as a Python Dict object.""" - return self.states_to_token_maps.get_transitions() diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 81181ed..c7a3700 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -1,57 +1,9 @@ from typing import Dict, List, Optional, Set, Tuple -class FSMInfo: - initial: int - finals: Set[int] - transitions: Dict[Tuple[int, int], int] - alphabet_anything_value: int - alphabet_symbol_mapping: Dict[str, int] - - def __init__( - self, - initial: int, - finals: Set[int], - transitions: Dict[Tuple[int, int], int], - alphabet_anything_value: int, - alphabet_symbol_mapping: Dict[str, int], - ) -> None: ... - def build_regex_from_schema( json: str, whitespace_pattern: Optional[str] = None ) -> str: ... def to_regex(json: Dict, whitespace_pattern: Optional[str] = None) -> str: ... -def _walk_fsm( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - token_transition_keys: List[int], - start_state: int, - full_match: bool, -) -> List[int]: ... -def state_scan_tokens( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - vocabulary: Vocabulary, - vocabulary_transition_keys: Dict[str, List[int]], - start_state: int, -) -> Set[Tuple[int, int]]: ... -def get_token_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - token_str: str, -) -> List[int]: ... -def get_vocabulary_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - vocabulary: Vocabulary, - frozen_tokens: Set[str], -) -> Dict[str, List[int]]: ... -def create_fsm_index_end_to_end( - fsm_info: FSMInfo, - vocabulary: Vocabulary, - frozen_tokens: frozenset[str], -) -> Dict[int, Dict[int, int]]: ... BOOLEAN: str DATE: str diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py deleted file mode 100644 index e4b93b7..0000000 --- a/python/outlines_core/fsm/regex.py +++ /dev/null @@ -1,482 +0,0 @@ -import re -from functools import lru_cache -from typing import ( - Dict, - FrozenSet, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, -) - -from interegular.fsm import ( - FSM, - Alphabet, - State, - TransitionKey, - _AnythingElseCls, - anything_else, -) - -from .outlines_core_rs import ( # noqa: F401 - FSMInfo, - Index, - Vocabulary, - _walk_fsm, - create_fsm_index_end_to_end, - get_token_transition_keys, - get_vocabulary_transition_keys, - state_scan_tokens, -) - - -class BetterAlphabet(Alphabet): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert anything_else in self._symbol_mapping - self.anything_value = self._symbol_mapping[anything_else] - - def __getitem__(self, item): - return self._symbol_mapping.get(item, self.anything_value) - - def copy(self): - return BetterAlphabet(self._symbol_mapping.copy()) - - -class BetterFSM(FSM): - flat_transition_map: Dict[Tuple[int, int], int] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not isinstance(self.alphabet, BetterAlphabet): - self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) - - flat_transition_map = {} - for from_state, trans_map in self.map.items(): - for trans_key, to_state in trans_map.items(): - flat_transition_map[(from_state, trans_key)] = to_state - - self.__dict__["flat_transition_map"] = flat_transition_map - self.__dict__["_fsm_info"] = None - - def copy(self): - return BetterFSM( - alphabet=self.alphabet.copy(), - states=self.states.copy(), - initial=self.initial, - finals=self.finals.copy(), - map=self.map.copy(), - __no_validation__=True, - ) - - @property - def fsm_info(self): - if self._fsm_info is None: - anything_value = self.alphabet.anything_value - self.__dict__["_fsm_info"] = FSMInfo( - self.initial, - self.finals, - self.flat_transition_map, - anything_value, - # TODO FIXME: Perform this conversion in Rust? - { - k: v - for k, v in self.alphabet._symbol_mapping.items() - if not isinstance(k, _AnythingElseCls) - }, - ) - - return self._fsm_info - - -TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"] - - -def add_to_transition_trie( - trie: TransitionTrie, - key_seq: Sequence[TransitionKey], - value: Union[State, None], -): - for key in key_seq[:-1]: - trie = cast(TransitionTrie, trie.setdefault(key, {})) - assert isinstance(trie, dict), "key sequence of incompatible length" - trie[key_seq[-1]] = value - - -# merge default_trie into the trie, only updating entries not present in the trie -def transition_trie_setdefault( - trie: TransitionTrie, - default_trie: TransitionTrie, -): - for key, default_value in default_trie.items(): - dest_value = trie.get(key) - if isinstance(dest_value, dict) and isinstance(default_value, dict): - transition_trie_setdefault(dest_value, default_value) - elif key not in trie: - trie[key] = default_value - - -def byte_symbol(byte: int) -> str: - return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) - - -def make_byte_level_fsm( - fsm: FSM, keep_utf8: bool = False, frozen_tokens: List[str] = [] -) -> FSM: - """Convert an FSM to a byte-level FSM, expanding multi-byte characters as - sequences of single-byte transitions. - - Parameters - ---------- - fsm: (`interegular.FSM`): - The token-level FSM to convert to a byte-level FSM. - keep_utf8: (`bool`, *optional*): - If set to True, the original utf-8 characters are kept as-is. Defaults to - False. NOTE: we're representing bytes as strings to keep it type-compatible. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is in the byte-level FSM. That is, - these tokens will not be expanded into byte-level transitions. Defaults to - an empty list. - - Returns - ------- - `interegular.FSM`: A byte-level FSM. - """ - - anything_else_key = fsm.alphabet[anything_else] - symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey] = {} - map: Dict[State, Dict[TransitionKey, State]] = {} - states: List[State] = list(fsm.states) - - # identify all multi-byte characters in the alphabet and build a mapping - # from the original transition keys to sequences of new keys for each byte - key_to_key_seqs: Dict[TransitionKey, Set[Tuple[TransitionKey, ...]]] = {} - all_key_seqs: Set[Tuple[TransitionKey, ...]] = set() - all_bytes: Set[int] = set() - max_key = max(fsm.alphabet.values()) - for symbol, transition_key in fsm.alphabet.items(): - assert symbol == anything_else or symbol in frozen_tokens or len(symbol) == 1 - if symbol == anything_else or symbol in frozen_tokens or ord(symbol) < 0x80: - symbol_mapping[symbol] = transition_key - else: - if keep_utf8: - symbol_mapping[symbol] = transition_key - key_list: List[TransitionKey] = [] - for byte in symbol.encode("utf-8"): - symbol = byte_symbol(byte) - if symbol not in symbol_mapping: - symbol_mapping[symbol] = max_key = TransitionKey(max_key + 1) - all_bytes.add(byte) - key_list.append(symbol_mapping[symbol]) - key_seq = tuple(key_list) - key_to_key_seqs.setdefault(transition_key, set()).add(key_seq) - all_key_seqs.add(key_seq) - - # add all remaining multi-byte utf-8 bytes to the alphabet - # (this is required to represent `anything_else`) - utf8_ranges = { - 1: (0x80, 0xC0), # continuation bytes - 2: (0xC0, 0xE0), # 2-byte sequences - 3: (0xE0, 0xF0), # 3-byte sequences - 4: (0xF0, 0xF8), # 4-byte sequences - } - utf8_all_keys: Dict[int, Set[TransitionKey]] = { - n: set() for n in utf8_ranges.keys() - } - for n, (start, end) in utf8_ranges.items(): - range_key = max_key = TransitionKey(max_key + 1) - for byte in range(start, end): - byte_key = symbol_mapping.setdefault(byte_symbol(byte), range_key) - utf8_all_keys[n].add(byte_key) - - # cache of intermediate transition states by transitions from that state - state_cache: Dict[FrozenSet[Tuple[TransitionKey, State]], State] = {} - - # helper function to create multi-step transitions between states - max_state = max(fsm.states) - - def create_seq_transitions( - seq_transitions_trie: TransitionTrie, - ) -> Dict[TransitionKey, State]: - nonlocal max_state - result: Dict[TransitionKey, State] = {} - - for next_key, next_trie in seq_transitions_trie.items(): - if isinstance(next_trie, dict): - next_transitions = create_seq_transitions(next_trie) - if not next_transitions: - continue - cache_key = frozenset(next_transitions.items()) - next_state = state_cache.get(cache_key) - if next_state is None: - next_state = max_state = State(max_state + 1) - map[next_state] = next_transitions - state_cache[cache_key] = next_state - states.append(next_state) - result[next_key] = next_state - elif next_trie is not None: - result[next_key] = next_trie - - return result - - # create new states and transitions - for state, transitions in fsm.map.items(): - seq_transitions_trie: TransitionTrie = {} - state_map: Dict[TransitionKey, State] = {} - - for transition_key, to_state in transitions.items(): - if transition_key in key_to_key_seqs: - if keep_utf8: - state_map[transition_key] = to_state - for key_seq in key_to_key_seqs[transition_key]: - add_to_transition_trie(seq_transitions_trie, key_seq, to_state) - else: # keep single-byte transitions as is - state_map[transition_key] = to_state - - # handle multi-byte anything_else sequences - if anything_else_key in transitions: - for key_seq in all_key_seqs: - add_to_transition_trie(seq_transitions_trie, key_seq, None) - - anything_else_trie: TransitionTrie = {} - cont_trie: Union[TransitionTrie, State] = transitions[anything_else_key] - for n in range(2, 5): - cont_trie = {key: cont_trie for key in utf8_all_keys[1]} - for key in utf8_all_keys[n]: - anything_else_trie[key] = cont_trie - - transition_trie_setdefault(seq_transitions_trie, anything_else_trie) - - # create new states and transitions - next_transitions = create_seq_transitions(seq_transitions_trie) - state_map.update(next_transitions) - map[state] = state_map - - return FSM( - alphabet=Alphabet(symbol_mapping), - states=states, - initial=fsm.initial, - finals=fsm.finals, - map=map, - ) - - -def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: - """Construct an equivalent FSM with deterministic state labels.""" - old_to_new_trans_keys = { - trans_key: i - for i, (trans_key, _) in enumerate( - sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) - ) - } - - new_symbol_mapping = { - symbol: old_to_new_trans_keys[trans_key] - for symbol, trans_key in fsm.alphabet._symbol_mapping.items() - } - - new_alphabet = BetterAlphabet(new_symbol_mapping) - - new_map = { - from_state: { - old_to_new_trans_keys[trans_key]: to_state - for trans_key, to_state in trans_map.items() - } - for from_state, trans_map in fsm.map.items() - } - - old_to_new_states = {} - old_to_new_states[fsm.initial] = 0 - - i = 0 - seen = {fsm.initial} - old_state_queue = [fsm.initial] - while old_state_queue: - old_state = old_state_queue.pop(-1) - transitions = new_map[old_state] - sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) - for _, old_state in sorted_transitions: - if old_state not in seen: - old_state_queue.append(old_state) - seen.add(old_state) - if old_state not in old_to_new_states: - i += 1 - old_to_new_states[old_state] = i - - new_map = dict( - sorted( - ( - ( - old_to_new_states[from_state], - dict( - sorted( - ( - (trans_key, old_to_new_states[to_state]) - for trans_key, to_state in trans_map.items() - ), - key=lambda v: v[0], - ) - ), - ) - for from_state, trans_map in new_map.items() - ), - key=lambda v: v[0], - ) - ) - - new_initial = 0 - new_finals = frozenset( - sorted(old_to_new_states[old_state] for old_state in fsm.finals) - ) - new_states = frozenset(sorted(new_map.keys())) - - new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) - - return new_fsm, old_to_new_states - - -re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") - -# The ".?" prefix and suffix is to handle special cases in some model vocabularies. This -# includes Gemma models (which use "▁�" as a token), NorwAI models (which use ".�" as a -# token), Salamandra models (which use ".�" and "�?" as tokens) and OpenCoder models -# (which use "�s" as a token). -re_replacement_seq = re.compile(r"^.?�+.?$") - - -# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode -@lru_cache() -def gpt2_bytes_to_unicode(): - """ - Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control - characters the bpe code barfs on. - - The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab - if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for - decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup - tables between utf-8 bytes and unicode strings. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -@lru_cache() -def gpt2_unicode_to_bytes(): - return {v: k for k, v in gpt2_bytes_to_unicode().items()} - - -@lru_cache -def reduced_vocabulary( - tokenizer, -) -> Tuple[Dict[str, List[int]], Set[int]]: - """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" - # TODO FIXME: See if we can get the underlying Rust tokenizers from HF and - # do all this in Rust - empty_token_ids = set() - vocabulary: Dict[str, List[int]] = {} - for token, token_idx in tokenizer.vocabulary.items(): - if token in tokenizer.special_tokens: - continue - - token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( - token - ) - - if token_str: - if isinstance(token, bytes): - # Handle BPE tokenizers where the tokens are directly stored as bytes - # https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md#regular-tokens - token_str = "".join(byte_symbol(b) for b in token) - - elif "\ufffd" in token_str and not re_replacement_seq.match(token): - # invalid utf-8 sequences are replaced with � (\ufffd), but there - # might also be tokens specifically for �, ��, ���, etc. - - if re_llama_byte_token.match(token): - # llama-like tokenizers have <0xXX> tokens for all - # bytes >= 0x80 and represent all incomplete utf-8 - # sequences using such tokens - token_bytes = [int(token[3:5], 16)] - else: - # gpt2-like tokenizers have multi-byte tokens that can - # have a mix of full and incomplete utf-8 characters, - # for example, b` \xf0` can be one token; these tokenizers - # map each byte to a valid utf-8 character - token_bytes = cast( - List[int], [gpt2_unicode_to_bytes().get(c) for c in token] - ) - if None in token_bytes: - raise RuntimeError( - f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" - ) - token_str = "".join(byte_symbol(b) for b in token_bytes) - - assert isinstance(token_str, str) - - vocabulary.setdefault(token_str, []).append(token_idx) - else: - empty_token_ids.add(token_idx) - - return vocabulary, empty_token_ids - - -def create_fsm_index_tokenizer( - fsm: BetterFSM, - tokenizer, - frozen_tokens: Optional[Iterable[str]] = None, -) -> Tuple[Index, Set[int]]: - """Construct an FMS index from a tokenizer. - - This uses the end-to-end approach of `create_fsm_index_end_to_end`. - - Parameters - ---------- - fsm: (`BetterFSM`): - A cache-friendly FSM. Other interegular FSMs can also be used, but caching - may not work as expected. - tokenizer: (`Tokenizer`): - The model's tokenizer. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is when expanding the token-level - FSM into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): - A set of token ids that correspond to empty strings. - - .. warning:: - - `fsm` needs to be deterministically ordered so that future caching makes sense. - """ - tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) - - states_to_token_subsets = Index( # type: ignore - fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - tokenizer.eos_token_id, - frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(), - ) - - return states_to_token_subsets, empty_token_ids diff --git a/src/index.rs b/src/index.rs index 3df6e74..81ca852 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,6 +1,5 @@ /// Construct an Index. use crate::prelude::*; -use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; use bincode::{Decode, Encode}; @@ -9,33 +8,6 @@ use regex_automata::util::primitives::StateID as AutomataStateId; use regex_automata::Anchored; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -#[derive(Debug)] -pub struct FSMInfo { - pub(crate) initial: State, - pub(crate) finals: HashSet, - pub(crate) transitions: HashMap<(State, TransitionKey), State>, - pub(crate) alphabet_anything_value: TransitionKey, - pub(crate) alphabet_symbol_mapping: HashMap, -} - -impl FSMInfo { - pub fn new( - initial: State, - finals: HashSet, - transitions: HashMap<(State, TransitionKey), State>, - alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: HashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } -} - #[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, @@ -45,70 +17,7 @@ pub struct Index { } impl Index { - pub fn new( - fsm_info: &FSMInfo, - vocabulary: &Vocabulary, - eos_token_id: u32, - frozen_tokens: HashSet, - ) -> Result { - let mut states_to_token_subsets: HashMap> = HashMap::default(); - let mut seen: HashSet = HashSet::default(); - let mut next_states: HashSet = HashSet::from_iter([fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - let token_ids_end_states = state_scan_tokens( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - vocabulary, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in &token_ids_end_states { - let inner_map = states_to_token_subsets.entry(start_state).or_default(); - inner_map.insert(*token_id, *end_state); - - if !seen.contains(end_state) { - next_states.insert(*end_state); - } - } - - if fsm_info.finals.contains(&start_state) && !token_ids_end_states.is_empty() { - let inner_map = states_to_token_subsets.entry(start_state).or_default(); - inner_map.insert(eos_token_id, start_state); - } - - seen.insert(start_state); - } - - let is_valid = states_to_token_subsets - .values() - .flat_map(|token_id_end_states| token_id_end_states.values()) - .any(|end_state| fsm_info.finals.contains(end_state)); - - if is_valid { - Ok(Self { - initial: fsm_info.initial, - finals: fsm_info.finals.clone(), - states_to_token_subsets, - eos_token_id, - }) - } else { - Err(Error::InsufficientVocabulary) - } - } - - pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result { + pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result { let eos_token_id = match vocabulary.eos_token_id() { Some(s) => s, // TODO: this error will be removed once eos_token_id for vocabulary won't be optional @@ -121,8 +30,8 @@ impl Index { None => return Err(Error::DfaHasNoStartState), }; - let mut transitions: HashMap> = HashMap::default(); - let mut final_states: HashSet = HashSet::default(); + let mut transitions: HashMap> = HashMap::default(); + let mut final_states: HashSet = HashSet::default(); let mut seen: HashSet = HashSet::from_iter([start_state]); let mut next_states: Vec = vec![start_state]; @@ -210,7 +119,7 @@ impl Index { self.finals.contains(&state) } - pub(crate) fn final_states(&self) -> &HashSet { + pub(crate) fn final_states(&self) -> &HashSet { &self.finals } @@ -232,7 +141,7 @@ mod tests { .insert("2", 2) .insert("0", 3); - let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + let index = Index::new(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); @@ -253,7 +162,7 @@ mod tests { .insert(".", 102) .insert("`", 101); - let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + let index = Index::new(regex, &vocabulary).expect("Index failed"); let allowed = index .allowed_tokens(index.initial()) .expect("No allowed tokens"); @@ -273,7 +182,7 @@ mod tests { .insert(vec![32, 240, 159, 152, 141], 6) .insert(vec![240, 159, 152, 141], 4); - let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + let index = Index::new(regex, &vocabulary).expect("Index failed"); assert_eq!(index.final_states(), &HashSet::from_iter([208, 128])); diff --git a/src/lib.rs b/src/lib.rs index cb56f86..538152f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ pub mod index; pub mod json_schema; pub mod prelude; pub mod primitives; -pub mod regex; pub mod vocabulary; pub use error::{Error, JsonSchemaParserError, Result}; diff --git a/src/prelude.rs b/src/prelude.rs index d42516b..27236b9 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,4 +1,4 @@ pub use super::{ - primitives::{State, Token, TokenId, TransitionKey}, + primitives::{StateId, Token, TokenId}, vocabulary::Vocabulary, }; diff --git a/src/primitives.rs b/src/primitives.rs index 0976f76..5691bf8 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -1,11 +1,8 @@ -/// Interegular transition key. -pub type TransitionKey = u32; - /// Token content. pub type Token = Vec; /// Token identifier. pub type TokenId = u32; -/// Interegular state. -pub type State = u32; +/// State. +pub type StateId = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 2db679a..643980d 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,10 +1,6 @@ -use crate::index::{FSMInfo, Index}; +use crate::index::Index; use crate::json_schema; use crate::prelude::*; -use crate::regex::get_token_transition_keys; -use crate::regex::get_vocabulary_transition_keys; -use crate::regex::state_scan_tokens; -use crate::regex::walk_fsm; use bincode::config; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -13,65 +9,6 @@ use pyo3::wrap_pyfunction; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; -#[pyclass(name = "FSMInfo")] -pub struct PyFSMInfo { - #[pyo3(get)] - initial: State, - #[pyo3(get)] - finals: HashSet, - #[pyo3(get)] - transitions: HashMap<(State, TransitionKey), State>, - #[pyo3(get)] - alphabet_anything_value: TransitionKey, - #[pyo3(get)] - alphabet_symbol_mapping: HashMap, -} - -impl From for PyFSMInfo { - fn from(fsm_info: FSMInfo) -> Self { - PyFSMInfo { - initial: fsm_info.initial, - finals: fsm_info.finals, - transitions: fsm_info.transitions, - alphabet_anything_value: fsm_info.alphabet_anything_value, - alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping, - } - } -} - -// FIXME: could be costly, confirm if FSMInfo will actually be part of the interface -impl From<&PyFSMInfo> for FSMInfo { - fn from(fsm_info: &PyFSMInfo) -> Self { - FSMInfo { - initial: fsm_info.initial, - finals: fsm_info.finals.clone(), - transitions: fsm_info.transitions.clone(), - alphabet_anything_value: fsm_info.alphabet_anything_value, - alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping.clone(), - } - } -} - -#[pymethods] -impl PyFSMInfo { - #[new] - fn new( - initial: State, - finals: HashSet, - transitions: HashMap<(State, TransitionKey), State>, - alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: HashMap, - ) -> Self { - FSMInfo::new( - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - ) - .into() - } -} #[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] pub struct PyIndex(Index); @@ -79,24 +16,9 @@ pub struct PyIndex(Index); #[pymethods] impl PyIndex { #[new] - fn new( - py: Python<'_>, - fsm_info: &PyFSMInfo, - vocabulary: &PyVocabulary, - eos_token_id: u32, - frozen_tokens: HashSet, - ) -> PyResult { - py.allow_threads(|| { - Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) - .map(PyIndex) - .map_err(Into::into) - }) - } - - #[staticmethod] - fn from_regex(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { + fn new(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { py.allow_threads(|| { - Index::from_regex(regex, &vocabulary.0) + Index::new(regex, &vocabulary.0) .map(PyIndex) .map_err(Into::into) }) @@ -135,7 +57,7 @@ impl PyIndex { self.0.is_final(state) } - fn final_states(&self) -> HashSet { + fn final_states(&self) -> HashSet { self.0.final_states().clone() } @@ -166,153 +88,31 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR .map_err(|e| PyValueError::new_err(e.to_string())) } -#[pyfunction(name = "_walk_fsm")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" -)] -pub fn walk_fsm_py( - fsm_transitions: HashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: HashSet, - token_transition_keys: Vec, - start_state: State, - full_match: bool, -) -> PyResult> { - Ok(walk_fsm( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &token_transition_keys, - start_state, - full_match, - )) -} - -#[pyfunction(name = "state_scan_tokens")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" -)] -pub fn state_scan_tokens_py( - fsm_transitions: HashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: HashSet, - vocabulary: &PyVocabulary, - vocabulary_transition_keys: HashMap>, - start_state: State, -) -> PyResult> { - Ok(state_scan_tokens( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &vocabulary.0, - &vocabulary_transition_keys, - start_state, - )) -} - -#[pyfunction(name = "get_token_transition_keys")] -#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: TransitionKey, - token_str: String, -) -> PyResult> { - Ok(get_token_transition_keys( - &alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - )) -} - -#[pyfunction(name = "get_vocabulary_transition_keys")] -#[pyo3( - text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" -)] -pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: TransitionKey, - vocabulary: &PyVocabulary, - frozen_tokens: HashSet, -) -> PyResult>> { - Ok(get_vocabulary_transition_keys( - &alphabet_symbol_mapping, - alphabet_anything_value, - &vocabulary.0, - &frozen_tokens, - )) -} - -#[pyfunction(name = "create_fsm_index_end_to_end")] -#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end_py<'py>( - py: Python<'py>, - fsm_info: &PyFSMInfo, - vocabulary: &PyVocabulary, - frozen_tokens: HashSet, -) -> PyResult> { - let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::default(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - &vocabulary.0, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - &vocabulary.0, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { - existing_dict.set_item(token_id, end_state)?; - } else { - let new_dict = PyDict::new_bound(py); - new_dict.set_item(token_id, end_state)?; - states_to_token_subsets.set_item(start_state, new_dict)?; - } - - if !seen.contains(&end_state) { - next_states.insert(end_state); - } - } - - seen.insert(start_state); - } - - Ok(states_to_token_subsets) -} - #[pyclass(name = "Vocabulary")] pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(map: HashMap>) -> PyVocabulary { + fn from_dict(map: HashMap>) -> PyVocabulary { PyVocabulary(Vocabulary::from(map)) } #[staticmethod] fn from_dict_with_eos_token_id( - map: HashMap>, + map: HashMap>, eos_token_id: TokenId, ) -> PyVocabulary { let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); PyVocabulary(v) } + #[staticmethod] + fn from_pretrained(model: String) -> PyResult { + let v = Vocabulary::from_pretrained(model.as_str(), None)?; + Ok(PyVocabulary(v)) + } + fn __repr__(&self) -> String { format!("{:#?}", self.0) } @@ -324,12 +124,6 @@ impl PyVocabulary { #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; - m.add("BOOLEAN", json_schema::BOOLEAN)?; m.add("DATE", json_schema::DATE)?; m.add("DATE_TIME", json_schema::DATE_TIME)?; @@ -349,7 +143,6 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; Ok(()) } diff --git a/src/regex.rs b/src/regex.rs deleted file mode 100644 index c9270b6..0000000 --- a/src/regex.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::prelude::*; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; - -pub fn walk_fsm( - fsm_transitions: &HashMap<(State, TransitionKey), State>, - _fsm_initial: State, - fsm_finals: &HashSet, - token_transition_keys: &[TransitionKey], - start_state: State, - full_match: bool, -) -> Vec { - let mut state = start_state; - let mut accepted_states = Vec::new(); - let mut last_final_idx = 0; - - for (i, &trans_key) in token_transition_keys.iter().enumerate() { - match fsm_transitions.get(&(state, trans_key)) { - Some(&new_state) => { - state = new_state; - if fsm_finals.contains(&state) { - last_final_idx = i + 1; - } - accepted_states.push(state); - } - None => { - if !full_match && last_final_idx > 0 { - return accepted_states[..last_final_idx].to_vec(); - } - return Vec::new(); - } - } - } - - if full_match && last_final_idx != token_transition_keys.len() { - return Vec::new(); - } - - accepted_states -} - -pub fn state_scan_tokens( - fsm_transitions: &HashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: &HashSet, - vocabulary: &Vocabulary, - vocabulary_transition_keys: &HashMap>, - start_state: State, -) -> HashSet<(TokenId, State)> { - let mut res = HashSet::default(); - - for (token, token_ids) in vocabulary.iter() { - let token_transition_keys = &vocabulary_transition_keys[token]; - let state_seq = walk_fsm( - fsm_transitions, - fsm_initial, - fsm_finals, - token_transition_keys, - start_state, - false, - ); - - if state_seq.len() < token_transition_keys.len() { - continue; - } - - for &token_id in token_ids { - res.insert((token_id, *state_seq.last().unwrap())); - } - } - - res -} - -pub fn get_token_transition_keys( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: TransitionKey, - token_str: &str, -) -> Vec { - let mut token_transition_keys = Vec::new(); - let mut i = 0; - let chars: Vec = token_str.chars().collect(); - - while i < chars.len() { - let symbol; - if chars[i] == '\0' && i != chars.len() - 1 { - if i + 2 < chars.len() { - symbol = format!("\0{}{}", chars[i + 1], chars[i + 2]); - i += 3; - } else { - symbol = chars[i].to_string(); - i += 1; - } - } else { - symbol = chars[i].to_string(); - i += 1; - } - - let transition_key = *alphabet_symbol_mapping - .get(&symbol) - .unwrap_or(&alphabet_anything_value); - token_transition_keys.push(transition_key); - } - - token_transition_keys -} - -pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: TransitionKey, - vocabulary: &Vocabulary, - frozen_tokens: &HashSet, -) -> HashMap> { - let mut vocab_transition_keys = HashMap::default(); - - for item in vocabulary.iter() { - let token_str = item.0.clone(); - - let mut token_transition_keys; - - // Since these tokens are not expanded into byte-level transitions, we - // can simply get their transition keys directly. - if frozen_tokens.contains(&token_str) { - token_transition_keys = Vec::new(); - token_transition_keys.push( - *alphabet_symbol_mapping - .get(&token_str) - .unwrap_or(&alphabet_anything_value), - ) - } else { - token_transition_keys = get_token_transition_keys( - alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - ); - } - - vocab_transition_keys.insert(token_str, token_transition_keys); - } - - vocab_transition_keys -} diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 905bfde..9bae619 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -48,7 +48,7 @@ def convert_token_to_string(self, token): def test_from_regex(): class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} + vocabulary = {"1": [1], "a": [2], "eos": [3]} special_tokens = {"eos"} eos_token_id = 3 @@ -152,17 +152,17 @@ def convert_token_to_string(self, token): def test_regex_multi_byte_gpt2_like(): class MockTokenizer: vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - " ": 5, - "\ufffd": 6, - "\ufffd\ufffd": 7, - "ðŁĺ": 8, - "Ī": 9, # '😈' - "Ġð": 10, - "ŁĺĪ": 11, # ' 😈' + "1": [1], + "a": [2], + "eos": [3], + "😍": [4], + " ": [5], + "\ufffd": [6], + "\ufffd\ufffd": [7], + "ðŁĺ": [8], + "Ī": [9], # '😈' + "Ġð": [10], + "ŁĺĪ": [11], # ' 😈' } special_tokens = {"eos"} eos_token_id = 3 diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py deleted file mode 100644 index 1007d8b..0000000 --- a/tests/fsm/test_regex.py +++ /dev/null @@ -1,587 +0,0 @@ -from typing import List, Tuple, Union - -import interegular -import pytest -import torch -from datasets.fingerprint import Hasher -from outlines_core.fsm.outlines_core_rs import Vocabulary -from outlines_core.fsm.regex import ( - BetterAlphabet, - BetterFSM, - _walk_fsm, - create_fsm_index_end_to_end, - create_fsm_index_tokenizer, - get_token_transition_keys, - get_vocabulary_transition_keys, - make_byte_level_fsm, - make_deterministic_fsm, - reduced_vocabulary, -) -from transformers import AutoTokenizer, PreTrainedTokenizer - - -def get_llama_tokenizer_types(): - """Get all the Llama tokenizer types/classes that need work-arounds. - - When they can't be imported, a dummy class is created. - - """ - try: - from transformers.models.llama import LlamaTokenizer - except ImportError: - - class LlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.llama import LlamaTokenizerFast - except ImportError: - - class LlamaTokenizerFast: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizer - except ImportError: - - class CodeLlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizerFast - except ImportError: - - class CodeLlamaTokenizerFast: # type: ignore - pass - - return ( - LlamaTokenizer, - LlamaTokenizerFast, - CodeLlamaTokenizer, - CodeLlamaTokenizerFast, - ) - - -class TransformerTokenizer: - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs): - self.tokenizer = tokenizer - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: torch.LongTensor) -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __hash__(self): - return hash(Hasher.hash(self.tokenizer)) - - def __eq__(self, other): - if isinstance(other, type(self)): - if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return ( - other.model_name == self.model_name and other.kwargs == self.kwargs - ) - else: - return other.tokenizer == self.tokenizer - return NotImplemented - - def __getstate__(self): - state = {"tokenizer": self.tokenizer} - return state - - def __setstate__(self, state): - self.__init__(state["tokenizer"]) - - -def identity(s): - return s - - -def to_bytes(s): - return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] - - -def merge_symbols(byte_hexs): - return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) - - -def token_str_to_trans_key(fsm, input_string): - return get_token_transition_keys( - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, - input_string, - ) - - -def walk_fsm_from_token_str_rust( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return _walk_fsm( - fsm.fsm_info.transitions, - fsm.fsm_info.initial, - fsm.fsm_info.finals, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match=full_match, - ) - - -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - -def test_walk_fsm(): - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "0", regex_fsm.initial, full_match=True) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=False - ) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "!", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - # This should fail, because state `1` reads nothing - res = tuple(walk_fsm_from_token_str_rust(regex_fsm, "0", 1, full_match=True)) - assert res == tuple() - - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "1", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "1", regex_fsm.initial, full_match=False - ) - ) - assert res == (2,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "12", regex_fsm.initial, full_match=True - ) - ) - assert res == (2, 3) - - pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") - fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - - res = tuple(walk_fsm_from_token_str_rust(fsm, "x ", fsm.initial, full_match=False)) - assert res == (2,) - - start_state = list(fsm.finals)[0] - res = tuple(walk_fsm_from_token_str_rust(fsm, "!", start_state, full_match=False)) - assert res == tuple() - - -@pytest.mark.parametrize( - "transform", - [ - identity, - to_bytes, - ], -) -def test_walk_fsm_multi_bytes(transform): - regex_pattern = interegular.parse_pattern("😂|[😇-😍][😈-😍]*") - str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=False, - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=True, - ) - ) - assert res == tuple() - - -def test_create_fsm_index_end_to_end(): - regex_str = "0|[1-9][0-9]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - tokens_to_token_ids = { - "blah": [0], - "1a": [1], - "2": [2], - "0": [3], - "": [4], - } - - res = create_fsm_index_end_to_end( - regex_fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - assert res == {0: {2: 2, 3: 1}, 2: {2: 2, 3: 2}} - - -def test_create_fsm_index_end_to_end_multi_byte(): - regex_str = "😇| [😈-😍][😇-😎]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - tokens_to_token_ids = { - "blah": [0], - "😈a": [1], - "😇": [2], - "😍": [3], - merge_symbols(("F0", "9F", "98", "8D")): [4], # '😍' - " 😍": [5], - merge_symbols((" ", "F0", "9F", "98", "8D")): [6], # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): [7], # ' 😍' incomplete - "": [8], - } - - res = create_fsm_index_end_to_end( - byte_fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - assert res == {0: {5: 3, 6: 3, 7: 7, 2: 2}, 3: {2: 3, 3: 3, 4: 3}} - - -@pytest.mark.parametrize( - "hf_tokenizer_uri, revision", - [ - ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), - ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), - ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), - ( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "783fd50eb82d7f57758de033861f54d62dde234f", - ), - ], -) -def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - num_bytes_fsm_states = len(bytes_fsm.states) - assert num_bytes_fsm_states == 235 - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) - tokenizer = TransformerTokenizer(tokenizer) - - states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( - bytes_fsm, tokenizer - ) - - assert not empty_token_ids - assert len(states_to_token_subsets.get_transitions()) / num_fsm_states > 0.94 - - -@pytest.mark.parametrize( - "regex,string,should_accept", - [ - ("[a-c]+", "😀", False), - ("[^a-c]+", "😀", True), - ("😀+", "😀😀😀", True), - ("😀+", "a", False), - ("[😀-😍]{2}", "😈😈", True), - ("[😀-😍]{2}", "aa", False), - ("[^😀-😍]{2}", "aa", True), - ("[^😀-😍]{2}", "😈😈", False), - ("[^😀-😍]{2}", "😎😎", True), - ("[^😀-😍]{2}", "😎😓", True), - ("[^😀-😍]{2}", "😎😈", False), - ("[😀-🙌]{2}", "😎😈", True), - ("[^😀-🙌]{2}", "😎😈", False), - ("[^😀-🙌]{2}", "🙏🙏", True), - ("[^😀-🙌]{2}", "🙏😎", False), - ], -) -def test_make_byte_level_fsm(regex, string, should_accept): - str_fsm = interegular.parse_pattern(regex).to_fsm() - str_accepts = str_fsm.accepts(string) - assert str_accepts == should_accept - - byte_fsm = make_byte_level_fsm(str_fsm) - byte_accepts = byte_fsm.accepts(to_bytes(string)) # type: ignore - assert byte_accepts == str_accepts - - mix_fsm = make_byte_level_fsm(str_fsm, keep_utf8=True) - mix_accepts = mix_fsm.accepts(to_bytes(string)) # type: ignore - assert mix_accepts == str_accepts - - mix_accepts_utf8 = mix_fsm.accepts(string) # type: ignore - assert mix_accepts_utf8 == str_accepts - - def advance(fsm, state, seq): - for symbol in seq: - if state is None: - return None - key = fsm.alphabet[symbol] - state = fsm.map[state].get(key) - return state - - # verify each state along the pattern - str_state = str_fsm.initial - byte_state = byte_fsm.initial - mix_state = byte_fsm.initial - for symbol in string: - str_state = advance(str_fsm, str_state, symbol) - byte_state = advance(byte_fsm, byte_state, to_bytes(symbol)) - mix_state_utf8 = advance(mix_fsm, mix_state, symbol) - mix_state = advance(mix_fsm, mix_state, to_bytes(symbol)) - assert byte_state == str_state - assert mix_state == str_state - assert mix_state_utf8 == str_state - - -@pytest.mark.skip(reason="Only for local profiling") -def test_regex_index_performance(): - from line_profiler import LineProfiler # type: ignore [import] - - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TransformerTokenizer(tokenizer) - - res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) - assert len(res) > 1 - - profiler = LineProfiler(create_fsm_index_end_to_end) - - profiler.runctx( - "create_fsm_index_tokenizer(regex_fsm, tokenizer)", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-create_fsm_index.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -def test_token_trans_keys_identical(): - """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" - - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"z[ab]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - tokens_to_token_ids, _ = reduced_vocabulary(tokenizer) - token_str_to_tranition_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - # `a` and `b` both are workable, but `z` has distinct transition rules - assert interegular_fsm.accepts("zaz") - assert interegular_fsm.accepts("zbz") - assert token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"] - assert not token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - - -def test_token_trans_keys_walk_fsm(): - """assert _walk_fsm works using transition keys""" - - class MockTokenizer: - vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"a[bc]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - tokens_to_token_ids, _ = reduced_vocabulary(tokenizer) - token_str_to_tranition_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - # verify initial state valid only for "ab" and "ac" using transition key seq - token_acceptance = {"ab": True, "ac": True, "az": False} - for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_to_tranition_keys[token] - state_seq = _walk_fsm( - regex_fsm.fsm_info.transitions, - regex_fsm.fsm_info.initial, - regex_fsm.fsm_info.finals, - token_trans_key_seq, - regex_fsm.initial, - False, - ) - is_accepted = len(state_seq) >= len(token_trans_key_seq) - assert should_accept == is_accepted - - -@pytest.mark.parametrize( - "rare_token", - [ - "�", - "��", - "�.", - ".�", - ".�.", - "▁�", - "�▁", - "▁�▁", - "?�", - "�?", - "?�?", - ], -) -def test_reduced_vocabulary_with_rare_tokens(rare_token): - """Assert reduced_vocabulary works with rare tokens. - - See [1] and [2] for context. - - [1]: https://github.com/dottxt-ai/outlines/pull/763 - [2]: https://github.com/dottxt-ai/outlines/pull/948 - [3]: https://github.com/dottxt-ai/outlines/pull/1153 - """ - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tokenizer = TransformerTokenizer(tokenizer=tokenizer) - tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 - reduced_vocabulary(tokenizer) - - -def test_reduced_vocabulary_with_byte_tokens(): - class MockTokenizer: - vocabulary = { - "string": 1, - b"\xa1": 2, # Qwen-Style - "eos": 3, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return b"\xef\xbf\xbd".decode() - - tokens_to_token_ids = reduced_vocabulary(MockTokenizer()) - - # See fsm.regex.get_token_transition_keys() - # FSM transition keys represents bytes as - assert tokens_to_token_ids[0] == {"string": [1], "\x00A1": [2]} From 70b4bc6eaf70835f1a3b83c74854d1948b3c0730 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 18 Dec 2024 18:15:09 +0000 Subject: [PATCH 07/22] Add PyGuide, use proper types for Index --- src/index.rs | 20 +++++----- src/python_bindings/mod.rs | 80 +++++++++++++++++++++++++++++++++++--- src/vocabulary/mod.rs | 21 +++++----- 3 files changed, 96 insertions(+), 25 deletions(-) diff --git a/src/index.rs b/src/index.rs index 81ca852..fc35aa7 100644 --- a/src/index.rs +++ b/src/index.rs @@ -8,12 +8,12 @@ use regex_automata::util::primitives::StateID as AutomataStateId; use regex_automata::Anchored; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -#[derive(Debug, Encode, Decode)] +#[derive(Clone, Debug, Encode, Decode)] pub struct Index { - initial: u32, - finals: HashSet, - states_to_token_subsets: HashMap>, - eos_token_id: u32, + initial: StateId, + finals: HashSet, + states_to_token_subsets: HashMap>, + eos_token_id: TokenId, } impl Index { @@ -98,24 +98,24 @@ impl Index { } } - pub(crate) fn allowed_tokens(&self, state: u32) -> Option> { + pub(crate) fn allowed_tokens(&self, state: StateId) -> Option> { self.states_to_token_subsets .get(&state) .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) } - pub(crate) fn next_state(&self, state: u32, token_id: u32) -> Option { + pub(crate) fn next_state(&self, state: StateId, token_id: TokenId) -> Option { if token_id == self.eos_token_id { return None; } Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?) } - pub(crate) fn initial(&self) -> u32 { + pub(crate) fn initial(&self) -> StateId { self.initial } - pub(crate) fn is_final(&self, state: u32) -> bool { + pub(crate) fn is_final(&self, state: StateId) -> bool { self.finals.contains(&state) } @@ -123,7 +123,7 @@ impl Index { &self.finals } - pub(crate) fn transitions(&self) -> &HashMap> { + pub(crate) fn transitions(&self) -> &HashMap> { &self.states_to_token_subsets } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 643980d..9e9495b 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -2,6 +2,7 @@ use crate::index::Index; use crate::json_schema; use crate::prelude::*; use bincode::config; +use bincode::{Decode, Encode}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -9,8 +10,76 @@ use pyo3::wrap_pyfunction; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; +#[pyclass(name = "Guide", module = "outlines_core.fsm.outlines_core_rs")] +#[derive(Clone, Debug, Encode, Decode)] +pub struct PyGuide { + state: StateId, + index: PyIndex, +} + +#[pymethods] +impl PyGuide { + #[new] + fn new(index: PyIndex) -> Self { + PyGuide { + state: index.get_initial_state(), + index, + } + } + + fn get_start_tokens(&self) -> PyResult> { + self.index + .get_allowed_tokens(self.index.get_initial_state()) + .ok_or(PyErr::new::( + "Initial state is not in allowed tokens", + )) + } + + fn read_next_token(&mut self, token_id: TokenId) -> PyResult> { + match self.index.get_next_state(self.state, token_id) { + Some(new_state) => { + self.state = new_state; + self.index + .get_allowed_tokens(new_state) + .ok_or(PyErr::new::(format!( + "No token ids found for the next state {new_state}" + ))) + } + None => Err(PyErr::new::(format!( + "Next state is not found for {} and token id {token_id}", + self.state + ))), + } + } + + fn is_finished(&self) -> bool { + self.index.is_final_state(self.state) + } + + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("PyGuide")?; + let binary_data: Vec = + bincode::encode_to_vec(self, config::standard()).map_err(|e| { + PyErr::new::(format!("Serialization of Guide failed: {}", e)) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (guide, _): (PyGuide, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!("Deserialization of Guide failed: {}", e)) + })?; + Ok(guide) + } +} #[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] +#[derive(Clone, Debug, Encode, Decode)] pub struct PyIndex(Index); #[pymethods] @@ -45,15 +114,15 @@ impl PyIndex { Ok(PyIndex(index)) } - fn get_allowed_tokens(&self, state: u32) -> Option> { + fn get_allowed_tokens(&self, state: StateId) -> Option> { self.0.allowed_tokens(state) } - fn get_next_state(&self, state: u32, token_id: u32) -> Option { + fn get_next_state(&self, state: StateId, token_id: TokenId) -> Option { self.0.next_state(state, token_id) } - fn is_final_state(&self, state: u32) -> bool { + fn is_final_state(&self, state: StateId) -> bool { self.0.is_final(state) } @@ -61,11 +130,11 @@ impl PyIndex { self.0.final_states().clone() } - fn get_transitions(&self) -> HashMap> { + fn get_transitions(&self) -> HashMap> { self.0.transitions().clone() } - fn get_initial_state(&self) -> u32 { + fn get_initial_state(&self) -> StateId { self.0.initial() } } @@ -143,6 +212,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 613d735..7efc3c7 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use rustc_hash::FxHashMap as HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -90,7 +90,7 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let processed_token= processor.process(token)?; + let processed_token = processor.process(token)?; vocabulary = vocabulary.insert(processed_token, token_id); } @@ -216,7 +216,7 @@ impl From>> for Vocabulary { eos_token_id: None, tokens: tokens .into_iter() - .map(|(k,v)| (k.as_bytes().to_vec(), v)) + .map(|(k, v)| (k.as_bytes().to_vec(), v)) .collect::>>(), } } @@ -235,6 +235,7 @@ where #[cfg(test)] mod tests { use super::*; + use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[test] fn insert() { @@ -340,7 +341,7 @@ mod tests { let token = "Ġal"; let btoken = token.as_bytes().to_vec(); assert!(vocabulary.token_to_ids(&btoken).is_none()); - assert!(tokenizer.token_to_id(&token).is_some()); + assert!(tokenizer.token_to_id(token).is_some()); for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { let v_ids = vocabulary.token_to_ids(&v_token.as_bytes().to_vec()); @@ -370,7 +371,7 @@ mod tests { tokenizer.id_to_token(v_eos).expect("Token not found"), "" ); - + let tests: &[(Vec, &[&str])] = &[ ("abc".as_bytes().to_vec(), &["abc"]), (" al".as_bytes().to_vec(), &["▁al"]), @@ -382,17 +383,17 @@ mod tests { (vec![0x20], &["▁", "<0x20>"]), ]; for (v_token, t_tokens_expected) in tests { - let v_ids = vocabulary.token_to_ids(&v_token); + let v_ids = vocabulary.token_to_ids(v_token); assert!(v_ids.is_some()); - - let t_tokens = v_ids.unwrap() + + let t_tokens = v_ids + .unwrap() .iter() .map(|v_id| { tokenizer .id_to_token(*v_id) .expect("Token id not found in tokenizer") - } - ) + }) .collect::>(); let expected = HashSet::from_iter(t_tokens_expected.iter().map(|s| s.to_string())); assert_eq!(t_tokens, expected) From b477598857d2eb6a86196942796470d8cc017a44 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 19 Dec 2024 18:59:05 +0000 Subject: [PATCH 08/22] Provide basic Guide binding, test it --- python/outlines_core/fsm/__init__.py | 1 + python/outlines_core/fsm/outlines_core_rs.pyi | 41 +++- src/python_bindings/mod.rs | 22 +- src/vocabulary/mod.rs | 2 +- tests/fsm/test_guide.py | 227 ++---------------- 5 files changed, 66 insertions(+), 227 deletions(-) diff --git a/python/outlines_core/fsm/__init__.py b/python/outlines_core/fsm/__init__.py index e69de29..9e167c4 100644 --- a/python/outlines_core/fsm/__init__.py +++ b/python/outlines_core/fsm/__init__.py @@ -0,0 +1 @@ +from .outlines_core_rs import Guide, Index, Vocabulary \ No newline at end of file diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index c7a3700..d3a25c0 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -19,23 +19,42 @@ WHITESPACE: str EMAIL: str URI: str +class Guide: + def __init__(self, index: Index): + """ + Defines a guide object an index an initializes it in its start state. + """ + def get_start_tokens(self) -> List[int]: + """ + Gets the list of allowed tokens from the start state. + """ + ... + def read_next_token(self, token_id: int) -> List[int]: + """ + Reads the next token according to the model and returns a list of allowable tokens. + """ + ... + def is_finished(self) -> bool: + """ + Checks if the automaton is in a final state. + """ + ... + class Vocabulary: """ Vocabulary of an LLM. """ @staticmethod - def from_dict(map: Dict[str, List[int]]) -> "Vocabulary": + def from_dict(eos_token_id: int, map: Dict[str, List[int]]) -> "Vocabulary": """ - Creates a vocabulary from a dictionary of tokens to token IDs. + Creates a vocabulary from a map of tokens to token ids and eos token id. """ ... @staticmethod - def from_dict_with_eos_token_id( - map: Dict[str, List[int]], eos_token_id: int - ) -> "Vocabulary": + def from_pretrained(model: str) -> "Vocabulary": """ - Creates a vocabulary from a dictionary of tokens to token IDs and eos token id. + Creates the vocabulary of a pre-trained model. """ ... def __repr__(self) -> str: @@ -48,6 +67,16 @@ class Vocabulary: Gets the string representation of the vocabulary. """ ... + def __eq__(self, other: object) -> bool: + """ + Gets whether two vocabularies are the same. + """ + ... + def get_eos_token_id(self) -> Optional[int]: + """ + Gets the end of sentence token id. + """ + ... class Index: @staticmethod diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 9e9495b..f64d4af 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -42,11 +42,11 @@ impl PyGuide { self.index .get_allowed_tokens(new_state) .ok_or(PyErr::new::(format!( - "No token ids found for the next state {new_state}" + "No token ids found for the next state: {new_state}" ))) } None => Err(PyErr::new::(format!( - "Next state is not found for {} and token id {token_id}", + "No next state found for the current state: {} with token ID: {token_id}", self.state ))), } @@ -163,15 +163,7 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(map: HashMap>) -> PyVocabulary { - PyVocabulary(Vocabulary::from(map)) - } - - #[staticmethod] - fn from_dict_with_eos_token_id( - map: HashMap>, - eos_token_id: TokenId, - ) -> PyVocabulary { + fn from_dict(eos_token_id: TokenId, map: HashMap>) -> PyVocabulary { let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); PyVocabulary(v) } @@ -182,6 +174,10 @@ impl PyVocabulary { Ok(PyVocabulary(v)) } + fn get_eos_token_id(&self) -> Option { + self.0.eos_token_id() + } + fn __repr__(&self) -> String { format!("{:#?}", self.0) } @@ -189,6 +185,10 @@ impl PyVocabulary { fn __str__(&self) -> String { format!("{}", self.0) } + + fn __eq__(&self, other: &PyVocabulary) -> bool { + self.0 == other.0 + } } #[pymodule] diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 7efc3c7..62ff748 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -25,7 +25,7 @@ mod processor; /// .insert("2", 2) /// .insert("0", 3); /// ``` -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 9bae619..50d9599 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,216 +1,25 @@ -import interegular import pytest -from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write +from outlines_core.fsm import Guide, Index, Vocabulary -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) +def test_stop_at_eos_txt(): + eos_token_id = 3 + # TODO: support bytes from python + # tokens = {b"1": {1}, b"a": {2}} + tokens = {"1": [1], "a": [2]} + regex = r"[1-9]" + vocabulary = Vocabulary.from_dict(eos_token_id, tokens) -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 + index = Index(regex, vocabulary) + guide = Guide(index) - fsm = StopAtEOSGuide(MockTokenizer()) + assert list(guide.get_start_tokens()) == [1] + assert list(guide.read_next_token(1)) == [vocabulary.get_eos_token_id()] + assert guide.is_finished() - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert instruction.tokens is None - - instruction = fsm.get_next_instruction(fsm.final_state) - assert isinstance(instruction, Write) - assert instruction.tokens == [2] - - assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide.from_regex(regex_str, MockTokenizer()) - - -def test_from_regex(): - class MockTokenizer: - vocabulary = {"1": [1], "a": [2], "eos": [3]} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == {0: {1: 1}} - - instruction = fsm.get_next_instruction(-1) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - instruction = fsm.get_next_instruction(3) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - - assert fsm.get_next_state(state=0, token_id=1) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_from_fsm(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_interegular_fsm( - interegular.parse_pattern(regex_str).to_fsm(), tokenizer - ) - - assert fsm.get_index_dict() == {0: {1: 1}} - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - - assert fsm.get_next_state(state=0, token_id=1) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_multi_byte_llama_like(): - class MockTokenizer: - vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - "<0xF0>": 5, - "<0x9F>": 6, - "<0x98>": 7, - "<0x88>": 8, # 😈 - "\ufffd": 9, - "\ufffd\ufffd": 10, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if token[0] == "<": - return "\ufffd" - return token - - regex_str = "[😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == { - 0: {5: 1, 4: 2}, - 1: {6: 3}, - 3: {7: 4}, - 4: {8: 2}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 4]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_multi_byte_gpt2_like(): - class MockTokenizer: - vocabulary = { - "1": [1], - "a": [2], - "eos": [3], - "😍": [4], - " ": [5], - "\ufffd": [6], - "\ufffd\ufffd": [7], - "ðŁĺ": [8], - "Ī": [9], # '😈' - "Ġð": [10], - "ŁĺĪ": [11], # ' 😈' - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if self.vocabulary[token] >= 8: - return "\ufffd" - return token - - regex_str = " [😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == { - 0: {5: 1, 10: 2}, - 1: {8: 5, 4: 3}, - 2: {11: 3}, - 5: {9: 3}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 10]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - state = fsm.get_next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.get_next_state(state=5, token_id=103) - assert fsm.is_final_state(state) + with pytest.raises( + ValueError, + match="No next state found for the current state", + ): + assert list(guide.read_next_token(4)) == [] From f3266eefecb0ab8808f966385ce865d14265e8a6 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 20 Dec 2024 13:56:26 +0000 Subject: [PATCH 09/22] Improve Vocabulary python binding, add tests --- src/primitives.rs | 2 +- src/python_bindings/mod.rs | 61 +++++++++++++++++++++++++++++++++--- src/vocabulary/mod.rs | 27 +++++++++------- tests/fsm/test_vocabulary.py | 43 +++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 18 deletions(-) create mode 100644 tests/fsm/test_vocabulary.py diff --git a/src/primitives.rs b/src/primitives.rs index 5691bf8..f5571fe 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -4,5 +4,5 @@ pub type Token = Vec; /// Token identifier. pub type TokenId = u32; -/// State. +/// State id. pub type StateId = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index f64d4af..f57e569 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -5,7 +5,7 @@ use bincode::config; use bincode::{Decode, Encode}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyAny, PyDict}; use pyo3::wrap_pyfunction; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; @@ -157,15 +157,23 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR .map_err(|e| PyValueError::new_err(e.to_string())) } -#[pyclass(name = "Vocabulary")] +#[pyclass(name = "Vocabulary", module = "outlines_core.fsm.outlines_core_rs")] +#[derive(Clone, Debug, Encode, Decode)] pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(eos_token_id: TokenId, map: HashMap>) -> PyVocabulary { - let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); - PyVocabulary(v) + fn from_dict(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { + if let Ok(dict) = map.extract::>>(py) { + return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + } + if let Ok(dict) = map.extract::, Vec>>(py) { + return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + } + Err(PyErr::new::( + "Expected a dictionary with keys of type String or Bytes", + )) } #[staticmethod] @@ -178,6 +186,18 @@ impl PyVocabulary { self.0.eos_token_id() } + fn get(&self, py: Python<'_>, token: Py) -> PyResult>> { + if let Ok(t) = token.extract::(py) { + return Ok(self.0.token_to_ids(&t.into_bytes()).cloned()); + } + if let Ok(t) = token.extract::(py) { + return Ok(self.0.token_to_ids(&t).cloned()); + } + Err(PyErr::new::( + "Expected a token of type String or Bytes", + )) + } + fn __repr__(&self) -> String { format!("{:#?}", self.0) } @@ -189,6 +209,37 @@ impl PyVocabulary { fn __eq__(&self, other: &PyVocabulary) -> bool { self.0 == other.0 } + + fn __len__(&self) -> usize { + self.0.tokens_to_ids().len() + } + + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("PyVocabulary")?; + let binary_data: Vec = + bincode::encode_to_vec(self, config::standard()).map_err(|e| { + PyErr::new::(format!( + "Serialization of Vocabulary failed: {}", + e + )) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (guide, _): (PyVocabulary, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!( + "Deserialization of Vocabulary failed: {}", + e + )) + })?; + Ok(guide) + } } #[pymodule] diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 62ff748..5488ac6 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,6 @@ +use bincode::{Decode, Encode}; use rustc_hash::FxHashMap as HashMap; +use std::borrow::Borrow; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -25,7 +27,7 @@ mod processor; /// .insert("2", 2) /// .insert("0", 3); /// ``` -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, @@ -103,8 +105,8 @@ impl Vocabulary { } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. - pub fn token_to_ids(&self, token: &Token) -> Option<&Vec> { - self.tokens.get(token) + pub fn token_to_ids>(&self, token: &T) -> Option<&Vec> { + self.tokens.get(token.borrow()) } /// Gets the identifier of the special end of the sentence token. @@ -201,19 +203,21 @@ impl std::fmt::Display for Vocabulary { } } -impl From>> for Vocabulary { - fn from(tokens: HashMap>) -> Vocabulary { +impl From<(TokenId, HashMap>)> for Vocabulary { + fn from(values: (TokenId, HashMap>)) -> Vocabulary { + let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: None, + eos_token_id: Some(eos_token_id), tokens, } } } -impl From>> for Vocabulary { - fn from(tokens: HashMap>) -> Vocabulary { +impl From<(TokenId, HashMap>)> for Vocabulary { + fn from(values: (TokenId, HashMap>)) -> Vocabulary { + let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: None, + eos_token_id: Some(eos_token_id), tokens: tokens .into_iter() .map(|(k, v)| (k.as_bytes().to_vec(), v)) @@ -235,7 +239,7 @@ where #[cfg(test)] mod tests { use super::*; - use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; + use rustc_hash::FxHashSet as HashSet; #[test] fn insert() { @@ -277,8 +281,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map: HashMap> = HashMap::default(); - let vocabulary = Vocabulary::from(map); + let vocabulary = Vocabulary::new(None); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty()); } diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py new file mode 100644 index 0000000..764e11c --- /dev/null +++ b/tests/fsm/test_vocabulary.py @@ -0,0 +1,43 @@ +import pickle +import pytest + +from outlines_core.fsm import Vocabulary + +def test_supports_strings_as_keys(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + + assert vocabulary.get_eos_token_id() == eos_token_id + assert vocabulary.get("1") == [1] + assert vocabulary.get(b"1") == [1] + assert len(vocabulary) == 2 + +def test_supports_bytes_as_keys(): + eos_token_id = 3 + tokens = {b"1": [1], b"a": [2]} + vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + + assert vocabulary.get_eos_token_id() == eos_token_id + assert vocabulary.get(b"1") == [1] + assert vocabulary.get("1") == [1] + assert len(vocabulary) == 2 + +def test_do_not_supports_other_types_as_keys(): + eos_token_id = 3 + tokens = {1: [1], 2: [2]} + + with pytest.raises( + TypeError, + match="Expected a dictionary with keys of type String or Bytes" + ): + Vocabulary.from_dict(eos_token_id, tokens) + +def test_pickling(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + + serialized = pickle.dumps(vocabulary) + deserialized = pickle.loads(serialized) + assert deserialized == vocabulary \ No newline at end of file From 7edb8315df3028f88b4b73156420ba18dade2312 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 20 Dec 2024 15:02:52 +0000 Subject: [PATCH 10/22] Non-optional eos_token_id --- src/error.rs | 3 -- src/index.rs | 13 +++----- src/python_bindings/mod.rs | 2 +- src/vocabulary/mod.rs | 60 ++++++++---------------------------- tests/fsm/test_vocabulary.py | 11 ++++--- 5 files changed, 25 insertions(+), 64 deletions(-) diff --git a/src/error.rs b/src/error.rs index 53a8728..4ffe7ed 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,9 +6,6 @@ pub type Result = std::result::Result; pub enum Error { #[error("The vocabulary does not allow to build an index that matches the input")] InsufficientVocabulary, - // TODO: this error will be removed once eos_token_id for vocabulary won't be optional - #[error("Index failed since vocabulary doesn't provide eos token id")] - IndexEosTokenIdNotAvailable, #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] diff --git a/src/index.rs b/src/index.rs index fc35aa7..e8d839d 100644 --- a/src/index.rs +++ b/src/index.rs @@ -18,12 +18,7 @@ pub struct Index { impl Index { pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result { - let eos_token_id = match vocabulary.eos_token_id() { - Some(s) => s, - // TODO: this error will be removed once eos_token_id for vocabulary won't be optional - None => return Err(Error::IndexEosTokenIdNotAvailable), - }; - + let eos_token_id = vocabulary.eos_token_id(); let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { Some(s) => s, @@ -135,7 +130,7 @@ mod tests { #[test] fn index_from_regex() { let regex = "0|[1-9][0-9]*"; - let vocabulary = Vocabulary::new(Some(4)) + let vocabulary = Vocabulary::new(4) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -157,7 +152,7 @@ mod tests { #[test] fn index_from_regex_initital_in_allowed() { let regex = "`\\n(\\.\\n)?`\\n"; - let vocabulary = Vocabulary::new(Some(104)) + let vocabulary = Vocabulary::new(104) .insert("\n", 103) .insert(".", 102) .insert("`", 101); @@ -172,7 +167,7 @@ mod tests { #[test] fn index_from_regex_multibyte() { let regex = "😇| [😈-😍][😇-😎]*"; - let vocabulary = Vocabulary::new(Some(8)) + let vocabulary = Vocabulary::new(8) .insert(" 😍", 5) .insert("blah", 0) .insert("😇", 2) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index f57e569..255dae8 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -182,7 +182,7 @@ impl PyVocabulary { Ok(PyVocabulary(v)) } - fn get_eos_token_id(&self) -> Option { + fn get_eos_token_id(&self) -> TokenId { self.0.eos_token_id() } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 5488ac6..c1dd1ea 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -29,27 +29,19 @@ mod processor; /// ``` #[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { - // TODO: Option is temp for back compatibility - eos_token_id: Option, + eos_token_id: TokenId, tokens: HashMap>, } impl Vocabulary { /// Creates an empty vocabulary. - pub fn new(eos_token_id: Option) -> Self { + pub fn new(eos_token_id: TokenId) -> Self { Self { eos_token_id, tokens: HashMap::default(), } } - pub fn with_eos_token_id(self, eos_token_id: Option) -> Self { - Self { - eos_token_id, - ..self - } - } - /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -77,7 +69,7 @@ impl Vocabulary { }; // Start building the vocabulary from eos_token_id and added tokens. - let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { if !added_token.special { vocabulary = vocabulary.insert(added_token.content.clone(), *id); @@ -110,7 +102,7 @@ impl Vocabulary { } /// Gets the identifier of the special end of the sentence token. - pub fn eos_token_id(&self) -> Option { + pub fn eos_token_id(&self) -> TokenId { self.eos_token_id } @@ -207,7 +199,7 @@ impl From<(TokenId, HashMap>)> for Vocabulary { fn from(values: (TokenId, HashMap>)) -> Vocabulary { let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: Some(eos_token_id), + eos_token_id, tokens, } } @@ -217,7 +209,7 @@ impl From<(TokenId, HashMap>)> for Vocabulary { fn from(values: (TokenId, HashMap>)) -> Vocabulary { let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: Some(eos_token_id), + eos_token_id, tokens: tokens .into_iter() .map(|(k, v)| (k.as_bytes().to_vec(), v)) @@ -226,16 +218,6 @@ impl From<(TokenId, HashMap>)> for Vocabulary { } } -impl FromIterator<(T, I)> for Vocabulary -where - T: Into, - I: IntoIterator, -{ - fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new(None).extend(tokens_and_ids) - } -} - #[cfg(test)] mod tests { use super::*; @@ -243,7 +225,7 @@ mod tests { #[test] fn insert() { - let vocabulary = Vocabulary::new(None) + let vocabulary = Vocabulary::new(4) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -258,7 +240,7 @@ mod tests { #[test] fn extend() { - let vocabulary = Vocabulary::new(None).extend([ + let vocabulary = Vocabulary::new(4).extend([ ("blah", vec![0]), ("1a", vec![1]), ("2", vec![2]), @@ -274,28 +256,19 @@ mod tests { #[test] fn new_empty_vocabulary() { - let vocabulary = Vocabulary::new(None); - assert!(vocabulary.eos_token_id.is_none()); + let vocabulary = Vocabulary::new(1); + assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } #[test] fn new_empty_vocabulary_from_hashmap() { - let vocabulary = Vocabulary::new(None); - assert!(vocabulary.eos_token_id.is_none()); + let map: HashMap> = HashMap::default(); + let vocabulary = Vocabulary::from((1_u32, map)); + assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } - #[test] - fn new_vocabulary_from_iterator() { - let token: Token = "abc".as_bytes().to_vec(); - let id: Vec = vec![1]; - let it = vec![(token, id)]; - let vocabulary = Vocabulary::from_iter(it); - assert!(vocabulary.eos_token_id.is_none()); - assert!(!vocabulary.tokens.is_empty()); - } - #[test] fn supported_pretrained_models() { // Support is expected for these: @@ -315,7 +288,6 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None); match vocabulary { Ok(v) => { - assert!(v.eos_token_id().is_some()); assert_eq!(v.eos_token_id, v.eos_token_id()); assert!(!v.tokens.is_empty()); } @@ -332,9 +304,6 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 50256); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -366,9 +335,6 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 2); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index 764e11c..a714eae 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -1,18 +1,20 @@ import pickle -import pytest +import pytest from outlines_core.fsm import Vocabulary + def test_supports_strings_as_keys(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} vocabulary = Vocabulary.from_dict(eos_token_id, tokens) - + assert vocabulary.get_eos_token_id() == eos_token_id assert vocabulary.get("1") == [1] assert vocabulary.get(b"1") == [1] assert len(vocabulary) == 2 + def test_supports_bytes_as_keys(): eos_token_id = 3 tokens = {b"1": [1], b"a": [2]} @@ -23,16 +25,17 @@ def test_supports_bytes_as_keys(): assert vocabulary.get("1") == [1] assert len(vocabulary) == 2 + def test_do_not_supports_other_types_as_keys(): eos_token_id = 3 tokens = {1: [1], 2: [2]} with pytest.raises( - TypeError, - match="Expected a dictionary with keys of type String or Bytes" + TypeError, match="Expected a dictionary with keys of type String or Bytes" ): Vocabulary.from_dict(eos_token_id, tokens) + def test_pickling(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} From 03e5561e8d8b864f6bbeb53c8ec606ca49c03425 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 2 Jan 2025 13:02:50 +0000 Subject: [PATCH 11/22] Stabilize vocabulary interface --- src/index.rs | 38 +++++----- src/prelude.rs | 1 + src/python_bindings/mod.rs | 26 +++++-- src/vocabulary/mod.rs | 141 ++++++++++++----------------------- tests/fsm/test_vocabulary.py | 30 ++++++-- 5 files changed, 109 insertions(+), 127 deletions(-) diff --git a/src/index.rs b/src/index.rs index e8d839d..21fab51 100644 --- a/src/index.rs +++ b/src/index.rs @@ -130,11 +130,10 @@ mod tests { #[test] fn index_from_regex() { let regex = "0|[1-9][0-9]*"; - let vocabulary = Vocabulary::new(4) - .insert("blah", 0) - .insert("1a", 1) - .insert("2", 2) - .insert("0", 3); + let mut vocabulary = Vocabulary::new(4); + for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] { + vocabulary.insert(token, token_id as u32); + } let index = Index::new(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); @@ -152,10 +151,10 @@ mod tests { #[test] fn index_from_regex_initital_in_allowed() { let regex = "`\\n(\\.\\n)?`\\n"; - let vocabulary = Vocabulary::new(104) - .insert("\n", 103) - .insert(".", 102) - .insert("`", 101); + let mut vocabulary = Vocabulary::new(104); + for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] { + vocabulary.insert(token, token_id as u32); + } let index = Index::new(regex, &vocabulary).expect("Index failed"); let allowed = index @@ -167,15 +166,18 @@ mod tests { #[test] fn index_from_regex_multibyte() { let regex = "😇| [😈-😍][😇-😎]*"; - let vocabulary = Vocabulary::new(8) - .insert(" 😍", 5) - .insert("blah", 0) - .insert("😇", 2) - .insert("😈a", 1) - .insert("😍", 3) - .insert(vec![32, 240, 159, 152], 7) - .insert(vec![32, 240, 159, 152, 141], 6) - .insert(vec![240, 159, 152, 141], 4); + let mut vocabulary = Vocabulary::new(8); + for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)] + { + vocabulary.insert(token, token_id as u32); + } + for (token, token_id) in [ + (vec![32, 240, 159, 152], 7), + (vec![32, 240, 159, 152, 141], 6), + (vec![240, 159, 152, 141], 4), + ] { + vocabulary.insert(token, token_id as u32); + } let index = Index::new(regex, &vocabulary).expect("Index failed"); diff --git a/src/prelude.rs b/src/prelude.rs index 27236b9..8c1a853 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,4 +1,5 @@ pub use super::{ + index::Index, primitives::{StateId, Token, TokenId}, vocabulary::Vocabulary, }; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 255dae8..daf14ce 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,6 +10,13 @@ use pyo3::wrap_pyfunction; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; +macro_rules! type_name { + ($obj:expr) => { + // Safety: obj is always initialized and tp_name is a C-string + unsafe { std::ffi::CStr::from_ptr((&*(&*$obj.as_ptr()).ob_type).tp_name) } + }; +} + #[pyclass(name = "Guide", module = "outlines_core.fsm.outlines_core_rs")] #[derive(Clone, Debug, Encode, Decode)] pub struct PyGuide { @@ -59,7 +66,7 @@ impl PyGuide { fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? - .getattr("PyGuide")?; + .getattr("Guide")?; let binary_data: Vec = bincode::encode_to_vec(self, config::standard()).map_err(|e| { PyErr::new::(format!("Serialization of Guide failed: {}", e)) @@ -163,8 +170,8 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { - #[staticmethod] - fn from_dict(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { + #[new] + fn new(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { if let Ok(dict) = map.extract::>>(py) { return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); } @@ -172,7 +179,10 @@ impl PyVocabulary { return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); } Err(PyErr::new::( - "Expected a dictionary with keys of type String or Bytes", + format!( + "Expected a dictionary with keys of type str or bytes and values of type list[int], got {:?}", + type_name!(map) + ), )) } @@ -188,13 +198,13 @@ impl PyVocabulary { fn get(&self, py: Python<'_>, token: Py) -> PyResult>> { if let Ok(t) = token.extract::(py) { - return Ok(self.0.token_to_ids(&t.into_bytes()).cloned()); + return Ok(self.0.token_to_ids(t.into_bytes()).cloned()); } if let Ok(t) = token.extract::(py) { return Ok(self.0.token_to_ids(&t).cloned()); } Err(PyErr::new::( - "Expected a token of type String or Bytes", + format!("Expected a token of type str or bytes, got {:?}", type_name!(token)), )) } @@ -217,7 +227,7 @@ impl PyVocabulary { fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? - .getattr("PyVocabulary")?; + .getattr("Vocabulary")?; let binary_data: Vec = bincode::encode_to_vec(self, config::standard()).map_err(|e| { PyErr::new::(format!( @@ -266,4 +276,4 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) -} +} \ No newline at end of file diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index c1dd1ea..edafcf7 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,6 +1,5 @@ use bincode::{Decode, Encode}; use rustc_hash::FxHashMap as HashMap; -use std::borrow::Borrow; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -18,14 +17,19 @@ mod processor; /// /// ## Examples /// +/// ### Create a vocabulary from a pretrained model. /// ```rust /// # use outlines_core::prelude::*; /// # -/// let vocabulary = Vocabulary::new(None) -/// .insert("blah", 0) -/// .insert("1a", 1) -/// .insert("2", 2) -/// .insert("0", 3); +/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None); +/// ``` +/// +/// ### Create an empty vocabulary. +/// ```rust +/// # use outlines_core::prelude::*; +/// # +/// let mut vocabulary = Vocabulary::new(1); +/// vocabulary.insert("token", 0); /// ``` #[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { @@ -42,6 +46,12 @@ impl Vocabulary { } } + /// Inserts a token to the vocabulary with the specified identifier. + pub fn insert(&mut self, token: impl Into, id: TokenId) { + let token = token.into(); + self.tokens.entry(token).or_default().push(id); + } + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -72,7 +82,7 @@ impl Vocabulary { let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { if !added_token.special { - vocabulary = vocabulary.insert(added_token.content.clone(), *id); + vocabulary.insert(added_token.content.clone(), *id); } } @@ -85,7 +95,7 @@ impl Vocabulary { }; for (token, token_id) in tokenizer.get_vocab(false) { let processed_token = processor.process(token)?; - vocabulary = vocabulary.insert(processed_token, token_id); + vocabulary.insert(processed_token, token_id); } Ok(vocabulary) @@ -97,8 +107,8 @@ impl Vocabulary { } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. - pub fn token_to_ids>(&self, token: &T) -> Option<&Vec> { - self.tokens.get(token.borrow()) + pub fn token_to_ids(&self, token: impl AsRef<[u8]>) -> Option<&Vec> { + self.tokens.get(token.as_ref()) } /// Gets the identifier of the special end of the sentence token. @@ -137,59 +147,11 @@ impl Vocabulary { } } -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier. - pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { - self.insert_in_place(token, id); - self - } - - /// Extends the vocabulary with tokens and their identifiers. - pub fn extend, I: IntoIterator>( - mut self, - tokens_and_ids: impl IntoIterator, - ) -> Vocabulary { - self.extend_in_place(tokens_and_ids); - self - } -} - -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier, in place. - pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { - // TODO: return error if eos token id is inserted - let token = token.into(); - self.tokens.entry(token).or_default().push(id); - } - - /// Extends the vocabulary with tokens and their identifiers, in place. - pub fn extend_in_place, I: IntoIterator>( - &mut self, - tokens_and_ids: impl IntoIterator, - ) { - for (token, ids) in tokens_and_ids.into_iter() { - let token = token.into(); - self.tokens.entry(token).or_default().extend(ids); - } - } -} - -impl std::ops::Deref for Vocabulary { - type Target = HashMap>; - - fn deref(&self) -> &HashMap> { - &self.tokens - } -} - impl std::fmt::Display for Vocabulary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for (index, (token, token_ids)) in self.iter().enumerate() { - if index != (self.len() - 1) { - writeln!(f, "{:?} -> {:?}", token, token_ids)?; - } else { - write!(f, "{:?} -> {:?}", token, token_ids)?; - } + writeln!(f, "[{:?}]", self.eos_token_id)?; + for (token, token_ids) in self.tokens.iter() { + writeln!(f, "{:?} -> {:?}", token, token_ids)?; } Ok(()) } @@ -224,41 +186,30 @@ mod tests { use rustc_hash::FxHashSet as HashSet; #[test] - fn insert() { - let vocabulary = Vocabulary::new(4) - .insert("blah", 0) - .insert("1a", 1) - .insert("2", 2) - .insert("0", 3); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah".as_bytes()], &[0]); - assert_eq!(vocabulary["1a".as_bytes()], &[1]); - assert_eq!(vocabulary["2".as_bytes()], &[2]); - assert_eq!(vocabulary["0".as_bytes()], &[3]); - } - - #[test] - fn extend() { - let vocabulary = Vocabulary::new(4).extend([ - ("blah", vec![0]), - ("1a", vec![1]), - ("2", vec![2]), - ("0", vec![3]), - ]); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah".as_bytes()], &[0]); - assert_eq!(vocabulary["1a".as_bytes()], &[1]); - assert_eq!(vocabulary["2".as_bytes()], &[2]); - assert_eq!(vocabulary["0".as_bytes()], &[3]); - } + fn basic_interface() { + let eos_token_id = 3; + let mut vocabulary = Vocabulary::new(eos_token_id); - #[test] - fn new_empty_vocabulary() { - let vocabulary = Vocabulary::new(1); - assert_eq!(vocabulary.eos_token_id, 1); + // New empty vocabulary. + assert_eq!(vocabulary.eos_token_id, eos_token_id); assert!(vocabulary.tokens.is_empty()); + + for (token, id) in [("zero", 0), ("one", 1), ("two", 2)] { + vocabulary.insert(token, id); + assert_eq!(vocabulary.token_to_ids(token), Some(&vec![id])); + } + assert_eq!(vocabulary.tokens.len(), 3); + assert_eq!(vocabulary.tokens_to_ids().len(), 3); + + // Confirm different types. + vocabulary.insert(b"four", 4); + assert_eq!(vocabulary.token_to_ids("four"), Some(&vec![4])); + + vocabulary.insert(b"five".to_vec(), 5); + assert_eq!(vocabulary.token_to_ids("five"), Some(&vec![5])); + + vocabulary.insert("six".to_string(), 6); + assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6])); } #[test] @@ -316,7 +267,7 @@ mod tests { assert!(tokenizer.token_to_id(token).is_some()); for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { - let v_ids = vocabulary.token_to_ids(&v_token.as_bytes().to_vec()); + let v_ids = vocabulary.token_to_ids(v_token.as_bytes()); assert!(v_ids.is_some()); for v_id in v_ids.unwrap() { let t_token = tokenizer diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index a714eae..b0de5a9 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -7,7 +7,7 @@ def test_supports_strings_as_keys(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + vocabulary = Vocabulary(eos_token_id, tokens) assert vocabulary.get_eos_token_id() == eos_token_id assert vocabulary.get("1") == [1] @@ -18,7 +18,7 @@ def test_supports_strings_as_keys(): def test_supports_bytes_as_keys(): eos_token_id = 3 tokens = {b"1": [1], b"a": [2]} - vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + vocabulary = Vocabulary(eos_token_id, tokens) assert vocabulary.get_eos_token_id() == eos_token_id assert vocabulary.get(b"1") == [1] @@ -31,16 +31,34 @@ def test_do_not_supports_other_types_as_keys(): tokens = {1: [1], 2: [2]} with pytest.raises( - TypeError, match="Expected a dictionary with keys of type String or Bytes" + TypeError, + match=r"Expected a dictionary with keys of type str or bytes and values of type list\[int\], got", ): - Vocabulary.from_dict(eos_token_id, tokens) + Vocabulary(eos_token_id, tokens) + + +def test_get_bad_type(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + vocabulary = Vocabulary(eos_token_id, tokens) + + with pytest.raises( + TypeError, + match="Expected a token of type str or bytes, got", + ): + vocabulary.get(1) + + +def test_from_pretrained(): + vocabulary = Vocabulary.from_pretrained("gpt2") + assert vocabulary.get_eos_token_id() == 50256 def test_pickling(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + vocabulary = Vocabulary(eos_token_id, tokens) serialized = pickle.dumps(vocabulary) deserialized = pickle.loads(serialized) - assert deserialized == vocabulary \ No newline at end of file + assert deserialized == vocabulary From 64f0d7320ad9eacb02518b3d345fa29e8cec508a Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 3 Jan 2025 13:04:47 +0000 Subject: [PATCH 12/22] Add tests for Guide --- src/python_bindings/mod.rs | 27 +++++--- tests/fsm/test_guide.py | 125 ++++++++++++++++++++++++++++++++--- tests/fsm/test_vocabulary.py | 15 +++-- 3 files changed, 144 insertions(+), 23 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index daf14ce..dc766f8 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -178,12 +178,18 @@ impl PyVocabulary { if let Ok(dict) = map.extract::, Vec>>(py) { return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); } - Err(PyErr::new::( - format!( - "Expected a dictionary with keys of type str or bytes and values of type list[int], got {:?}", - type_name!(map) - ), - )) + + let message = "Expected a dict with keys of type str or bytes and values of type list[int]"; + let tname = type_name!(map).to_string_lossy(); + if tname == "dict" { + Err(PyErr::new::(format!( + "Dict keys or/and values of the wrong types. {message}" + ))) + } else { + Err(PyErr::new::(format!( + "{message}, got {tname}" + ))) + } } #[staticmethod] @@ -203,9 +209,10 @@ impl PyVocabulary { if let Ok(t) = token.extract::(py) { return Ok(self.0.token_to_ids(&t).cloned()); } - Err(PyErr::new::( - format!("Expected a token of type str or bytes, got {:?}", type_name!(token)), - )) + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) } fn __repr__(&self) -> String { @@ -276,4 +283,4 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) -} \ No newline at end of file +} diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 50d9599..45b3d18 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,25 +1,134 @@ +import pickle + import pytest from outlines_core.fsm import Guide, Index, Vocabulary -def test_stop_at_eos_txt(): +def test_stop_at_eos(): eos_token_id = 3 - # TODO: support bytes from python - # tokens = {b"1": {1}, b"a": {2}} tokens = {"1": [1], "a": [2]} - regex = r"[1-9]" - vocabulary = Vocabulary.from_dict(eos_token_id, tokens) + vocabulary = Vocabulary(eos_token_id, tokens) index = Index(regex, vocabulary) guide = Guide(index) - assert list(guide.get_start_tokens()) == [1] - assert list(guide.read_next_token(1)) == [vocabulary.get_eos_token_id()] + assert guide.get_start_tokens() == [1] + assert guide.read_next_token(1) == [vocabulary.get_eos_token_id()] assert guide.is_finished() with pytest.raises( ValueError, match="No next state found for the current state", ): - assert list(guide.read_next_token(4)) == [] + assert guide.read_next_token(4) == [] + + +def test_regex_final_state_walk(): + # Make sure that the Guide can walk to the final state correctly. + eos_token_id = 104 + tokens = {b"\n": [103], b".": [102], b"`": [101]} + regex = r"`\n(\.\n)?`\n" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + guide = Guide(index) + + assert guide.get_start_tokens() == [101] + assert guide.read_next_token(101) == [103] + assert sorted(guide.read_next_token(103)) == [101, 102] + assert guide.read_next_token(101) == [103] + assert guide.read_next_token(103) == [vocabulary.get_eos_token_id()] + assert guide.is_finished() + + +def test_token_trans_keys_identical(): + tokens = {"a": [1], "b": [2], "z": [3]} + eos_token_id = 4 + regex = r"z[ab]z" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + + guide1 = Guide(index) + guide2 = Guide(index) + + assert guide1.read_next_token(3) == guide2.read_next_token(3) + # `a` and `b` have similar transitions to `z` + assert guide1.read_next_token(1) == guide2.read_next_token(2) + assert guide1.read_next_token(3) == guide2.read_next_token(3) == [eos_token_id] + assert guide1.is_finished() + assert guide2.is_finished() + + +def test_str_and_bytes_produce_the_same(): + tokens1 = {"a": [1], "b": [2], "z": [3]} + tokens2 = {b"a": [1], b"b": [2], b"z": [3]} + eos_token_id = 4 + regex = r"z[ab]z" + + vocabulary1 = Vocabulary(eos_token_id, tokens1) + vocabulary2 = Vocabulary(eos_token_id, tokens2) + index1 = Index(regex, vocabulary1) + index2 = Index(regex, vocabulary2) + guide1 = Guide(index1) + guide2 = Guide(index2) + + assert guide1.read_next_token(3) == guide2.read_next_token(3) + # `a` and `b` have similar transitions to `z` + assert guide1.read_next_token(1) == guide2.read_next_token(2) + assert guide1.read_next_token(3) == guide2.read_next_token(3) == [eos_token_id] + assert guide1.is_finished() + assert guide2.is_finished() + + +def test_pickling(): + eos_token_id = 3 + tokens = {"1": [1], "2": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + guide = Guide(index) + + serialized = pickle.dumps(guide) + deserialized = pickle.loads(serialized) + assert sorted(deserialized.get_start_tokens()) == sorted(guide.get_start_tokens()) + + +# @pytest.mark.parametrize( +# "hf_tokenizer_uri, revision", +# [ +# ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), +# ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), +# ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), +# ( +# "NousResearch/Hermes-2-Pro-Llama-3-8B", +# "783fd50eb82d7f57758de033861f54d62dde234f", +# ), +# ], +# ) +# def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): +# # The combined regular expressions of a lexer state in a Python grammar +# regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + +# regex_pattern = interegular.parse_pattern(regex_str) +# # Not reduced, so that there are many states +# regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) +# bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) + +# num_fsm_states = len(regex_fsm.states) +# assert num_fsm_states == 220 + +# num_bytes_fsm_states = len(bytes_fsm.states) +# assert num_bytes_fsm_states == 235 + +# tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) +# tokenizer = TransformerTokenizer(tokenizer) + +# states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( +# bytes_fsm, tokenizer +# ) + +# assert not empty_token_ids +# assert len(states_to_token_subsets.get_transitions()) / num_fsm_states > 0.94 diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index b0de5a9..4bc54fa 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -26,15 +26,20 @@ def test_supports_bytes_as_keys(): assert len(vocabulary) == 2 -def test_do_not_supports_other_types_as_keys(): - eos_token_id = 3 - tokens = {1: [1], 2: [2]} +def test_do_not_supports_other_types(): + eos_token_id = 0 + + with pytest.raises( + TypeError, + match=r"Expected a dict with keys of type str or bytes and values of type list\[int\], got", + ): + Vocabulary(eos_token_id, 1) with pytest.raises( TypeError, - match=r"Expected a dictionary with keys of type str or bytes and values of type list\[int\], got", + match="Dict keys or/and values of the wrong types", ): - Vocabulary(eos_token_id, tokens) + Vocabulary(eos_token_id, {1: [1], 2: [2]}) def test_get_bad_type(): From 2ab0007fcadd6d9fd9e91c0cb58161444451f1fe Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 3 Jan 2025 14:04:47 +0000 Subject: [PATCH 13/22] Python vocabulary to accept pretrained params --- src/python_bindings/mod.rs | 17 ++++++++-- tests/fsm/test_guide.py | 59 +++++++++++++-------------------- tests/fsm/test_serialization.py | 56 ------------------------------- 3 files changed, 38 insertions(+), 94 deletions(-) delete mode 100644 tests/fsm/test_serialization.py diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index dc766f8..405972e 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -9,6 +9,7 @@ use pyo3::types::{PyAny, PyDict}; use pyo3::wrap_pyfunction; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; +use tokenizers::FromPretrainedParameters; macro_rules! type_name { ($obj:expr) => { @@ -193,8 +194,20 @@ impl PyVocabulary { } #[staticmethod] - fn from_pretrained(model: String) -> PyResult { - let v = Vocabulary::from_pretrained(model.as_str(), None)?; + #[pyo3(signature = (model, revision=None, token=None))] + fn from_pretrained( + model: String, + revision: Option, + token: Option, + ) -> PyResult { + let mut params = FromPretrainedParameters::default(); + if let Some(r) = revision { + params.revision = r + } + if token.is_some() { + params.token = token + } + let v = Vocabulary::from_pretrained(model.as_str(), Some(params))?; Ok(PyVocabulary(v)) } diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 45b3d18..d48738a 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -96,39 +96,26 @@ def test_pickling(): assert sorted(deserialized.get_start_tokens()) == sorted(guide.get_start_tokens()) -# @pytest.mark.parametrize( -# "hf_tokenizer_uri, revision", -# [ -# ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), -# ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), -# ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), -# ( -# "NousResearch/Hermes-2-Pro-Llama-3-8B", -# "783fd50eb82d7f57758de033861f54d62dde234f", -# ), -# ], -# ) -# def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): -# # The combined regular expressions of a lexer state in a Python grammar -# regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - -# regex_pattern = interegular.parse_pattern(regex_str) -# # Not reduced, so that there are many states -# regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) -# bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - -# num_fsm_states = len(regex_fsm.states) -# assert num_fsm_states == 220 - -# num_bytes_fsm_states = len(bytes_fsm.states) -# assert num_bytes_fsm_states == 235 - -# tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) -# tokenizer = TransformerTokenizer(tokenizer) - -# states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( -# bytes_fsm, tokenizer -# ) - -# assert not empty_token_ids -# assert len(states_to_token_subsets.get_transitions()) / num_fsm_states > 0.94 +@pytest.mark.parametrize( + "model, revision", + [ + ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), + ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), + ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), + ( + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "783fd50eb82d7f57758de033861f54d62dde234f", + ), + ], +) +def test_pickling_from_pretrained_with_revision(model, revision): + regex = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + vocabulary = Vocabulary.from_pretrained(model, revision=revision) + index = Index(regex, vocabulary) + assert len(index.get_transitions()) == 810 + + guide = Guide(index) + serialized = pickle.dumps(guide) + deserialized = pickle.loads(serialized) + assert sorted(deserialized.get_start_tokens()) == sorted(guide.get_start_tokens()) diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py deleted file mode 100644 index d3c3836..0000000 --- a/tests/fsm/test_serialization.py +++ /dev/null @@ -1,56 +0,0 @@ -import pickle - -import pytest -from outlines_core.fsm.guide import RegexGuide -from transformers import AutoTokenizer - -from tests.fsm.test_regex import TransformerTokenizer - - -def test_serialization(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - serialized = pickle.dumps(fsm) - deserialized = pickle.loads(serialized) - - assert fsm.eos_tensor == deserialized.eos_tensor - assert fsm.initial_state == deserialized.initial_state - - -@pytest.mark.parametrize( - "hf_tokenizer_uri, revision", - [ - ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), - ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), - ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), - ( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "783fd50eb82d7f57758de033861f54d62dde234f", - ), - ], -) -def test_complex_serialization(hf_tokenizer_uri, revision): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) - tokenizer = TransformerTokenizer(tokenizer) - - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - serialized = pickle.dumps(fsm) - deserialized = pickle.loads(serialized) - - assert fsm.eos_tensor == deserialized.eos_tensor - assert fsm.initial_state == deserialized.initial_state From 063d1c2fe871323ce8acc926a7f2d619553dbc13 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 3 Jan 2025 17:28:32 +0000 Subject: [PATCH 14/22] Correct interface in pyi, reprs for all python bindings --- python/outlines_core/fsm/outlines_core_rs.pyi | 79 ++++++++----------- src/index.rs | 10 +++ src/python_bindings/mod.rs | 63 ++++++++++----- src/vocabulary/mod.rs | 16 +++- tests/fsm/test_guide.py | 1 + 5 files changed, 101 insertions(+), 68 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index d3a25c0..efbaa6a 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -21,69 +21,53 @@ URI: str class Guide: def __init__(self, index: Index): - """ - Defines a guide object an index an initializes it in its start state. - """ + """Defines a guide object an index an initializes it in its start state.""" def get_start_tokens(self) -> List[int]: - """ - Gets the list of allowed tokens from the start state. - """ + """Gets the list of allowed tokens from the start state.""" ... def read_next_token(self, token_id: int) -> List[int]: - """ - Reads the next token according to the model and returns a list of allowable tokens. - """ + """Reads the next token according to the model and returns a list of allowable tokens.""" ... def is_finished(self) -> bool: - """ - Checks if the automaton is in a final state. - """ + """Checks if the automaton is in a final state.""" ... + def __repr__(self) -> str: + """Gets the debug string representation of the guide.""" + ... + def __str__(self) -> str: + """Gets the string representation of the guide.""" class Vocabulary: - """ - Vocabulary of an LLM. - """ - - @staticmethod - def from_dict(eos_token_id: int, map: Dict[str, List[int]]) -> "Vocabulary": - """ - Creates a vocabulary from a map of tokens to token ids and eos token id. - """ + def __init__( + self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]] + ) -> "Vocabulary": + """Creates a vocabulary from a map of tokens to token ids and eos token id.""" ... @staticmethod - def from_pretrained(model: str) -> "Vocabulary": - """ - Creates the vocabulary of a pre-trained model. - """ + def from_pretrained( + model: str, revision: Optional[String], token: Optional[String] + ) -> "Vocabulary": + """Creates the vocabulary of a pre-trained model.""" + ... + def get_eos_token_id(self) -> Optional[int]: + """Gets the end of sentence token id.""" + ... + def get(self, token: Union[str, bytes]) -> Optional[List[int]]: + """Gets the end of sentence token id.""" ... def __repr__(self) -> str: - """ - Gets the debug string representation of the vocabulary. - """ + """Gets the debug string representation of the vocabulary.""" ... def __str__(self) -> str: - """ - Gets the string representation of the vocabulary. - """ + """Gets the string representation of the vocabulary.""" ... def __eq__(self, other: object) -> bool: - """ - Gets whether two vocabularies are the same. - """ - ... - def get_eos_token_id(self) -> Optional[int]: - """ - Gets the end of sentence token id. - """ + """Compares whether two vocabularies are the same.""" ... class Index: - @staticmethod - def from_regex(regex: str, vocabulary: "Vocabulary") -> "Index": - """ - Creates an index from a regex and vocabulary. - """ + def __init__(self, regex: str, vocabulary: "Vocabulary") -> "Index": + """Creates an index from a regex and vocabulary.""" ... def get_allowed_tokens(self, state: int) -> Optional[List[int]]: """Returns allowed tokens in this state.""" @@ -97,9 +81,14 @@ class Index: def final_states(self) -> List[int]: """Get all final states.""" ... - def get_index_dict(self) -> Dict[int, Dict[int, int]]: + def get_transitions(self) -> Dict[int, Dict[int, int]]: """Returns the Index as a Python Dict object.""" ... def get_initial_state(self) -> int: """Returns the ID of the initial state of the input FSM automata.""" ... + def __repr__(self) -> str: + """Gets the debug string representation of the index.""" + ... + def __str__(self) -> str: + """Gets the string representation of the index.""" diff --git a/src/index.rs b/src/index.rs index 21fab51..7725acc 100644 --- a/src/index.rs +++ b/src/index.rs @@ -123,6 +123,16 @@ impl Index { } } +impl std::fmt::Display for Index { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Index object with transitions:")?; + for (state_id, token_ids) in self.states_to_token_subsets.iter() { + writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 405972e..a702be2 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -64,6 +64,20 @@ impl PyGuide { self.index.is_final_state(self.state) } + fn __repr__(&self) -> String { + format!( + "Guide object with the state={:#?} and {:#?}", + self.state, self.index + ) + } + + fn __str__(&self) -> String { + format!( + "Guide object with the state={} and {}", + self.state, self.index.0 + ) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -101,27 +115,6 @@ impl PyIndex { }) } - fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { - Python::with_gil(|py| { - let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? - .getattr("Index")?; - let binary_data: Vec = bincode::encode_to_vec(&self.0, config::standard()) - .map_err(|e| { - PyErr::new::(format!("Serialization of Index failed: {}", e)) - })?; - Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) - }) - } - - #[staticmethod] - fn from_binary(binary_data: Vec) -> PyResult { - let (index, _): (Index, usize) = - bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { - PyErr::new::(format!("Deserialization of Index failed: {}", e)) - })?; - Ok(PyIndex(index)) - } - fn get_allowed_tokens(&self, state: StateId) -> Option> { self.0.allowed_tokens(state) } @@ -145,6 +138,34 @@ impl PyIndex { fn get_initial_state(&self) -> StateId { self.0.initial() } + fn __repr__(&self) -> String { + format!("{:#?}", self.0) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } + + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("Index")?; + let binary_data: Vec = bincode::encode_to_vec(&self.0, config::standard()) + .map_err(|e| { + PyErr::new::(format!("Serialization of Index failed: {}", e)) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (index, _): (Index, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!("Deserialization of Index failed: {}", e)) + })?; + Ok(PyIndex(index)) + } } #[pyfunction(name = "build_regex_from_schema")] diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index edafcf7..dbae9b6 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -149,9 +149,21 @@ impl Vocabulary { impl std::fmt::Display for Vocabulary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "[{:?}]", self.eos_token_id)?; + writeln!( + f, + "Vocabulary object with eos_token_id={:?} and the following tokens to token_ids:", + self.eos_token_id + )?; for (token, token_ids) in self.tokens.iter() { - writeln!(f, "{:?} -> {:?}", token, token_ids)?; + writeln!( + f, + "{:?} -> {:?}", + token + .iter() + .map(|b| format!("0x{:02X}", b)) + .collect::>(), + token_ids + )?; } Ok(()) } diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index d48738a..db7ab07 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -73,6 +73,7 @@ def test_str_and_bytes_produce_the_same(): index2 = Index(regex, vocabulary2) guide1 = Guide(index1) guide2 = Guide(index2) + assert False assert guide1.read_next_token(3) == guide2.read_next_token(3) # `a` and `b` have similar transitions to `z` From f65d86f2a10781022218f718ab6fbe147c5451d6 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 6 Jan 2025 15:59:45 +0000 Subject: [PATCH 15/22] Adjust benchmarks --- benchmarks/bench_json_schema.py | 8 +- benchmarks/bench_regex_guide.py | 22 ++-- benchmarks/common.py | 117 ------------------ python/outlines_core/fsm/__init__.py | 2 +- python/outlines_core/fsm/outlines_core_rs.pyi | 10 +- 5 files changed, 20 insertions(+), 139 deletions(-) delete mode 100644 benchmarks/common.py diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 964caa3..f88ed6c 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,8 +1,6 @@ -from outlines_core.fsm.guide import RegexGuide +from outlines_core.fsm import Index, Vocabulary from outlines_core.fsm.json_schema import build_regex_from_schema -from .common import setup_tokenizer # noqa: E402 - simple_schema = """{ "$defs": { "Armor": { @@ -66,7 +64,7 @@ class JsonSchemaBenchmark: params = schemas.keys() def setup(self, schema_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.schema = schemas[schema_name] def time_json_schema_to_regex(self, schema_name): @@ -74,4 +72,4 @@ def time_json_schema_to_regex(self, schema_name): def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) - RegexGuide.from_regex(regex, self.tokenizer) + Index(regex, self.vocabulary) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 5dda576..d921a7d 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,9 +1,7 @@ from concurrent.futures import ThreadPoolExecutor import psutil -from outlines_core.fsm.guide import RegexGuide - -from .common import setup_tokenizer +from outlines_core.fsm import Index, Vocabulary regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -18,15 +16,15 @@ } -class RegexGuideBenchmark: +class RegexIndexBenchmark: params = regex_samples.keys() def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.pattern = regex_samples[pattern_name] def time_regex_to_guide(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + Index(self.pattern, self.vocabulary) def time_regex_to_guide_parallel(self, pattern_name): # Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks, @@ -37,6 +35,10 @@ def time_regex_to_guide_parallel(self, pattern_name): list(executor.map(self._from_regex, [pattern_name] * core_count)) def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): + # Note: after moving to full rust implementation for index and guide creation, this experiment + # is no longer shows the drastic difference as it once showed when python was heavily involved, + # due to on average speedup ~100 times. + # This test is to show, that if GIL's switch interval is set to be longer, then the parallel # test's runtime on physical cores will be much closer to the one-threaded case. import sys @@ -48,15 +50,15 @@ def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name) list(executor.map(self._from_regex, [pattern_name] * core_count)) def _from_regex(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + Index(self.pattern, self.vocabulary) -class MemoryRegexGuideBenchmark: +class MemoryRegexIndexBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.pattern = regex_samples[pattern_name] def peakmem_regex_to_guide(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + Index(self.pattern, self.vocabulary) diff --git a/benchmarks/common.py b/benchmarks/common.py deleted file mode 100644 index b56677e..0000000 --- a/benchmarks/common.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from datasets.fingerprint import Hasher -from transformers import AutoTokenizer, PreTrainedTokenizer - - -def get_llama_tokenizer_types(): - """Get all the Llama tokenizer types/classes that need work-arounds. - - When they can't be imported, a dummy class is created. - - """ - try: - from transformers.models.llama import LlamaTokenizer - except ImportError: - - class LlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.llama import LlamaTokenizerFast - except ImportError: - - class LlamaTokenizerFast: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizer - except ImportError: - - class CodeLlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizerFast - except ImportError: - - class CodeLlamaTokenizerFast: # type: ignore - pass - - return ( - LlamaTokenizer, - LlamaTokenizerFast, - CodeLlamaTokenizer, - CodeLlamaTokenizerFast, - ) - - -class TransformerTokenizer: - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs): - self.tokenizer = tokenizer - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: torch.LongTensor) -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __hash__(self): - return hash(Hasher.hash(self.tokenizer)) - - def __eq__(self, other): - if isinstance(other, type(self)): - if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return ( - other.model_name == self.model_name and other.kwargs == self.kwargs - ) - else: - return other.tokenizer == self.tokenizer - return NotImplemented - - def __getstate__(self): - state = {"tokenizer": self.tokenizer} - return state - - def __setstate__(self, state): - self.__init__(state["tokenizer"]) - - -def setup_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - return TransformerTokenizer(tokenizer) diff --git a/python/outlines_core/fsm/__init__.py b/python/outlines_core/fsm/__init__.py index 9e167c4..f4769f6 100644 --- a/python/outlines_core/fsm/__init__.py +++ b/python/outlines_core/fsm/__init__.py @@ -1 +1 @@ -from .outlines_core_rs import Guide, Index, Vocabulary \ No newline at end of file +from .outlines_core_rs import Guide, Index, Vocabulary diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index efbaa6a..d661d34 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union def build_regex_from_schema( json: str, whitespace_pattern: Optional[str] = None @@ -38,14 +38,12 @@ class Guide: """Gets the string representation of the guide.""" class Vocabulary: - def __init__( - self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]] - ) -> "Vocabulary": + def __init__(self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]]): """Creates a vocabulary from a map of tokens to token ids and eos token id.""" ... @staticmethod def from_pretrained( - model: str, revision: Optional[String], token: Optional[String] + model: str, revision: Optional[str], token: Optional[str] ) -> "Vocabulary": """Creates the vocabulary of a pre-trained model.""" ... @@ -66,7 +64,7 @@ class Vocabulary: ... class Index: - def __init__(self, regex: str, vocabulary: "Vocabulary") -> "Index": + def __init__(self, regex: str, vocabulary: "Vocabulary"): """Creates an index from a regex and vocabulary.""" ... def get_allowed_tokens(self, state: int) -> Optional[List[int]]: From 30e29efee98ce7dbc1718e8df18676f6e2027088 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 6 Jan 2025 17:08:00 +0000 Subject: [PATCH 16/22] Drop unused dependencies --- pyproject.toml | 9 --------- tests/fsm/test_json_schema.py | 4 ---- 2 files changed, 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3bde97b..ec43669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "interegular", "jsonschema", ] dynamic = ["version"] @@ -39,15 +38,8 @@ test = [ "pytest-mock", "coverage[toml]>=5.1", "diff-cover", - "accelerate", - "beartype<0.16.0", - "huggingface_hub", - "torch", "numpy", "scipy", - "transformers", - "datasets", - "pillow", "asv", "psutil", "setuptools-rust", @@ -95,7 +87,6 @@ module = [ "jsonschema.*", "pydantic.*", "pytest", - "interegular.*", "setuptools.*", "setuptools_rust.*", ] diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5dd8b5d..02c7348 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -2,7 +2,6 @@ import re from typing import Literal, Union -import interegular import pytest from outlines_core.fsm.json_schema import build_regex_from_schema, to_regex from pydantic import BaseModel, Field @@ -55,9 +54,6 @@ class Model(BaseModel): json_schema = json.dumps(Model.model_json_schema()) pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) - # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() - interegular.parse_pattern(pattern).to_fsm() - def test_match_object(): test_regex = to_regex( From e04e5bec5b3db3a2c3f4c9d465845452e3f1e3d6 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 7 Jan 2025 15:47:06 +0000 Subject: [PATCH 17/22] Index by ref in Guide --- benchmarks/bench_regex_guide.py | 25 +++++++++++++++++++++++-- src/python_bindings/mod.rs | 14 ++++++++------ tests/fsm/test_guide.py | 1 - tests/fsm/test_json_schema.py | 2 +- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index d921a7d..d077cd2 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,7 +1,8 @@ +import os from concurrent.futures import ThreadPoolExecutor import psutil -from outlines_core.fsm import Index, Vocabulary +from outlines_core.fsm import Guide, Index, Vocabulary regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -60,5 +61,25 @@ def setup(self, pattern_name): self.vocabulary = Vocabulary.from_pretrained("gpt2") self.pattern = regex_samples[pattern_name] - def peakmem_regex_to_guide(self, pattern_name): + def peakmem_regex_to_index(self, pattern_name): Index(self.pattern, self.vocabulary) + + +class MemoryStabilityBenchmark: + params = [1, 10_000] + + def setup(self, num): + self.vocabulary = Vocabulary.from_pretrained("gpt2") + self.index = Index(".*", self.vocabulary) + self.process = psutil.Process(os.getpid()) + + def _memory_usage(self): + return self.process.memory_info().rss / 1024**2 + + def peakmem_guides_per_index(self, num_guides): + initial = self._memory_usage() + objects = [Guide(self.index) for i in range(num_guides)] + final = self._memory_usage() + + assert len(objects) == num_guides + assert final - initial < 5 diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index a702be2..255b091 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::index::Index; use crate::json_schema; use crate::prelude::*; @@ -28,7 +30,7 @@ pub struct PyGuide { #[pymethods] impl PyGuide { #[new] - fn new(index: PyIndex) -> Self { + fn __new__(index: PyIndex) -> Self { PyGuide { state: index.get_initial_state(), index, @@ -102,15 +104,15 @@ impl PyGuide { #[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] #[derive(Clone, Debug, Encode, Decode)] -pub struct PyIndex(Index); +pub struct PyIndex(Arc); #[pymethods] impl PyIndex { #[new] - fn new(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { + fn __new__(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { py.allow_threads(|| { Index::new(regex, &vocabulary.0) - .map(PyIndex) + .map(|x| PyIndex(Arc::new(x))) .map_err(Into::into) }) } @@ -164,7 +166,7 @@ impl PyIndex { bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { PyErr::new::(format!("Deserialization of Index failed: {}", e)) })?; - Ok(PyIndex(index)) + Ok(PyIndex(Arc::new(index))) } } @@ -193,7 +195,7 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[new] - fn new(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { + fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { if let Ok(dict) = map.extract::>>(py) { return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); } diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index db7ab07..d48738a 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -73,7 +73,6 @@ def test_str_and_bytes_produce_the_same(): index2 = Index(regex, vocabulary2) guide1 = Guide(index1) guide2 = Guide(index2) - assert False assert guide1.read_next_token(3) == guide2.read_next_token(3) # `a` and `b` have similar transitions to `z` diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 02c7348..36a269e 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -52,7 +52,7 @@ class Model(BaseModel): n: int json_schema = json.dumps(Model.model_json_schema()) - pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) + build_regex_from_schema(json_schema, whitespace_pattern=None) def test_match_object(): From 7b6781b2066c23a8cbab0d6c3f86b19902f0d18b Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 8 Jan 2025 18:31:37 +0000 Subject: [PATCH 18/22] Extend interface of python bindings --- python/outlines_core/fsm/outlines_core_rs.pyi | 15 ++++ src/index.rs | 2 +- src/python_bindings/mod.rs | 35 ++++++++- tests/fsm/test_guide.py | 38 +++++++--- tests/fsm/test_index.py | 40 +++++++++++ tests/fsm/test_vocabulary.py | 71 ++++++++++++++----- 6 files changed, 171 insertions(+), 30 deletions(-) create mode 100644 tests/fsm/test_index.py diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index d661d34..a3f4ec7 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -36,6 +36,9 @@ class Guide: ... def __str__(self) -> str: """Gets the string representation of the guide.""" + def __eq__(self, other: object) -> bool: + """Compares whether two guides are the same.""" + ... class Vocabulary: def __init__(self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]]): @@ -47,6 +50,9 @@ class Vocabulary: ) -> "Vocabulary": """Creates the vocabulary of a pre-trained model.""" ... + def insert(self, token: Union[str, bytes], token_id: int): + """Inserts new token with token_id or extends list of token_ids if token already present.""" + ... def get_eos_token_id(self) -> Optional[int]: """Gets the end of sentence token id.""" ... @@ -62,6 +68,9 @@ class Vocabulary: def __eq__(self, other: object) -> bool: """Compares whether two vocabularies are the same.""" ... + def __deepcopy__(self, memo: dict) -> "Vocabulary": + """Makes a deep copy of the Vocabulary.""" + ... class Index: def __init__(self, regex: str, vocabulary: "Vocabulary"): @@ -90,3 +99,9 @@ class Index: ... def __str__(self) -> str: """Gets the string representation of the index.""" + def __eq__(self, other: object) -> bool: + """Compares whether two indexes are the same.""" + ... + def __deepcopy__(self, memo: dict) -> "Index": + """Makes a deep copy of the Index.""" + ... diff --git a/src/index.rs b/src/index.rs index 7725acc..c005587 100644 --- a/src/index.rs +++ b/src/index.rs @@ -8,7 +8,7 @@ use regex_automata::util::primitives::StateID as AutomataStateId; use regex_automata::Anchored; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -#[derive(Clone, Debug, Encode, Decode)] +#[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct Index { initial: StateId, finals: HashSet, diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 255b091..3fbdd30 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -21,7 +21,7 @@ macro_rules! type_name { } #[pyclass(name = "Guide", module = "outlines_core.fsm.outlines_core_rs")] -#[derive(Clone, Debug, Encode, Decode)] +#[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct PyGuide { state: StateId, index: PyIndex, @@ -80,6 +80,10 @@ impl PyGuide { ) } + fn __eq__(&self, other: &PyGuide) -> bool { + self == other + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -103,7 +107,7 @@ impl PyGuide { } #[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] -#[derive(Clone, Debug, Encode, Decode)] +#[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct PyIndex(Arc); #[pymethods] @@ -148,6 +152,14 @@ impl PyIndex { format!("{}", self.0) } + fn __eq__(&self, other: &PyIndex) -> bool { + *self.0 == *other.0 + } + + fn __deepcopy__(&self, _py: Python<'_>, _memo: Py) -> Self { + PyIndex(Arc::new((*self.0).clone())) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -234,6 +246,21 @@ impl PyVocabulary { Ok(PyVocabulary(v)) } + fn insert(&mut self, py: Python<'_>, token: Py, token_id: TokenId) -> PyResult<()> { + if let Ok(t) = token.extract::(py) { + self.0.insert(t, token_id); + return Ok(()); + } + if let Ok(t) = token.extract::(py) { + self.0.insert(t, token_id); + return Ok(()); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) + } + fn get_eos_token_id(&self) -> TokenId { self.0.eos_token_id() } @@ -267,6 +294,10 @@ impl PyVocabulary { self.0.tokens_to_ids().len() } + fn __deepcopy__(&self, _py: Python<'_>, _memo: Py) -> Self { + PyVocabulary(self.0.clone()) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index d48738a..8413ee9 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,9 +1,22 @@ +import copy import pickle +from typing import Dict, List, Union import pytest from outlines_core.fsm import Guide, Index, Vocabulary +@pytest.fixture(scope="session") +def index() -> Index: + eos_token_id = 3 + # types here only to please mypy checks + tokens: Dict[Union[str, bytes], List[int]] = {"1": [1], "2": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + return Index(regex, vocabulary) + + def test_stop_at_eos(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} @@ -82,15 +95,8 @@ def test_str_and_bytes_produce_the_same(): assert guide2.is_finished() -def test_pickling(): - eos_token_id = 3 - tokens = {"1": [1], "2": [2]} - regex = r"[1-9]" - - vocabulary = Vocabulary(eos_token_id, tokens) - index = Index(regex, vocabulary) +def test_pickling(index): guide = Guide(index) - serialized = pickle.dumps(guide) deserialized = pickle.loads(serialized) assert sorted(deserialized.get_start_tokens()) == sorted(guide.get_start_tokens()) @@ -119,3 +125,19 @@ def test_pickling_from_pretrained_with_revision(model, revision): serialized = pickle.dumps(guide) deserialized = pickle.loads(serialized) assert sorted(deserialized.get_start_tokens()) == sorted(guide.get_start_tokens()) + + +def test_equality(index): + guide1 = Guide(index) + guide2 = Guide(index) + assert guide1 == guide2 + + # confirm that equality is about inner index, not reference difference + index2 = copy.deepcopy(index) + guide3 = Guide(index2) + assert guide3 == guide2 == guide1 + + # progress one of the guides, confirm different state == different guide + guide1.read_next_token(guide1.get_start_tokens()[-1]) + assert guide1 != guide2 + assert guide3 == guide2 diff --git a/tests/fsm/test_index.py b/tests/fsm/test_index.py new file mode 100644 index 0000000..799b468 --- /dev/null +++ b/tests/fsm/test_index.py @@ -0,0 +1,40 @@ +import copy +import gc +import pickle +from typing import Dict, List, Union + +import pytest +from outlines_core.fsm import Index, Vocabulary + + +@pytest.fixture(scope="session") +def index() -> Index: + eos_token_id = 3 + # types here only to please mypy checks + tokens: Dict[Union[str, bytes], List[int]] = {"1": [1], "2": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + return Index(regex, vocabulary) + + +def test_pickling(index): + serialized = pickle.dumps(index) + deserialized = pickle.loads(serialized) + assert deserialized == index + + +def test_deepcopy(index): + index2 = copy.deepcopy(index) + assert index2 == index + + copy_index2 = copy.deepcopy(index2) + assert copy_index2 == index2 + + index2_id = id(index2) + del index2 + gc.collect() + is_deleted = not any(id(o) == index2_id for o in gc.get_objects()) + assert is_deleted + + assert copy_index2 == index diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index 4bc54fa..447b260 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -1,29 +1,48 @@ +import copy import pickle import pytest from outlines_core.fsm import Vocabulary -def test_supports_strings_as_keys(): +@pytest.fixture(scope="session") +def vocabulary(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + return Vocabulary(eos_token_id, tokens) + + +def test_basic_vocabulary_interface(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} vocabulary = Vocabulary(eos_token_id, tokens) assert vocabulary.get_eos_token_id() == eos_token_id - assert vocabulary.get("1") == [1] - assert vocabulary.get(b"1") == [1] + assert vocabulary.get("1") == vocabulary.get(b"1") == [1] assert len(vocabulary) == 2 + vocabulary.insert("b", 4) + assert vocabulary.get("b") == [4] + assert len(vocabulary) == 3 -def test_supports_bytes_as_keys(): + vocabulary.insert(b"b", 5) + assert vocabulary.get("b") == vocabulary.get(b"b") == [4, 5] + assert len(vocabulary) == 3 + + +def test_string_and_bytes_as_tokens(): eos_token_id = 3 - tokens = {b"1": [1], b"a": [2]} + tokens = {"1": [1], "a": [2]} + btokens = {b"1": [1], b"a": [2]} vocabulary = Vocabulary(eos_token_id, tokens) + bvocabulary = Vocabulary(eos_token_id, btokens) - assert vocabulary.get_eos_token_id() == eos_token_id - assert vocabulary.get(b"1") == [1] - assert vocabulary.get("1") == [1] - assert len(vocabulary) == 2 + assert ( + vocabulary.get_eos_token_id() == bvocabulary.get_eos_token_id() == eos_token_id + ) + assert vocabulary.get(b"1") == vocabulary.get("1") == [1] + assert bvocabulary.get(b"1") == bvocabulary.get("1") == [1] + assert len(vocabulary) == len(bvocabulary) == 2 def test_do_not_supports_other_types(): @@ -42,11 +61,7 @@ def test_do_not_supports_other_types(): Vocabulary(eos_token_id, {1: [1], 2: [2]}) -def test_get_bad_type(): - eos_token_id = 3 - tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary(eos_token_id, tokens) - +def test_get_bad_type(vocabulary): with pytest.raises( TypeError, match="Expected a token of type str or bytes, got", @@ -54,16 +69,34 @@ def test_get_bad_type(): vocabulary.get(1) +def test_insert_bad_type(vocabulary): + with pytest.raises( + TypeError, + match="Expected a token of type str or bytes, got", + ): + vocabulary.insert(1, 6) + + def test_from_pretrained(): vocabulary = Vocabulary.from_pretrained("gpt2") assert vocabulary.get_eos_token_id() == 50256 -def test_pickling(): - eos_token_id = 3 - tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary(eos_token_id, tokens) - +def test_pickling(vocabulary): serialized = pickle.dumps(vocabulary) deserialized = pickle.loads(serialized) assert deserialized == vocabulary + + +def test_deepcopy(vocabulary): + vocabulary2 = copy.deepcopy(vocabulary) + assert vocabulary2 == vocabulary + + copy_vocabulary2 = copy.deepcopy(vocabulary2) + assert copy_vocabulary2 == vocabulary2 + + vocabulary2.insert("new", 4) + assert vocabulary2 != copy_vocabulary2 + assert len(vocabulary2) - 1 == len(copy_vocabulary2) + assert copy_vocabulary2 == vocabulary + assert len(copy_vocabulary2) == len(vocabulary) From 1fab8726ba8f2f20739e58fbb0b8c7772d98e454 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 9 Jan 2025 12:51:30 +0000 Subject: [PATCH 19/22] Disallow insert of eos token into Vocabulary --- src/error.rs | 4 ++ src/index.rs | 16 ++++++-- src/python_bindings/mod.rs | 10 ++--- src/vocabulary/mod.rs | 75 +++++++++++++++++++++++++----------- src/vocabulary/processor.rs | 12 +++--- tests/fsm/test_vocabulary.py | 7 ++++ 6 files changed, 85 insertions(+), 39 deletions(-) diff --git a/src/error.rs b/src/error.rs index 4ffe7ed..e5781e8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,12 +4,16 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { + // Index Errors #[error("The vocabulary does not allow to build an index that matches the input")] InsufficientVocabulary, #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] DfaHasNoStartState, + // Vocabulary Errors + #[error("EOS token should not be inserted into Vocabulary")] + EOSTokenDisallowed, #[error(transparent)] TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] diff --git a/src/index.rs b/src/index.rs index c005587..407e471 100644 --- a/src/index.rs +++ b/src/index.rs @@ -142,7 +142,9 @@ mod tests { let regex = "0|[1-9][0-9]*"; let mut vocabulary = Vocabulary::new(4); for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); @@ -163,7 +165,9 @@ mod tests { let regex = "`\\n(\\.\\n)?`\\n"; let mut vocabulary = Vocabulary::new(104); for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); @@ -179,14 +183,18 @@ mod tests { let mut vocabulary = Vocabulary::new(8); for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } for (token, token_id) in [ (vec![32, 240, 159, 152], 7), (vec![32, 240, 159, 152, 141], 6), (vec![240, 159, 152, 141], 4), ] { - vocabulary.insert(token, token_id as u32); + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); } let index = Index::new(regex, &vocabulary).expect("Index failed"); diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 3fbdd30..eca8e7b 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -209,10 +209,10 @@ impl PyVocabulary { #[new] fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { if let Ok(dict) = map.extract::>>(py) { - return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); } if let Ok(dict) = map.extract::, Vec>>(py) { - return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict)))); + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); } let message = "Expected a dict with keys of type str or bytes and values of type list[int]"; @@ -248,12 +248,10 @@ impl PyVocabulary { fn insert(&mut self, py: Python<'_>, token: Py, token_id: TokenId) -> PyResult<()> { if let Ok(t) = token.extract::(py) { - self.0.insert(t, token_id); - return Ok(()); + return Ok(self.0.try_insert(t, token_id)?); } if let Ok(t) = token.extract::(py) { - self.0.insert(t, token_id); - return Ok(()); + return Ok(self.0.try_insert(t, token_id)?); } Err(PyErr::new::(format!( "Expected a token of type str or bytes, got {:?}", diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index dbae9b6..71f2c42 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -24,12 +24,13 @@ mod processor; /// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None); /// ``` /// -/// ### Create an empty vocabulary. +/// ### Create an empty vocabulary and manually insert tokens. /// ```rust /// # use outlines_core::prelude::*; /// # -/// let mut vocabulary = Vocabulary::new(1); -/// vocabulary.insert("token", 0); +/// let eos_token_id = 1; +/// let mut vocabulary = Vocabulary::new(eos_token_id); +/// vocabulary.try_insert("token", 0).expect("New token inserted"); /// ``` #[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { @@ -47,9 +48,13 @@ impl Vocabulary { } /// Inserts a token to the vocabulary with the specified identifier. - pub fn insert(&mut self, token: impl Into, id: TokenId) { + pub fn try_insert(&mut self, token: impl Into, id: TokenId) -> Result<(), Error> { + if id == self.eos_token_id { + return Err(Error::EOSTokenDisallowed); + } let token = token.into(); self.tokens.entry(token).or_default().push(id); + Ok(()) } /// Creates the vocabulary of pre-trained model from Hugging Face Hub. @@ -81,8 +86,8 @@ impl Vocabulary { // Start building the vocabulary from eos_token_id and added tokens. let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { - if !added_token.special { - vocabulary.insert(added_token.content.clone(), *id); + if !added_token.special && id != &eos_token_id { + vocabulary.try_insert(added_token.content.clone(), *id)? } } @@ -94,8 +99,10 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let processed_token = processor.process(token)?; - vocabulary.insert(processed_token, token_id); + if token_id != eos_token_id { + let processed_token = processor.process(&token)?; + vocabulary.try_insert(processed_token, token_id)?; + } } Ok(vocabulary) @@ -169,26 +176,39 @@ impl std::fmt::Display for Vocabulary { } } -impl From<(TokenId, HashMap>)> for Vocabulary { - fn from(values: (TokenId, HashMap>)) -> Vocabulary { +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { let (eos_token_id, tokens) = values; - Vocabulary { + if tokens.iter().any(|(_, ids)| ids.contains(&eos_token_id)) { + return Err(Error::EOSTokenDisallowed); + } + Ok(Vocabulary { eos_token_id, tokens, - } + }) } } -impl From<(TokenId, HashMap>)> for Vocabulary { - fn from(values: (TokenId, HashMap>)) -> Vocabulary { +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { let (eos_token_id, tokens) = values; - Vocabulary { + Ok(Vocabulary { eos_token_id, tokens: tokens .into_iter() - .map(|(k, v)| (k.as_bytes().to_vec(), v)) - .collect::>>(), - } + .map(|(k, v)| { + if v.contains(&eos_token_id) { + Err(Error::EOSTokenDisallowed) + } else { + Ok((k.as_bytes().to_vec(), v)) + } + }) + .collect::>, _>>()?, + }) } } @@ -202,32 +222,41 @@ mod tests { let eos_token_id = 3; let mut vocabulary = Vocabulary::new(eos_token_id); + match vocabulary.try_insert("eos-token", eos_token_id) { + Err(Error::EOSTokenDisallowed) => {} + _ => unreachable!(), + } + // New empty vocabulary. assert_eq!(vocabulary.eos_token_id, eos_token_id); assert!(vocabulary.tokens.is_empty()); for (token, id) in [("zero", 0), ("one", 1), ("two", 2)] { - vocabulary.insert(token, id); + vocabulary.try_insert(token, id).expect("Insert failed"); assert_eq!(vocabulary.token_to_ids(token), Some(&vec![id])); } assert_eq!(vocabulary.tokens.len(), 3); assert_eq!(vocabulary.tokens_to_ids().len(), 3); // Confirm different types. - vocabulary.insert(b"four", 4); + vocabulary.try_insert(b"four", 4).expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("four"), Some(&vec![4])); - vocabulary.insert(b"five".to_vec(), 5); + vocabulary + .try_insert(b"five".to_vec(), 5) + .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("five"), Some(&vec![5])); - vocabulary.insert("six".to_string(), 6); + vocabulary + .try_insert("six".to_string(), 6) + .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6])); } #[test] fn new_empty_vocabulary_from_hashmap() { let map: HashMap> = HashMap::default(); - let vocabulary = Vocabulary::from((1_u32, map)); + let vocabulary = Vocabulary::try_from((1_u32, map)).expect("Vocabulary failed"); assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 7426f24..1e10bf0 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -102,7 +102,7 @@ impl Default for Mods { impl Mods { /// Apply default modifications to each token. - fn apply_default(&self, token: String) -> String { + fn apply_default(&self, token: &str) -> String { let to = Self::default().spacechar.to_string(); token.replace(self.spacechar, &to) } @@ -190,7 +190,7 @@ impl TokenProcessor { } /// Operates on each token based on the level of `TokenProcessor`. - pub(crate) fn process(&self, token: String) -> Result> { + pub(crate) fn process(&self, token: &str) -> Result> { match &self.level { TokenProcessorLevel::Byte => token .chars() @@ -275,7 +275,7 @@ mod tests { ('þ', 0xFE), ('ÿ', 0xFF), ] { - let processed = processor.process(ch.to_string()).expect("Not processed"); + let processed = processor.process(&ch.to_string()).expect("Not processed"); assert_eq!(processed, [byte]); } } @@ -304,7 +304,7 @@ mod tests { vec![0x20, 0x20, 0x20], ), ] { - let processed = processor.process(input.to_string()).expect("Not processed"); + let processed = processor.process(input).expect("Not processed"); assert_eq!(processed, expected); } } @@ -328,7 +328,7 @@ mod tests { let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); for token in ["𝒜𝒷𝒸𝒟𝓔", "🦄🌈🌍🔥🎉", "京东购物"] { - let result = processor.process(token.to_string()); + let result = processor.process(token); match result { Err(Error::ByteProcessorFailed) => {} _ => unreachable!(), @@ -342,7 +342,7 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); - let result = processor.process("<0x6y>".to_string()); + let result = processor.process("<0x6y>"); match result { Err(Error::ByteFallbackProcessorFailed) => {} _ => unreachable!(), diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index 447b260..f4879d6 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -77,6 +77,13 @@ def test_insert_bad_type(vocabulary): vocabulary.insert(1, 6) +def test_insert_eos_token(vocabulary): + with pytest.raises( + ValueError, match="EOS token should not be inserted into Vocabulary" + ): + vocabulary.insert("eos-token", 3) + + def test_from_pretrained(): vocabulary = Vocabulary.from_pretrained("gpt2") assert vocabulary.get_eos_token_id() == 50256 From 15a45c0ca6aafcccab47ae72491be579bd453cff Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 9 Jan 2025 15:00:56 +0000 Subject: [PATCH 20/22] Stabilize Index interfaces --- python/outlines_core/fsm/outlines_core_rs.pyi | 5 ++- src/index.rs | 44 +++++++++---------- src/python_bindings/mod.rs | 4 +- tests/fsm/test_index.py | 25 +++++++++++ 4 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index a3f4ec7..77b0823 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -68,6 +68,9 @@ class Vocabulary: def __eq__(self, other: object) -> bool: """Compares whether two vocabularies are the same.""" ... + def __len__(self) -> int: + """Returns length of Vocabulary's tokens, excluding EOS token.""" + ... def __deepcopy__(self, memo: dict) -> "Vocabulary": """Makes a deep copy of the Vocabulary.""" ... @@ -85,7 +88,7 @@ class Index: def is_final_state(self, state: int) -> bool: """Determines whether the current state is a final state.""" ... - def final_states(self) -> List[int]: + def get_final_states(self) -> List[int]: """Get all final states.""" ... def get_transitions(self) -> Dict[int, Dict[int, int]]: diff --git a/src/index.rs b/src/index.rs index 407e471..f6c433b 100644 --- a/src/index.rs +++ b/src/index.rs @@ -10,14 +10,14 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct Index { - initial: StateId, - finals: HashSet, - states_to_token_subsets: HashMap>, + initial_state: StateId, + final_states: HashSet, + transitions: HashMap>, eos_token_id: TokenId, } impl Index { - pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result { + pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result { let eos_token_id = vocabulary.eos_token_id(); let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { @@ -83,9 +83,9 @@ impl Index { if is_valid { Ok(Self { - initial: start_state.as_u32(), - finals: final_states, - states_to_token_subsets: transitions, + initial_state: start_state.as_u32(), + final_states, + transitions, eos_token_id, }) } else { @@ -93,40 +93,40 @@ impl Index { } } - pub(crate) fn allowed_tokens(&self, state: StateId) -> Option> { - self.states_to_token_subsets + pub fn allowed_tokens(&self, state: StateId) -> Option> { + self.transitions .get(&state) .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) } - pub(crate) fn next_state(&self, state: StateId, token_id: TokenId) -> Option { + pub fn next_state(&self, state: StateId, token_id: TokenId) -> Option { if token_id == self.eos_token_id { return None; } - Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?) + Some(*self.transitions.get(&state)?.get(&token_id)?) } - pub(crate) fn initial(&self) -> StateId { - self.initial + pub fn initial_state(&self) -> StateId { + self.initial_state } - pub(crate) fn is_final(&self, state: StateId) -> bool { - self.finals.contains(&state) + pub fn is_final(&self, state: StateId) -> bool { + self.final_states.contains(&state) } - pub(crate) fn final_states(&self) -> &HashSet { - &self.finals + pub fn final_states(&self) -> &HashSet { + &self.final_states } - pub(crate) fn transitions(&self) -> &HashMap> { - &self.states_to_token_subsets + pub fn transitions(&self) -> &HashMap> { + &self.transitions } } impl std::fmt::Display for Index { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "Index object with transitions:")?; - for (state_id, token_ids) in self.states_to_token_subsets.iter() { + for (state_id, token_ids) in self.transitions.iter() { writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?; } Ok(()) @@ -148,7 +148,7 @@ mod tests { } let index = Index::new(regex, &vocabulary).expect("Index failed"); - assert_eq!(index.initial(), 40); + assert_eq!(index.initial_state(), 40); assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); let expected = HashMap::from_iter([ @@ -172,7 +172,7 @@ mod tests { let index = Index::new(regex, &vocabulary).expect("Index failed"); let allowed = index - .allowed_tokens(index.initial()) + .allowed_tokens(index.initial_state()) .expect("No allowed tokens"); assert!(allowed.contains(&101)); } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index eca8e7b..5655baa 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -133,7 +133,7 @@ impl PyIndex { self.0.is_final(state) } - fn final_states(&self) -> HashSet { + fn get_final_states(&self) -> HashSet { self.0.final_states().clone() } @@ -142,7 +142,7 @@ impl PyIndex { } fn get_initial_state(&self) -> StateId { - self.0.initial() + self.0.initial_state() } fn __repr__(&self) -> String { format!("{:#?}", self.0) diff --git a/tests/fsm/test_index.py b/tests/fsm/test_index.py index 799b468..5b56088 100644 --- a/tests/fsm/test_index.py +++ b/tests/fsm/test_index.py @@ -18,6 +18,31 @@ def index() -> Index: return Index(regex, vocabulary) +def test_basic_interface(index): + init_state = index.get_initial_state() + assert init_state == 12 + assert index.is_final_state(init_state) is False + + allowed_tokens = index.get_allowed_tokens(init_state) + assert allowed_tokens == [1, 2] + + next_state = index.get_next_state(init_state, allowed_tokens[-1]) + assert next_state == 20 + assert index.is_final_state(next_state) is True + assert index.get_final_states() == {20} + + expected_transitions = { + 12: { + 1: 20, + 2: 20, + }, + 20: { + 3: 20, + }, + } + assert index.get_transitions() == expected_transitions + + def test_pickling(index): serialized = pickle.dumps(index) deserialized = pickle.loads(serialized) From bf6e8a6ade5dbf43eb2855a1cce713abdb29164c Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 9 Jan 2025 18:03:20 +0000 Subject: [PATCH 21/22] Use new interface in statistical --- tests/fsm/test_statistical.py | 37 +++++++++++------------------------ 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/tests/fsm/test_statistical.py b/tests/fsm/test_statistical.py index 20ef28c..942796a 100644 --- a/tests/fsm/test_statistical.py +++ b/tests/fsm/test_statistical.py @@ -1,27 +1,12 @@ from typing import Callable, List, Optional import numpy as np -from outlines_core.fsm.guide import RegexGuide +from outlines_core.fsm import Guide, Index, Vocabulary from pytest import approx from scipy.stats import ks_2samp def test_generate_length(): - class MockTokenizer: - vocabulary = {"0": 1, "1": 2, "eos": 3} - inverse_vocabulary = {1: "0", 2: "1", 3: ""} - special_tokens = {"eos"} - eos_token_id = 3 - - def length(self): - return len(self.vocabulary) - - def convert_token_to_string(self, token): - return token - - def decode(self, token): - return self.inverse_vocabulary[token] - class NextToken: def __init__( self, @@ -43,17 +28,18 @@ def __call__( next_t = [self.rng.choice(self.states, p=prob / np.sum(prob))] return tokens + next_t if tokens is not None else next_t - def generate(model, tokenizer, regex_str) -> Optional[List[int]]: - n_tokens = tokenizer.length() + def generate(model, regex_str) -> Optional[List[int]]: + vocabulary = Vocabulary(3, {"0": [1], "1": [2], "2": [4]}) + index = Index(regex_str, vocabulary) + guide = Guide(index) - fsm = RegexGuide.from_regex(regex_str, tokenizer) - state: int = fsm.initial_state + n_tokens = len(vocabulary) tokens = None - while state != -1: - allowed = fsm.get_next_instruction(state).tokens + allowed = guide.get_start_tokens() + while not guide.is_finished(): mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)] tokens = model(tokens, mask=mask) - state = fsm.get_next_state(state, tokens[-1]) + allowed = guide.read_next_token(tokens[-1]) return tokens def prob_non_markov(tokens: List[int]) -> np.array: @@ -75,16 +61,15 @@ def prob_markov(token: List[int]) -> np.array: n_samples: int = 250 regex_str: str = r"11[01]+|0[01]*" - tokenizer = MockTokenizer() model1 = NextToken(prob_markov, p0, states, 30127) model2 = NextToken(prob_non_markov, p0, states, 24601) lengths1: np.array = np.zeros((n_samples,)) lengths2: np.array = np.zeros((n_samples,)) for i in range(n_samples): - out1: List[int] = generate(model1, tokenizer, regex_str) + out1: List[int] = generate(model1, regex_str) lengths1[i] = len(out1) - 1 # take off the eos token - out2: List[int] = generate(model2, tokenizer, regex_str) + out2: List[int] = generate(model2, regex_str) lengths2[i] = len(out2) - 1 # take off the eos token # 2 sample KS test to check that lengths has the same distribution as From 3fef1d86ad78a4633c23021513ee07d576558f1d Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 10 Jan 2025 13:00:14 +0000 Subject: [PATCH 22/22] Add `remove` to vocabulary interfaces --- benchmarks/bench_regex_guide.py | 2 +- python/outlines_core/fsm/outlines_core_rs.pyi | 3 +++ src/python_bindings/mod.rs | 15 +++++++++++++++ src/vocabulary/mod.rs | 15 +++++++++++++++ tests/fsm/test_vocabulary.py | 19 +++++++++++++------ 5 files changed, 47 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index d077cd2..f5a42af 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -38,7 +38,7 @@ def time_regex_to_guide_parallel(self, pattern_name): def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): # Note: after moving to full rust implementation for index and guide creation, this experiment # is no longer shows the drastic difference as it once showed when python was heavily involved, - # due to on average speedup ~100 times. + # due to speedup up to ~100 times. # This test is to show, that if GIL's switch interval is set to be longer, then the parallel # test's runtime on physical cores will be much closer to the one-threaded case. diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 77b0823..a2d7921 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -53,6 +53,9 @@ class Vocabulary: def insert(self, token: Union[str, bytes], token_id: int): """Inserts new token with token_id or extends list of token_ids if token already present.""" ... + def remove(self, token: Union[str, bytes]): + """Removes a token from vocabulary.""" + ... def get_eos_token_id(self) -> Optional[int]: """Gets the end of sentence token id.""" ... diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 5655baa..5f1fac3 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -259,6 +259,21 @@ impl PyVocabulary { ))) } + fn remove(&mut self, py: Python<'_>, token: Py) -> PyResult<()> { + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) + } + fn get_eos_token_id(&self) -> TokenId { self.0.eos_token_id() } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 71f2c42..821abd5 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -57,6 +57,12 @@ impl Vocabulary { Ok(()) } + /// Removes a token from the vocabulary. + pub fn remove(&mut self, token: impl Into) { + let token = token.into(); + self.tokens.remove(&token); + } + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -251,6 +257,15 @@ mod tests { .try_insert("six".to_string(), 6) .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6])); + + vocabulary.remove(b"four"); + assert_eq!(vocabulary.token_to_ids("four"), None); + + vocabulary.remove(b"five".to_vec()); + assert_eq!(vocabulary.token_to_ids("five"), None); + + vocabulary.remove("six".to_string()); + assert_eq!(vocabulary.token_to_ids("six"), None); } #[test] diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index f4879d6..e44e2da 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -12,12 +12,8 @@ def vocabulary(): return Vocabulary(eos_token_id, tokens) -def test_basic_vocabulary_interface(): - eos_token_id = 3 - tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary(eos_token_id, tokens) - - assert vocabulary.get_eos_token_id() == eos_token_id +def test_basic_vocabulary_interface(vocabulary): + assert vocabulary.get_eos_token_id() == 3 assert vocabulary.get("1") == vocabulary.get(b"1") == [1] assert len(vocabulary) == 2 @@ -29,6 +25,17 @@ def test_basic_vocabulary_interface(): assert vocabulary.get("b") == vocabulary.get(b"b") == [4, 5] assert len(vocabulary) == 3 + vocabulary.remove("b") + assert vocabulary.get("b") is None + + # second remove doesn't fail too + vocabulary.remove("b") + assert vocabulary.get("b") is None + + assert vocabulary.get("a") == [2] + vocabulary.remove(b"a") + assert vocabulary.get("a") is None + def test_string_and_bytes_as_tokens(): eos_token_id = 3