diff --git a/Cargo.toml b/Cargo.toml index 9f3f31ad..f49df1ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bincode = "2.0.0-rc.3" hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } rustc-hash = "2.1.0" +regex-automata = "0.4.9" [features] python-bindings = ["pyo3"] diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 964caa35..f88ed6c6 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,8 +1,6 @@ -from outlines_core.fsm.guide import RegexGuide +from outlines_core.fsm import Index, Vocabulary from outlines_core.fsm.json_schema import build_regex_from_schema -from .common import setup_tokenizer # noqa: E402 - simple_schema = """{ "$defs": { "Armor": { @@ -66,7 +64,7 @@ class JsonSchemaBenchmark: params = schemas.keys() def setup(self, schema_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.schema = schemas[schema_name] def time_json_schema_to_regex(self, schema_name): @@ -74,4 +72,4 @@ def time_json_schema_to_regex(self, schema_name): def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) - RegexGuide.from_regex(regex, self.tokenizer) + Index(regex, self.vocabulary) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 5dda576a..4d9a343c 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,9 +1,8 @@ +import os from concurrent.futures import ThreadPoolExecutor import psutil -from outlines_core.fsm.guide import RegexGuide - -from .common import setup_tokenizer +from outlines_core.fsm import Guide, Index, Vocabulary regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -18,17 +17,17 @@ } -class RegexGuideBenchmark: +class RegexIndexBenchmark: params = regex_samples.keys() def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.pattern = regex_samples[pattern_name] def time_regex_to_guide(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + Index(self.pattern, self.vocabulary) - def time_regex_to_guide_parallel(self, pattern_name): + def time_regex_to_guide_threads(self, pattern_name): # Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks, # this parallel case should be relatively close in runtime to one thread, but it is not, # because of the GIL. @@ -36,7 +35,11 @@ def time_regex_to_guide_parallel(self, pattern_name): with ThreadPoolExecutor(max_workers=core_count) as executor: list(executor.map(self._from_regex, [pattern_name] * core_count)) - def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): + def time_regex_to_guide_threads_with_custom_switch_interval(self, pattern_name): + # Note: after moving to full rust implementation for index and guide creation, this experiment + # is no longer shows the drastic difference as it once showed when python was heavily involved, + # due to average speedup ~10 times. + # This test is to show, that if GIL's switch interval is set to be longer, then the parallel # test's runtime on physical cores will be much closer to the one-threaded case. import sys @@ -48,15 +51,35 @@ def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name) list(executor.map(self._from_regex, [pattern_name] * core_count)) def _from_regex(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + Index(self.pattern, self.vocabulary) -class MemoryRegexGuideBenchmark: +class MemoryRegexIndexBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() + self.vocabulary = Vocabulary.from_pretrained("gpt2") self.pattern = regex_samples[pattern_name] - def peakmem_regex_to_guide(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + def peakmem_regex_to_index(self, pattern_name): + Index(self.pattern, self.vocabulary) + + +class MemoryStabilityBenchmark: + params = [1, 10_000] + + def setup(self, num): + self.vocabulary = Vocabulary.from_pretrained("gpt2") + self.index = Index(".*", self.vocabulary) + self.process = psutil.Process(os.getpid()) + + def _memory_usage(self): + return self.process.memory_info().rss / 1024**2 + + def peakmem_guides_per_index(self, num_guides): + initial = self._memory_usage() + objects = [Guide(self.index) for i in range(num_guides)] + final = self._memory_usage() + + assert len(objects) == num_guides + assert final - initial < 5 diff --git a/benchmarks/common.py b/benchmarks/common.py deleted file mode 100644 index b56677e7..00000000 --- a/benchmarks/common.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from datasets.fingerprint import Hasher -from transformers import AutoTokenizer, PreTrainedTokenizer - - -def get_llama_tokenizer_types(): - """Get all the Llama tokenizer types/classes that need work-arounds. - - When they can't be imported, a dummy class is created. - - """ - try: - from transformers.models.llama import LlamaTokenizer - except ImportError: - - class LlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.llama import LlamaTokenizerFast - except ImportError: - - class LlamaTokenizerFast: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizer - except ImportError: - - class CodeLlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizerFast - except ImportError: - - class CodeLlamaTokenizerFast: # type: ignore - pass - - return ( - LlamaTokenizer, - LlamaTokenizerFast, - CodeLlamaTokenizer, - CodeLlamaTokenizerFast, - ) - - -class TransformerTokenizer: - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs): - self.tokenizer = tokenizer - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: torch.LongTensor) -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __hash__(self): - return hash(Hasher.hash(self.tokenizer)) - - def __eq__(self, other): - if isinstance(other, type(self)): - if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return ( - other.model_name == self.model_name and other.kwargs == self.kwargs - ) - else: - return other.tokenizer == self.tokenizer - return NotImplemented - - def __getstate__(self): - state = {"tokenizer": self.tokenizer} - return state - - def __setstate__(self, state): - self.__init__(state["tokenizer"]) - - -def setup_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - return TransformerTokenizer(tokenizer) diff --git a/pyproject.toml b/pyproject.toml index 3bde97b9..ec436699 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "interegular", "jsonschema", ] dynamic = ["version"] @@ -39,15 +38,8 @@ test = [ "pytest-mock", "coverage[toml]>=5.1", "diff-cover", - "accelerate", - "beartype<0.16.0", - "huggingface_hub", - "torch", "numpy", "scipy", - "transformers", - "datasets", - "pillow", "asv", "psutil", "setuptools-rust", @@ -95,7 +87,6 @@ module = [ "jsonschema.*", "pydantic.*", "pytest", - "interegular.*", "setuptools.*", "setuptools_rust.*", ] diff --git a/python/outlines_core/fsm/__init__.py b/python/outlines_core/fsm/__init__.py index e69de29b..f4769f65 100644 --- a/python/outlines_core/fsm/__init__.py +++ b/python/outlines_core/fsm/__init__.py @@ -0,0 +1 @@ +from .outlines_core_rs import Guide, Index, Vocabulary diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py deleted file mode 100644 index 18b4523c..00000000 --- a/python/outlines_core/fsm/guide.py +++ /dev/null @@ -1,309 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Protocol, Set, Tuple, Union - -import interegular -import torch -from outlines_core.fsm.regex import ( - create_fsm_index_tokenizer, - make_byte_level_fsm, - make_deterministic_fsm, -) - -from .outlines_core_rs import Index - - -@dataclass(frozen=True) -class Write: - """Write instruction. - - Attributes - ---------- - tokens - The sequence of tokens to be added to the current sequence by the - generation process. - - """ - - tokens: List[int] - - -@dataclass(frozen=True) -class Generate: - """Generate instruction - - Attributes - ---------- - tokens - The tokens that lead to a valid completion if generated. A value - of ``None`` indicates that all tokens are allowed. - """ - - tokens: Optional[List[int]] - - -Instruction = Union[Write, Generate] - - -class Guide(Protocol): - """Base definition of a generation guide. - - A generation guide defines the behavior of a finite-state machine that guides - a text generation procedure. Unlike the DFAs built from regular expressions - guides can also emit a `Write` instructions which tells the model that it can - append a sequence of tokens (or token word) instead of generating it. - - """ - - initial_state: Any - - def get_next_instruction(self, state: Any) -> Instruction: - ... - - def get_next_state(self, state: Any, token_id: int) -> Any: - ... - - def is_final_state(self, state: Any) -> bool: - ... - - def copy(self) -> "Guide": - ... - - -class StopAtEOSGuide(Guide): - """Guide to generate tokens until the EOS token has been generated.""" - - final_state = 1 - start_state = 0 # TODO: remove start_state, use only initial_state - initial_state = 0 - - def __init__(self, tokenizer): - """Initialize the generation guide. - - model - The logit generator used to generate the next token. - - """ - self.eos_token_id = tokenizer.eos_token_id - self.vocabulary = tokenizer.vocabulary.values() - - def get_next_instruction(self, state: int) -> Instruction: - if self.is_final_state(state): - return Write([self.eos_token_id]) - return Generate(None) - - def get_next_state(self, state: int, token_id: int) -> int: - if token_id == self.eos_token_id or state == self.final_state: - return self.final_state - - return self.initial_state - - def is_final_state(self, state: int): - return state == self.final_state - - def copy(self): - return self - - -def create_states_mapping( - regex_string: str, - tokenizer, - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], -) -> Tuple[Index, Set[int], Set[int]]: - """Create the variables related to the mapping between states and tokens from a regex string. - - The parameters of the function are used for caching purpose. - - Parameters - ---------- - regex_string: - The regular expression string to generate a states mapping for. - tokenizer: - The model's tokenizer. - regex_parser: - A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: - A set of token ids that correspond to empty strings. - final_states: - A set of final states in the FSM. - """ - regex_fsm = regex_parser(regex_string).to_fsm() - return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) - - -def create_states_mapping_from_fsm( - fsm: interegular.fsm.FSM, - tokenizer, - frozen_tokens: List[str] = [], -) -> Tuple[Index, Set[int], Set[int]]: - """Create the variables related to the mapping between states and tokens from an FSM. - - The parameters of the function are used for caching purpose. - - Parameters - ---------- - fsm: - An FSM for the regular expression. - tokenizer: - The model's tokenizer. - frozen_tokens: - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: - A set of token ids that correspond to empty strings. - final_states: - A set of final states in the FSM. - """ - byte_fsm = make_byte_level_fsm( - fsm.reduce(), keep_utf8=True, frozen_tokens=frozen_tokens - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - - -class RegexGuide(Guide): - """Guide to generate text in the language of a regular expression.""" - - initial_state = 0 - - def __init__( - self, states_to_token_maps, empty_token_ids, eos_tensor, initial_state - ): - self.states_to_token_maps = states_to_token_maps - self.empty_token_ids = empty_token_ids - self.eos_tensor = eos_tensor - self.initial_state = initial_state - - @classmethod - def from_regex( - cls, - regex_string: str, - tokenizer, - _create_states_mapping=create_states_mapping, - device=None, - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], - ): - ( - states_to_token_maps, - empty_token_ids, - fsm_finals, - ) = _create_states_mapping( - regex_string, - tokenizer, - regex_parser=regex_parser, - frozen_tokens=frozen_tokens, - ) - eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) - initial_state = states_to_token_maps.get_initial_state() - return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) - - @classmethod - def from_interegular_fsm( - cls, - interegular_fsm: interegular.fsm.FSM, - tokenizer, - _create_states_mapping_from_fsm=create_states_mapping_from_fsm, - device=None, - frozen_tokens: List[str] = [], - ): - ( - states_to_token_maps, - empty_token_ids, - fsm_finals, - ) = _create_states_mapping_from_fsm( - interegular_fsm, tokenizer, frozen_tokens=frozen_tokens - ) - eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) - initial_state = states_to_token_maps.get_initial_state() - return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) - - def get_next_instruction(self, state: int) -> Instruction: - """Return the next instruction for guided generation. - - The initialization of the guide builds an index which maps FSM states to a - map from authorized tokens to the state in which the guide needs to move - if said token is generated. Therefore the authorized tokens at the - current state are the keys of the map returned by the value of the index - for current state. - - If the current state is not contained in the end this means that we are - in a final state of the guide. We only authorize EOS tokens in the final - state. - - Parameters - ---------- - state - The current state of the guide. - - Returns - ------- - A `Generate` instance that contains the model and the allowed token ids. - - """ - if state == -1: - return Write(self.eos_tensor) - next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state) - # TODO: Create the Write and Generate objects within Rust instead? - if next_tokens_mask is None: - return Write(self.eos_tensor) - - return Generate(torch.tensor(next_tokens_mask)) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - We use the index to determine to which state the guide should transition - given the token that was just generated. - - Parameters - ---------- - state - The current state of the guide. - token_id - The id of the token that was just generated. - - Returns - ------- - The new state of the guide. - - """ - if state == -1: - return -1 - next_state = self.states_to_token_maps.get_next_state(state, token_id) - if next_state is None: - return -1 - else: - return next_state - - def is_final_state(self, state: int) -> bool: - """Determine whether the current state of the guide is a final state.""" - return state == -1 or self.states_to_token_maps.is_final_state(state) - - def copy(self): - return self - - def get_index_dict(self): - """Returns the Index as a Python Dict object.""" - return self.states_to_token_maps.get_transitions() diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index dae645e0..38272922 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -1,57 +1,9 @@ -from typing import Dict, List, Optional, Set, Tuple - -class FSMInfo: - initial: int - finals: Set[int] - transitions: Dict[Tuple[int, int], int] - alphabet_anything_value: int - alphabet_symbol_mapping: Dict[str, int] - - def __init__( - self, - initial: int, - finals: Set[int], - transitions: Dict[Tuple[int, int], int], - alphabet_anything_value: int, - alphabet_symbol_mapping: Dict[str, int], - ) -> None: ... +from typing import Dict, List, Optional, Set, Tuple, Union def build_regex_from_schema( json: str, whitespace_pattern: Optional[str] = None ) -> str: ... def to_regex(json: Dict, whitespace_pattern: Optional[str] = None) -> str: ... -def _walk_fsm( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - token_transition_keys: List[int], - start_state: int, - full_match: bool, -) -> List[int]: ... -def state_scan_tokens( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - vocabulary: Vocabulary, - vocabulary_transition_keys: Dict[str, List[int]], - start_state: int, -) -> Set[Tuple[int, int]]: ... -def get_token_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - token_str: str, -) -> List[int]: ... -def get_vocabulary_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - vocabulary: Vocabulary, - frozen_tokens: Set[str], -) -> Dict[str, List[int]]: ... -def create_fsm_index_end_to_end( - fsm_info: FSMInfo, - vocabulary: Vocabulary, - frozen_tokens: frozenset[str], -) -> Dict[int, Dict[int, int]]: ... BOOLEAN: str DATE: str @@ -67,29 +19,72 @@ WHITESPACE: str EMAIL: str URI: str -class Vocabulary: - """ - Vocabulary of an LLM. - """ +class Guide: + def __init__(self, index: Index): + """Creates a Guide object based on Index or statefull Index.""" + def get_state(self) -> int: + """Retrieves current state id of the Guide.""" + ... + def get_tokens(self) -> List[int]: + """Gets the list of allowed tokens for the current state.""" + ... + def advance(self, token_id: int) -> List[int]: + """Guide moves to the next state provided by the token id and returns a list of allowed tokens.""" + ... + def is_finished(self) -> bool: + """Checks if the automaton is in a final state.""" + ... + def __repr__(self) -> str: + """Gets the debug string representation of the guide.""" + ... + def __str__(self) -> str: + """Gets the string representation of the guide.""" + def __eq__(self, other: object) -> bool: + """Compares whether two guides are the same.""" + ... +class Vocabulary: + def __init__(self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]]): + """Creates a vocabulary from a map of tokens to token ids and eos token id.""" + ... @staticmethod - def from_dict(map: Dict[str, List[int]]) -> "Vocabulary": - """ - Creates a vocabulary from a dictionary of tokens to token IDs. - """ + def from_pretrained( + model: str, revision: Optional[str], token: Optional[str] + ) -> "Vocabulary": + """Creates the vocabulary of a pre-trained model.""" + ... + def insert(self, token: Union[str, bytes], token_id: int): + """Inserts new token with token_id or extends list of token_ids if token already present.""" + ... + def remove(self, token: Union[str, bytes]): + """Removes a token from vocabulary.""" + ... + def get_eos_token_id(self) -> Optional[int]: + """Gets the end of sentence token id.""" + ... + def get(self, token: Union[str, bytes]) -> Optional[List[int]]: + """Gets the end of sentence token id.""" ... def __repr__(self) -> str: - """ - Gets the debug string representation of the vocabulary. - """ + """Gets the debug string representation of the vocabulary.""" ... def __str__(self) -> str: - """ - Gets the string representation of the vocabulary. - """ + """Gets the string representation of the vocabulary.""" + ... + def __eq__(self, other: object) -> bool: + """Compares whether two vocabularies are the same.""" + ... + def __len__(self) -> int: + """Returns length of Vocabulary's tokens, excluding EOS token.""" + ... + def __deepcopy__(self, memo: dict) -> "Vocabulary": + """Makes a deep copy of the Vocabulary.""" ... class Index: + def __init__(self, regex: str, vocabulary: "Vocabulary"): + """Creates an index from a regex and vocabulary.""" + ... def get_allowed_tokens(self, state: int) -> Optional[List[int]]: """Returns allowed tokens in this state.""" ... @@ -99,9 +94,23 @@ class Index: def is_final_state(self, state: int) -> bool: """Determines whether the current state is a final state.""" ... - def get_index_dict(self) -> Dict[int, Dict[int, int]]: + def get_final_states(self) -> List[int]: + """Get all final states.""" + ... + def get_transitions(self) -> Dict[int, Dict[int, int]]: """Returns the Index as a Python Dict object.""" ... def get_initial_state(self) -> int: """Returns the ID of the initial state of the input FSM automata.""" ... + def __repr__(self) -> str: + """Gets the debug string representation of the index.""" + ... + def __str__(self) -> str: + """Gets the string representation of the index.""" + def __eq__(self, other: object) -> bool: + """Compares whether two indexes are the same.""" + ... + def __deepcopy__(self, memo: dict) -> "Index": + """Makes a deep copy of the Index.""" + ... diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py deleted file mode 100644 index e4b93b7e..00000000 --- a/python/outlines_core/fsm/regex.py +++ /dev/null @@ -1,482 +0,0 @@ -import re -from functools import lru_cache -from typing import ( - Dict, - FrozenSet, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, -) - -from interegular.fsm import ( - FSM, - Alphabet, - State, - TransitionKey, - _AnythingElseCls, - anything_else, -) - -from .outlines_core_rs import ( # noqa: F401 - FSMInfo, - Index, - Vocabulary, - _walk_fsm, - create_fsm_index_end_to_end, - get_token_transition_keys, - get_vocabulary_transition_keys, - state_scan_tokens, -) - - -class BetterAlphabet(Alphabet): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert anything_else in self._symbol_mapping - self.anything_value = self._symbol_mapping[anything_else] - - def __getitem__(self, item): - return self._symbol_mapping.get(item, self.anything_value) - - def copy(self): - return BetterAlphabet(self._symbol_mapping.copy()) - - -class BetterFSM(FSM): - flat_transition_map: Dict[Tuple[int, int], int] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not isinstance(self.alphabet, BetterAlphabet): - self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) - - flat_transition_map = {} - for from_state, trans_map in self.map.items(): - for trans_key, to_state in trans_map.items(): - flat_transition_map[(from_state, trans_key)] = to_state - - self.__dict__["flat_transition_map"] = flat_transition_map - self.__dict__["_fsm_info"] = None - - def copy(self): - return BetterFSM( - alphabet=self.alphabet.copy(), - states=self.states.copy(), - initial=self.initial, - finals=self.finals.copy(), - map=self.map.copy(), - __no_validation__=True, - ) - - @property - def fsm_info(self): - if self._fsm_info is None: - anything_value = self.alphabet.anything_value - self.__dict__["_fsm_info"] = FSMInfo( - self.initial, - self.finals, - self.flat_transition_map, - anything_value, - # TODO FIXME: Perform this conversion in Rust? - { - k: v - for k, v in self.alphabet._symbol_mapping.items() - if not isinstance(k, _AnythingElseCls) - }, - ) - - return self._fsm_info - - -TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"] - - -def add_to_transition_trie( - trie: TransitionTrie, - key_seq: Sequence[TransitionKey], - value: Union[State, None], -): - for key in key_seq[:-1]: - trie = cast(TransitionTrie, trie.setdefault(key, {})) - assert isinstance(trie, dict), "key sequence of incompatible length" - trie[key_seq[-1]] = value - - -# merge default_trie into the trie, only updating entries not present in the trie -def transition_trie_setdefault( - trie: TransitionTrie, - default_trie: TransitionTrie, -): - for key, default_value in default_trie.items(): - dest_value = trie.get(key) - if isinstance(dest_value, dict) and isinstance(default_value, dict): - transition_trie_setdefault(dest_value, default_value) - elif key not in trie: - trie[key] = default_value - - -def byte_symbol(byte: int) -> str: - return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) - - -def make_byte_level_fsm( - fsm: FSM, keep_utf8: bool = False, frozen_tokens: List[str] = [] -) -> FSM: - """Convert an FSM to a byte-level FSM, expanding multi-byte characters as - sequences of single-byte transitions. - - Parameters - ---------- - fsm: (`interegular.FSM`): - The token-level FSM to convert to a byte-level FSM. - keep_utf8: (`bool`, *optional*): - If set to True, the original utf-8 characters are kept as-is. Defaults to - False. NOTE: we're representing bytes as strings to keep it type-compatible. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is in the byte-level FSM. That is, - these tokens will not be expanded into byte-level transitions. Defaults to - an empty list. - - Returns - ------- - `interegular.FSM`: A byte-level FSM. - """ - - anything_else_key = fsm.alphabet[anything_else] - symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey] = {} - map: Dict[State, Dict[TransitionKey, State]] = {} - states: List[State] = list(fsm.states) - - # identify all multi-byte characters in the alphabet and build a mapping - # from the original transition keys to sequences of new keys for each byte - key_to_key_seqs: Dict[TransitionKey, Set[Tuple[TransitionKey, ...]]] = {} - all_key_seqs: Set[Tuple[TransitionKey, ...]] = set() - all_bytes: Set[int] = set() - max_key = max(fsm.alphabet.values()) - for symbol, transition_key in fsm.alphabet.items(): - assert symbol == anything_else or symbol in frozen_tokens or len(symbol) == 1 - if symbol == anything_else or symbol in frozen_tokens or ord(symbol) < 0x80: - symbol_mapping[symbol] = transition_key - else: - if keep_utf8: - symbol_mapping[symbol] = transition_key - key_list: List[TransitionKey] = [] - for byte in symbol.encode("utf-8"): - symbol = byte_symbol(byte) - if symbol not in symbol_mapping: - symbol_mapping[symbol] = max_key = TransitionKey(max_key + 1) - all_bytes.add(byte) - key_list.append(symbol_mapping[symbol]) - key_seq = tuple(key_list) - key_to_key_seqs.setdefault(transition_key, set()).add(key_seq) - all_key_seqs.add(key_seq) - - # add all remaining multi-byte utf-8 bytes to the alphabet - # (this is required to represent `anything_else`) - utf8_ranges = { - 1: (0x80, 0xC0), # continuation bytes - 2: (0xC0, 0xE0), # 2-byte sequences - 3: (0xE0, 0xF0), # 3-byte sequences - 4: (0xF0, 0xF8), # 4-byte sequences - } - utf8_all_keys: Dict[int, Set[TransitionKey]] = { - n: set() for n in utf8_ranges.keys() - } - for n, (start, end) in utf8_ranges.items(): - range_key = max_key = TransitionKey(max_key + 1) - for byte in range(start, end): - byte_key = symbol_mapping.setdefault(byte_symbol(byte), range_key) - utf8_all_keys[n].add(byte_key) - - # cache of intermediate transition states by transitions from that state - state_cache: Dict[FrozenSet[Tuple[TransitionKey, State]], State] = {} - - # helper function to create multi-step transitions between states - max_state = max(fsm.states) - - def create_seq_transitions( - seq_transitions_trie: TransitionTrie, - ) -> Dict[TransitionKey, State]: - nonlocal max_state - result: Dict[TransitionKey, State] = {} - - for next_key, next_trie in seq_transitions_trie.items(): - if isinstance(next_trie, dict): - next_transitions = create_seq_transitions(next_trie) - if not next_transitions: - continue - cache_key = frozenset(next_transitions.items()) - next_state = state_cache.get(cache_key) - if next_state is None: - next_state = max_state = State(max_state + 1) - map[next_state] = next_transitions - state_cache[cache_key] = next_state - states.append(next_state) - result[next_key] = next_state - elif next_trie is not None: - result[next_key] = next_trie - - return result - - # create new states and transitions - for state, transitions in fsm.map.items(): - seq_transitions_trie: TransitionTrie = {} - state_map: Dict[TransitionKey, State] = {} - - for transition_key, to_state in transitions.items(): - if transition_key in key_to_key_seqs: - if keep_utf8: - state_map[transition_key] = to_state - for key_seq in key_to_key_seqs[transition_key]: - add_to_transition_trie(seq_transitions_trie, key_seq, to_state) - else: # keep single-byte transitions as is - state_map[transition_key] = to_state - - # handle multi-byte anything_else sequences - if anything_else_key in transitions: - for key_seq in all_key_seqs: - add_to_transition_trie(seq_transitions_trie, key_seq, None) - - anything_else_trie: TransitionTrie = {} - cont_trie: Union[TransitionTrie, State] = transitions[anything_else_key] - for n in range(2, 5): - cont_trie = {key: cont_trie for key in utf8_all_keys[1]} - for key in utf8_all_keys[n]: - anything_else_trie[key] = cont_trie - - transition_trie_setdefault(seq_transitions_trie, anything_else_trie) - - # create new states and transitions - next_transitions = create_seq_transitions(seq_transitions_trie) - state_map.update(next_transitions) - map[state] = state_map - - return FSM( - alphabet=Alphabet(symbol_mapping), - states=states, - initial=fsm.initial, - finals=fsm.finals, - map=map, - ) - - -def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: - """Construct an equivalent FSM with deterministic state labels.""" - old_to_new_trans_keys = { - trans_key: i - for i, (trans_key, _) in enumerate( - sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) - ) - } - - new_symbol_mapping = { - symbol: old_to_new_trans_keys[trans_key] - for symbol, trans_key in fsm.alphabet._symbol_mapping.items() - } - - new_alphabet = BetterAlphabet(new_symbol_mapping) - - new_map = { - from_state: { - old_to_new_trans_keys[trans_key]: to_state - for trans_key, to_state in trans_map.items() - } - for from_state, trans_map in fsm.map.items() - } - - old_to_new_states = {} - old_to_new_states[fsm.initial] = 0 - - i = 0 - seen = {fsm.initial} - old_state_queue = [fsm.initial] - while old_state_queue: - old_state = old_state_queue.pop(-1) - transitions = new_map[old_state] - sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) - for _, old_state in sorted_transitions: - if old_state not in seen: - old_state_queue.append(old_state) - seen.add(old_state) - if old_state not in old_to_new_states: - i += 1 - old_to_new_states[old_state] = i - - new_map = dict( - sorted( - ( - ( - old_to_new_states[from_state], - dict( - sorted( - ( - (trans_key, old_to_new_states[to_state]) - for trans_key, to_state in trans_map.items() - ), - key=lambda v: v[0], - ) - ), - ) - for from_state, trans_map in new_map.items() - ), - key=lambda v: v[0], - ) - ) - - new_initial = 0 - new_finals = frozenset( - sorted(old_to_new_states[old_state] for old_state in fsm.finals) - ) - new_states = frozenset(sorted(new_map.keys())) - - new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) - - return new_fsm, old_to_new_states - - -re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") - -# The ".?" prefix and suffix is to handle special cases in some model vocabularies. This -# includes Gemma models (which use "▁�" as a token), NorwAI models (which use ".�" as a -# token), Salamandra models (which use ".�" and "�?" as tokens) and OpenCoder models -# (which use "�s" as a token). -re_replacement_seq = re.compile(r"^.?�+.?$") - - -# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode -@lru_cache() -def gpt2_bytes_to_unicode(): - """ - Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control - characters the bpe code barfs on. - - The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab - if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for - decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup - tables between utf-8 bytes and unicode strings. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -@lru_cache() -def gpt2_unicode_to_bytes(): - return {v: k for k, v in gpt2_bytes_to_unicode().items()} - - -@lru_cache -def reduced_vocabulary( - tokenizer, -) -> Tuple[Dict[str, List[int]], Set[int]]: - """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" - # TODO FIXME: See if we can get the underlying Rust tokenizers from HF and - # do all this in Rust - empty_token_ids = set() - vocabulary: Dict[str, List[int]] = {} - for token, token_idx in tokenizer.vocabulary.items(): - if token in tokenizer.special_tokens: - continue - - token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( - token - ) - - if token_str: - if isinstance(token, bytes): - # Handle BPE tokenizers where the tokens are directly stored as bytes - # https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md#regular-tokens - token_str = "".join(byte_symbol(b) for b in token) - - elif "\ufffd" in token_str and not re_replacement_seq.match(token): - # invalid utf-8 sequences are replaced with � (\ufffd), but there - # might also be tokens specifically for �, ��, ���, etc. - - if re_llama_byte_token.match(token): - # llama-like tokenizers have <0xXX> tokens for all - # bytes >= 0x80 and represent all incomplete utf-8 - # sequences using such tokens - token_bytes = [int(token[3:5], 16)] - else: - # gpt2-like tokenizers have multi-byte tokens that can - # have a mix of full and incomplete utf-8 characters, - # for example, b` \xf0` can be one token; these tokenizers - # map each byte to a valid utf-8 character - token_bytes = cast( - List[int], [gpt2_unicode_to_bytes().get(c) for c in token] - ) - if None in token_bytes: - raise RuntimeError( - f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" - ) - token_str = "".join(byte_symbol(b) for b in token_bytes) - - assert isinstance(token_str, str) - - vocabulary.setdefault(token_str, []).append(token_idx) - else: - empty_token_ids.add(token_idx) - - return vocabulary, empty_token_ids - - -def create_fsm_index_tokenizer( - fsm: BetterFSM, - tokenizer, - frozen_tokens: Optional[Iterable[str]] = None, -) -> Tuple[Index, Set[int]]: - """Construct an FMS index from a tokenizer. - - This uses the end-to-end approach of `create_fsm_index_end_to_end`. - - Parameters - ---------- - fsm: (`BetterFSM`): - A cache-friendly FSM. Other interegular FSMs can also be used, but caching - may not work as expected. - tokenizer: (`Tokenizer`): - The model's tokenizer. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is when expanding the token-level - FSM into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): - A set of token ids that correspond to empty strings. - - .. warning:: - - `fsm` needs to be deterministically ordered so that future caching makes sense. - """ - tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) - - states_to_token_subsets = Index( # type: ignore - fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - tokenizer.eos_token_id, - frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(), - ) - - return states_to_token_subsets, empty_token_ids diff --git a/src/error.rs b/src/error.rs index d7905ab7..ecb3bbb0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,21 +3,17 @@ use thiserror::Error; pub type Result = std::result::Result; #[derive(Error, Debug)] -#[error("{0}")] -pub struct TokenizersError(pub tokenizers::Error); - -impl PartialEq for TokenizersError { - fn eq(&self, other: &Self) -> bool { - self.0.to_string() == other.0.to_string() - } -} - -#[derive(Error, Debug, PartialEq)] pub enum Error { - #[error("The vocabulary does not allow us to build a sequence that matches the input")] - IndexError, + // Index Errors + #[error("Failed to build DFA {0}")] + IndexDfaError(#[from] Box), + #[error("Index failed since anchored universal start state doesn't exist")] + DfaHasNoStartState, + // Vocabulary Errors + #[error("EOS token should not be inserted into Vocabulary")] + EOSTokenDisallowed, #[error(transparent)] - TokenizersError(#[from] TokenizersError), + TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] UnsupportedTokenizer { model: String, reason: String }, #[error("Unable to locate EOS token for {model}")] diff --git a/src/index.rs b/src/index.rs index 5fcc3e93..f9d9f01a 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,132 +1,302 @@ /// Construct an Index. -use crate::prelude::{State, TransitionKey}; -use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; +use crate::prelude::*; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; use bincode::{Decode, Encode}; -use rustc_hash::{FxHashMap, FxHashSet}; - -#[derive(Debug)] -pub struct FSMInfo { - pub(crate) initial: State, - pub(crate) finals: FxHashSet, - pub(crate) transitions: FxHashMap<(State, TransitionKey), State>, - pub(crate) alphabet_anything_value: TransitionKey, - pub(crate) alphabet_symbol_mapping: FxHashMap, -} - -impl FSMInfo { - pub fn new( - initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, - alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } -} +use regex_automata::dfa::{dense::DFA, Automaton}; +use regex_automata::util::primitives::StateID as AutomataStateId; +use regex_automata::Anchored; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -#[derive(Debug, Encode, Decode)] +#[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct Index { - initial: u32, - finals: FxHashSet, - states_to_token_subsets: FxHashMap>, - eos_token_id: u32, + /// The ID of the initial state in the automaton, processing begins from this state. + initial_state: StateId, + /// A collection of states considered as terminal states. + final_states: HashSet, + /// A mapping of state transitions, defined by tokens ids and their corresponding state changes. + /// + /// ### Example + /// ```ignore + /// transitions = { + /// 1: {10: 2, 15: 3}, + /// 2: {20: 4, 25: 3}, + /// 3: {30: 4}, + /// 4: {40: 4}, + /// } + /// +--------------------------------------+ + /// | State 1 | + /// | Initial State | + /// +--------------------------------------+ + /// | | + /// + | + /// Token ID 10 | + /// +-----------------------+ | + /// | State 2 | | + /// +-----------------------+ | + /// | | | + /// | + + + /// | Token ID 25 Token ID 15 + /// | +------------------------+ + /// | | State 3 | + /// | +------------------------+ + /// | | + /// + + + /// Token ID 20 Token ID 30 + /// +--------------------------------------+ + /// | State 4 | + /// | Final state | + /// +--------------------------------------+ + /// ``` + transitions: HashMap>, + /// The token ID reserved for the "end-of-sequence" token. + eos_token_id: TokenId, } - +/// The `Index` structure is designed to efficiently map tokens from a given vocabulary +/// to state transitions within a finite-state automaton. +/// +/// ## Usage: +/// The `Index` is typically constructed by combining a vocabulary and regular expressions. +/// Once built, it can be used to efficiently evaluate token sequences or to validate input data. +/// +/// ## Example: +/// ```rust +/// use outlines_core::prelude::*; +/// +/// # fn run() -> Result<(), outlines_core::Error> { +/// let regex = "0|[1-9][0-9]*"; +/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None)?; +/// let index = Index::new(regex, &vocabulary)?; +/// +/// let initial_state = index.initial_state(); +/// println!("Initial state is {}", initial_state); +/// println!("Is initial state a final state? {}", index.is_final_state(&initial_state)); +/// +/// let allowed_tokens = index.allowed_tokens(&initial_state).unwrap(); +/// println!("Allowed tokens at initial state are {:?}", allowed_tokens); +/// +/// let token_id = allowed_tokens.first().unwrap(); +/// println!("Next state for the token_id {} is {:?}", token_id, index.next_state(&initial_state, token_id)); +/// +/// println!("Final states are {:?}", index.final_states()); +/// println!("Index has exactly {} transitions", index.transitions().len()); +/// # Ok(()) +/// # } +/// +/// ``` +/// +/// ## Performance: +/// - **Complexity**: +/// The `Index` can accommodate large vocabularies and complex regular expressions. +/// However, its size may grow significantly with the complexity of the input. +/// - **Construction Cost**: +/// Building the `Index` involves processing the vocabulary and regular expressions, +/// which may require a considerable amount of time and computational resources. impl Index { - pub fn new( - fsm_info: &FSMInfo, - vocabulary: &Vocabulary, - eos_token_id: u32, - frozen_tokens: FxHashSet, - ) -> Result { - let mut states_to_token_subsets: FxHashMap> = FxHashMap::default(); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter([fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - let token_ids_end_states = state_scan_tokens( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - vocabulary, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in &token_ids_end_states { - let inner_map = states_to_token_subsets.entry(start_state).or_default(); - inner_map.insert(*token_id, *end_state); - - if !seen.contains(end_state) { - next_states.insert(*end_state); - } - } + /// Builds an `Index` from regular expression and vocabulary tokens. + pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result { + let eos_token_id = vocabulary.eos_token_id(); + let dfa = DFA::new(regex).map_err(Box::new)?; + let start_state = match dfa.universal_start_state(Anchored::Yes) { + Some(s) => s, + None => return Err(Error::DfaHasNoStartState), + }; - if fsm_info.finals.contains(&start_state) && !token_ids_end_states.is_empty() { - let inner_map = states_to_token_subsets.entry(start_state).or_default(); - inner_map.insert(eos_token_id, start_state); + let mut transitions: HashMap> = HashMap::default(); + let mut final_states: HashSet = HashSet::default(); + + let mut seen: HashSet = HashSet::from_iter([start_state]); + let mut next_states: Vec = vec![start_state]; + + while let Some(current_state) = next_states.pop() { + if dfa.is_match_state(dfa.next_eoi_state(current_state)) { + final_states.insert(current_state.as_u32()); } - seen.insert(start_state); + 'token_loop: for (token, ids) in vocabulary.tokens().iter() { + if ids.contains(&eos_token_id) { + continue; + } + + let mut next_state = current_state; + for transition_byte in token { + next_state = dfa.next_state(next_state, *transition_byte); + if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { + continue 'token_loop; + } + } + + let is_intermediate_state = !dfa.is_match_state(next_state); + let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state)); + if is_intermediate_state || is_full_match_state { + for token_id in ids { + transitions + .entry(current_state.as_u32()) + .or_default() + .insert(*token_id, next_state.as_u32()); + } + } + if !seen.contains(&next_state) { + seen.insert(next_state); + next_states.push(next_state); + } + } } - let is_valid = states_to_token_subsets - .values() - .flat_map(|token_id_end_states| token_id_end_states.values()) - .any(|end_state| fsm_info.finals.contains(end_state)); - - if is_valid { - Ok(Self { - initial: fsm_info.initial, - finals: fsm_info.finals.clone(), - states_to_token_subsets, - eos_token_id, - }) - } else { - Err(Error::IndexError) + // Populate `transitions` with mappings from `final_states` to `eos_token_id` + for &final_state in &final_states { + transitions + .entry(final_state) + .or_default() + .insert(eos_token_id, final_state); } + + Ok(Self { + initial_state: start_state.as_u32(), + final_states, + transitions, + eos_token_id, + }) } - pub(crate) fn allowed_tokens(&self, state: u32) -> Option> { - self.states_to_token_subsets - .get(&state) - .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) + /// Returns the ID of the initial state in the automaton. + pub fn initial_state(&self) -> StateId { + self.initial_state } - pub(crate) fn next_state(&self, state: u32, token_id: u32) -> Option { - if token_id == self.eos_token_id { + /// Returns set of final states. + pub fn final_states(&self) -> &HashSet { + &self.final_states + } + + /// Returns state transitions map of tokens ids and their corresponding transition states. + pub fn transitions(&self) -> &HashMap> { + &self.transitions + } + + /// Checks if state is in final states set or not. + pub fn is_final_state(&self, state: &StateId) -> bool { + self.final_states.contains(state) + } + + /// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`. + pub fn allowed_tokens(&self, state: &StateId) -> Option> { + self.transitions + .get(state) + .map(|res| res.keys().cloned().collect()) + } + + /// Returns transition state for a given state and token id or `None` otherwise. + pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option { + if token_id == &self.eos_token_id { return None; } - Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?) + Some(*self.transitions.get(state)?.get(token_id)?) } +} + +impl std::fmt::Display for Index { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Index object with transitions:")?; + for (state_id, token_ids) in self.transitions.iter() { + writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; - pub(crate) fn initial(&self) -> u32 { - self.initial + #[test] + fn index_from_regex() { + let regex = "0|[1-9][0-9]*"; + let eos_token_id = 4; + let mut vocabulary = Vocabulary::new(eos_token_id); + for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] { + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); + } + let index = Index::new(regex, &vocabulary).expect("Index failed"); + let initial_state = index.initial_state(); + assert_eq!(initial_state, 40); + assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); + assert!(!index.is_final_state(&initial_state)); + + let expected = HashMap::from_iter([ + (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])), + (48, HashMap::from_iter([(4, 48)])), + (40, HashMap::from_iter([(3, 48), (2, 56)])), + (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])), + ]); + assert_eq!(index.transitions(), &expected); + + let allowed_tokens = index + .allowed_tokens(&initial_state) + .expect("No allowed tokens"); + let token_id = allowed_tokens.first().expect("No first tokens"); + + let state = 48; + assert_eq!(index.next_state(&initial_state, token_id), Some(state)); + assert!(index.is_final_state(&state)); + + assert_eq!(index.next_state(&state, &eos_token_id), None); + assert_eq!(index.next_state(&state, token_id), None); } - pub(crate) fn is_final(&self, state: u32) -> bool { - self.finals.contains(&state) + #[test] + fn index_from_regex_initital_in_allowed() { + let regex = "`\\n(\\.\\n)?`\\n"; + let mut vocabulary = Vocabulary::new(104); + for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] { + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); + } + + let index = Index::new(regex, &vocabulary).expect("Index failed"); + let allowed = index + .allowed_tokens(&index.initial_state()) + .expect("No allowed tokens"); + assert!(allowed.contains(&101)); } - pub(crate) fn transitions(&self) -> &FxHashMap> { - &self.states_to_token_subsets + #[test] + fn index_from_regex_multibyte() { + let regex = "😇| [😈-😍][😇-😎]*"; + let mut vocabulary = Vocabulary::new(8); + for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)] + { + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); + } + for (token, token_id) in [ + (vec![32, 240, 159, 152], 7), + (vec![32, 240, 159, 152, 141], 6), + (vec![240, 159, 152, 141], 4), + ] { + vocabulary + .try_insert(token, token_id as u32) + .expect("Insert failed"); + } + + let index = Index::new(regex, &vocabulary).expect("Index failed"); + assert_eq!(index.final_states(), &HashSet::from_iter([208, 128])); + + let expected = HashMap::from_iter([ + ( + 208, + HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]), + ), + ( + 80, + HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]), + ), + (128, HashMap::from_iter([(8, 128)])), + ]); + assert_eq!(index.transitions(), &expected); } } diff --git a/src/lib.rs b/src/lib.rs index cb56f86f..538152f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ pub mod index; pub mod json_schema; pub mod prelude; pub mod primitives; -pub mod regex; pub mod vocabulary; pub use error::{Error, JsonSchemaParserError, Result}; diff --git a/src/prelude.rs b/src/prelude.rs index d42516b9..3d38dcc6 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,4 +1,7 @@ +pub use tokenizers::FromPretrainedParameters; + pub use super::{ - primitives::{State, Token, TokenId, TransitionKey}, + index::Index, + primitives::{StateId, Token, TokenId}, vocabulary::Vocabulary, }; diff --git a/src/primitives.rs b/src/primitives.rs index e12bf036..f5571fec 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -1,11 +1,8 @@ -/// Interegular transition key. -pub type TransitionKey = u32; - /// Token content. -pub type Token = String; +pub type Token = Vec; /// Token identifier. pub type TokenId = u32; -/// Interegular state. -pub type State = u32; +/// State id. +pub type StateId = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index f9c4936c..b3c5a6f6 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,98 +1,168 @@ -use crate::index::{FSMInfo, Index}; +use std::sync::Arc; + +use crate::index::Index; use crate::json_schema; use crate::prelude::*; -use crate::regex::get_token_transition_keys; -use crate::regex::get_vocabulary_transition_keys; -use crate::regex::state_scan_tokens; -use crate::regex::walk_fsm; use bincode::config; +use bincode::{Decode, Encode}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyAny, PyDict}; use pyo3::wrap_pyfunction; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use serde_json::Value; +use tokenizers::FromPretrainedParameters; + +macro_rules! type_name { + ($obj:expr) => { + // Safety: obj is always initialized and tp_name is a C-string + unsafe { std::ffi::CStr::from_ptr((&*(&*$obj.as_ptr()).ob_type).tp_name) } + }; +} -#[pyclass(name = "FSMInfo")] -pub struct PyFSMInfo { - #[pyo3(get)] - initial: State, - #[pyo3(get)] - finals: FxHashSet, - #[pyo3(get)] - transitions: FxHashMap<(State, TransitionKey), State>, - #[pyo3(get)] - alphabet_anything_value: TransitionKey, - #[pyo3(get)] - alphabet_symbol_mapping: FxHashMap, +#[pyclass(name = "Guide", module = "outlines_core.fsm.outlines_core_rs")] +#[derive(Clone, Debug, PartialEq, Encode, Decode)] +pub struct PyGuide { + state: StateId, + index: PyIndex, } -impl From for PyFSMInfo { - fn from(fsm_info: FSMInfo) -> Self { - PyFSMInfo { - initial: fsm_info.initial, - finals: fsm_info.finals, - transitions: fsm_info.transitions, - alphabet_anything_value: fsm_info.alphabet_anything_value, - alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping, +#[pymethods] +impl PyGuide { + #[new] + fn __new__(index: PyIndex) -> Self { + PyGuide { + state: index.get_initial_state(), + index, } } -} -// FIXME: could be costly, confirm if FSMInfo will actually be part of the interface -impl From<&PyFSMInfo> for FSMInfo { - fn from(fsm_info: &PyFSMInfo) -> Self { - FSMInfo { - initial: fsm_info.initial, - finals: fsm_info.finals.clone(), - transitions: fsm_info.transitions.clone(), - alphabet_anything_value: fsm_info.alphabet_anything_value, - alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping.clone(), + fn get_state(&self) -> StateId { + self.state + } + + fn get_tokens(&self) -> PyResult> { + self.index + .get_allowed_tokens(self.state) + // Since Guide advances only through the states offered by the Index, it means + // None here shouldn't happen and it's an issue at Index creation step + .ok_or(PyErr::new::(format!( + "No allowed tokens available for the state {}", + self.state + ))) + } + + fn advance(&mut self, token_id: TokenId) -> PyResult> { + match self.index.get_next_state(self.state, token_id) { + Some(new_state) => { + self.state = new_state; + self.get_tokens() + } + None => Err(PyErr::new::(format!( + "No next state found for the current state: {} with token ID: {token_id}", + self.state + ))), } } -} -#[pymethods] -impl PyFSMInfo { - #[new] - fn new( - initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, - alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, - ) -> Self { - FSMInfo::new( - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, + fn is_finished(&self) -> bool { + self.index.is_final_state(self.state) + } + + fn __repr__(&self) -> String { + format!( + "Guide object with the state={:#?} and {:#?}", + self.state, self.index + ) + } + + fn __str__(&self) -> String { + format!( + "Guide object with the state={} and {}", + self.state, self.index.0 ) - .into() + } + + fn __eq__(&self, other: &PyGuide) -> bool { + self == other + } + + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("Guide")?; + let binary_data: Vec = + bincode::encode_to_vec(self, config::standard()).map_err(|e| { + PyErr::new::(format!("Serialization of Guide failed: {}", e)) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (guide, _): (PyGuide, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!("Deserialization of Guide failed: {}", e)) + })?; + Ok(guide) } } #[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] -pub struct PyIndex(Index); +#[derive(Clone, Debug, PartialEq, Encode, Decode)] +pub struct PyIndex(Arc); #[pymethods] impl PyIndex { #[new] - fn new( - py: Python<'_>, - fsm_info: &PyFSMInfo, - vocabulary: &PyVocabulary, - eos_token_id: u32, - frozen_tokens: FxHashSet, - ) -> PyResult { + fn __new__(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { py.allow_threads(|| { - Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) - .map(PyIndex) + Index::new(regex, &vocabulary.0) + .map(|x| PyIndex(Arc::new(x))) .map_err(Into::into) }) } + fn get_allowed_tokens(&self, state: StateId) -> Option> { + self.0.allowed_tokens(&state) + } + + fn get_next_state(&self, state: StateId, token_id: TokenId) -> Option { + self.0.next_state(&state, &token_id) + } + + fn is_final_state(&self, state: StateId) -> bool { + self.0.is_final_state(&state) + } + + fn get_final_states(&self) -> HashSet { + self.0.final_states().clone() + } + + fn get_transitions(&self) -> HashMap> { + self.0.transitions().clone() + } + + fn get_initial_state(&self) -> StateId { + self.0.initial_state() + } + fn __repr__(&self) -> String { + format!("{:#?}", self.0) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } + + fn __eq__(&self, other: &PyIndex) -> bool { + *self.0 == *other.0 + } + + fn __deepcopy__(&self, _py: Python<'_>, _memo: Py) -> Self { + PyIndex(Arc::new((*self.0).clone())) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -111,27 +181,7 @@ impl PyIndex { bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { PyErr::new::(format!("Deserialization of Index failed: {}", e)) })?; - Ok(PyIndex(index)) - } - - fn get_allowed_tokens(&self, state: u32) -> Option> { - self.0.allowed_tokens(state) - } - - fn get_next_state(&self, state: u32, token_id: u32) -> Option { - self.0.next_state(state, token_id) - } - - fn is_final_state(&self, state: u32) -> bool { - self.0.is_final(state) - } - - fn get_transitions(&self) -> FxHashMap> { - self.0.transitions().clone() - } - - fn get_initial_state(&self) -> u32 { - self.0.initial() + Ok(PyIndex(Arc::new(index))) } } @@ -153,142 +203,95 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR .map_err(|e| PyValueError::new_err(e.to_string())) } -#[pyfunction(name = "_walk_fsm")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" -)] -pub fn walk_fsm_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: FxHashSet, - token_transition_keys: Vec, - start_state: State, - full_match: bool, -) -> PyResult> { - Ok(walk_fsm( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &token_transition_keys, - start_state, - full_match, - )) -} - -#[pyfunction(name = "state_scan_tokens")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" -)] -pub fn state_scan_tokens_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: FxHashSet, - vocabulary: &PyVocabulary, - vocabulary_transition_keys: FxHashMap>, - start_state: State, -) -> PyResult> { - Ok(state_scan_tokens( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &vocabulary.0, - &vocabulary_transition_keys, - start_state, - )) -} - -#[pyfunction(name = "get_token_transition_keys")] -#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, - alphabet_anything_value: TransitionKey, - token_str: String, -) -> PyResult> { - Ok(get_token_transition_keys( - &alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - )) -} +#[pyclass(name = "Vocabulary", module = "outlines_core.fsm.outlines_core_rs")] +#[derive(Clone, Debug, Encode, Decode)] +pub struct PyVocabulary(Vocabulary); -#[pyfunction(name = "get_vocabulary_transition_keys")] -#[pyo3( - text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" -)] -pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, - alphabet_anything_value: TransitionKey, - vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, -) -> PyResult>> { - Ok(get_vocabulary_transition_keys( - &alphabet_symbol_mapping, - alphabet_anything_value, - &vocabulary.0, - &frozen_tokens, - )) -} +#[pymethods] +impl PyVocabulary { + #[new] + fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py) -> PyResult { + if let Ok(dict) = map.extract::>>(py) { + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); + } + if let Ok(dict) = map.extract::, Vec>>(py) { + return Ok(PyVocabulary(Vocabulary::try_from((eos_token_id, dict))?)); + } -#[pyfunction(name = "create_fsm_index_end_to_end")] -#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end_py<'py>( - py: Python<'py>, - fsm_info: &PyFSMInfo, - vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, -) -> PyResult> { - let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter(vec![fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - &vocabulary.0, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - &vocabulary.0, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { - existing_dict.set_item(token_id, end_state)?; - } else { - let new_dict = PyDict::new_bound(py); - new_dict.set_item(token_id, end_state)?; - states_to_token_subsets.set_item(start_state, new_dict)?; - } + let message = "Expected a dict with keys of type str or bytes and values of type list[int]"; + let tname = type_name!(map).to_string_lossy(); + if tname == "dict" { + Err(PyErr::new::(format!( + "Dict keys or/and values of the wrong types. {message}" + ))) + } else { + Err(PyErr::new::(format!( + "{message}, got {tname}" + ))) + } + } - if !seen.contains(&end_state) { - next_states.insert(end_state); - } + #[staticmethod] + #[pyo3(signature = (model, revision=None, token=None))] + fn from_pretrained( + model: String, + revision: Option, + token: Option, + ) -> PyResult { + let mut params = FromPretrainedParameters::default(); + if let Some(r) = revision { + params.revision = r } + if token.is_some() { + params.token = token + } + let v = Vocabulary::from_pretrained(model.as_str(), Some(params))?; + Ok(PyVocabulary(v)) + } - seen.insert(start_state); + fn insert(&mut self, py: Python<'_>, token: Py, token_id: TokenId) -> PyResult<()> { + if let Ok(t) = token.extract::(py) { + return Ok(self.0.try_insert(t, token_id)?); + } + if let Ok(t) = token.extract::(py) { + return Ok(self.0.try_insert(t, token_id)?); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) } - Ok(states_to_token_subsets) -} + fn remove(&mut self, py: Python<'_>, token: Py) -> PyResult<()> { + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) + } -#[pyclass(name = "Vocabulary")] -pub struct PyVocabulary(Vocabulary); + fn get(&self, py: Python<'_>, token: Py) -> PyResult>> { + if let Ok(t) = token.extract::(py) { + return Ok(self.0.token_ids(t.into_bytes()).cloned()); + } + if let Ok(t) = token.extract::(py) { + return Ok(self.0.token_ids(&t).cloned()); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) + } -#[pymethods] -impl PyVocabulary { - #[staticmethod] - fn from_dict(map: FxHashMap>) -> PyVocabulary { - PyVocabulary(Vocabulary::from(map)) + fn get_eos_token_id(&self) -> TokenId { + self.0.eos_token_id() } fn __repr__(&self) -> String { @@ -298,16 +301,49 @@ impl PyVocabulary { fn __str__(&self) -> String { format!("{}", self.0) } + + fn __eq__(&self, other: &PyVocabulary) -> bool { + self.0 == other.0 + } + + fn __len__(&self) -> usize { + self.0.tokens().len() + } + + fn __deepcopy__(&self, _py: Python<'_>, _memo: Py) -> Self { + PyVocabulary(self.0.clone()) + } + + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("Vocabulary")?; + let binary_data: Vec = + bincode::encode_to_vec(self, config::standard()).map_err(|e| { + PyErr::new::(format!( + "Serialization of Vocabulary failed: {}", + e + )) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (guide, _): (PyVocabulary, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!( + "Deserialization of Vocabulary failed: {}", + e + )) + })?; + Ok(guide) + } } #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; - m.add("BOOLEAN", json_schema::BOOLEAN)?; m.add("DATE", json_schema::DATE)?; m.add("DATE_TIME", json_schema::DATE_TIME)?; @@ -327,7 +363,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/regex.rs b/src/regex.rs deleted file mode 100644 index 24687f1e..00000000 --- a/src/regex.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::prelude::*; -use rustc_hash::{FxHashMap, FxHashSet}; - -pub fn walk_fsm( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, - _fsm_initial: State, - fsm_finals: &FxHashSet, - token_transition_keys: &[TransitionKey], - start_state: State, - full_match: bool, -) -> Vec { - let mut state = start_state; - let mut accepted_states = Vec::new(); - let mut last_final_idx = 0; - - for (i, &trans_key) in token_transition_keys.iter().enumerate() { - match fsm_transitions.get(&(state, trans_key)) { - Some(&new_state) => { - state = new_state; - if fsm_finals.contains(&state) { - last_final_idx = i + 1; - } - accepted_states.push(state); - } - None => { - if !full_match && last_final_idx > 0 { - return accepted_states[..last_final_idx].to_vec(); - } - return Vec::new(); - } - } - } - - if full_match && last_final_idx != token_transition_keys.len() { - return Vec::new(); - } - - accepted_states -} - -pub fn state_scan_tokens( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, - fsm_initial: State, - fsm_finals: &FxHashSet, - vocabulary: &Vocabulary, - vocabulary_transition_keys: &FxHashMap>, - start_state: State, -) -> FxHashSet<(TokenId, State)> { - let mut res = FxHashSet::default(); - - for (token, token_ids) in vocabulary.iter() { - let token_transition_keys = &vocabulary_transition_keys[token]; - let state_seq = walk_fsm( - fsm_transitions, - fsm_initial, - fsm_finals, - token_transition_keys, - start_state, - false, - ); - - if state_seq.len() < token_transition_keys.len() { - continue; - } - - for &token_id in token_ids { - res.insert((token_id, *state_seq.last().unwrap())); - } - } - - res -} - -pub fn get_token_transition_keys( - alphabet_symbol_mapping: &FxHashMap, - alphabet_anything_value: TransitionKey, - token_str: &str, -) -> Vec { - let mut token_transition_keys = Vec::new(); - let mut i = 0; - let chars: Vec = token_str.chars().collect(); - - while i < chars.len() { - let symbol; - if chars[i] == '\0' && i != chars.len() - 1 { - if i + 2 < chars.len() { - symbol = format!("\0{}{}", chars[i + 1], chars[i + 2]); - i += 3; - } else { - symbol = chars[i].to_string(); - i += 1; - } - } else { - symbol = chars[i].to_string(); - i += 1; - } - - let transition_key = *alphabet_symbol_mapping - .get(&symbol) - .unwrap_or(&alphabet_anything_value); - token_transition_keys.push(transition_key); - } - - token_transition_keys -} - -pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &FxHashMap, - alphabet_anything_value: TransitionKey, - vocabulary: &Vocabulary, - frozen_tokens: &FxHashSet, -) -> FxHashMap> { - let mut vocab_transition_keys = FxHashMap::default(); - - for item in vocabulary.iter() { - let token_str = item.0.clone(); - - let mut token_transition_keys; - - // Since these tokens are not expanded into byte-level transitions, we - // can simply get their transition keys directly. - if frozen_tokens.contains(&token_str) { - token_transition_keys = Vec::new(); - token_transition_keys.push( - *alphabet_symbol_mapping - .get(&token_str) - .unwrap_or(&alphabet_anything_value), - ) - } else { - token_transition_keys = get_token_transition_keys( - alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - ); - } - - vocab_transition_keys.insert(token_str, token_transition_keys); - } - - vocab_transition_keys -} diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 0b3eaa2e..012574da 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,9 +1,10 @@ -use rustc_hash::FxHashMap; +use bincode::{Decode, Encode}; +use rustc_hash::FxHashMap as HashMap; use tokenizers::normalizers::Sequence; -use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; +use tokenizers::{NormalizerWrapper, Tokenizer}; -use crate::{error, prelude::*}; +use crate::prelude::*; use crate::{Error, Result}; use locator::{HFLocator, Locator}; @@ -16,28 +17,52 @@ mod processor; /// /// ## Examples /// +/// ### Create a vocabulary from a pretrained model. /// ```rust -/// # use outlines_core::prelude::*; -/// # -/// let vocabulary = Vocabulary::new(None) -/// .insert("blah", 0) -/// .insert("1a", 1) -/// .insert("2", 2) -/// .insert("0", 3); +/// use outlines_core::prelude::*; +/// +/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None); +/// ``` +/// +/// ### Create a vocabulary from a pretrained model with some additional parameters. +/// ``` rust +/// use outlines_core::prelude::*; +/// +/// let params = FromPretrainedParameters { +/// revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e".to_string(), +/// ..Default::default() +/// }; +/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", Some(params)); +/// +/// ``` +/// +/// ### Create an empty vocabulary and manually insert some tokens. +/// ```rust +/// use outlines_core::prelude::*; +/// +/// let eos_token_id = 1; +/// let mut vocabulary = Vocabulary::new(eos_token_id); +/// +/// vocabulary.try_insert("token", 0).expect("New token inserted"); +/// assert_eq!(vocabulary.token_ids("token"), Some(&vec![0])); +/// assert_eq!(vocabulary.tokens().len(), 1); +/// assert_eq!(vocabulary.eos_token_id(), eos_token_id); +/// +/// vocabulary.remove("token"); +/// assert_eq!(vocabulary.token_ids("token"), None); /// ``` -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { - // TODO: Option is temp for back compatibility - eos_token_id: Option, - tokens: FxHashMap>, + eos_token_id: TokenId, + tokens: HashMap>, } impl Vocabulary { /// Creates an empty vocabulary. - pub fn new(eos_token_id: Option) -> Self { + pub fn new(eos_token_id: TokenId) -> Self { Self { eos_token_id, - tokens: FxHashMap::default(), + tokens: HashMap::default(), } } @@ -55,8 +80,7 @@ impl Vocabulary { model: &str, parameters: Option, ) -> Result { - let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()) - .map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?; + let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone())?; Self::filter_prepend_normalizers(&mut tokenizer); // Locate eos_token_id in defined locations. @@ -69,10 +93,10 @@ impl Vocabulary { }; // Start building the vocabulary from eos_token_id and added tokens. - let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { - if !added_token.special { - vocabulary = vocabulary.insert(added_token.content.clone(), *id); + if !added_token.special && id != &eos_token_id { + vocabulary.try_insert(added_token.content.clone(), *id)? } } @@ -84,27 +108,46 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let token_bytes = processor.process(token)?; - // TODO: lossy is temp: - // - in python in was handled by byte_symbol function - // - interface needs to be redefined to treat Token type as bytes: Vec - let processed_token = String::from_utf8_lossy(&token_bytes); - vocabulary = vocabulary.insert(processed_token, token_id); + if token_id != eos_token_id { + let processed_token = processor.process(&token)?; + vocabulary.try_insert(processed_token, token_id)?; + } } Ok(vocabulary) } - /// Per provided token returns vector of `TokenId`s if available in the vocabulary. - pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { - self.tokens.get(token) + /// Returns all tokens with their token ids in vocabulary + pub fn tokens(&self) -> &HashMap> { + &self.tokens + } + + /// Returns all token ids per provided token if available in the vocabulary. + pub fn token_ids(&self, token: impl AsRef<[u8]>) -> Option<&Vec> { + self.tokens.get(token.as_ref()) } /// Gets the identifier of the special end of the sentence token. - pub fn eos_token_id(&self) -> Option { + pub fn eos_token_id(&self) -> TokenId { self.eos_token_id } + /// Inserts a token to the vocabulary with the specified identifier. + pub fn try_insert(&mut self, token: impl Into, id: TokenId) -> Result<(), Error> { + if id == self.eos_token_id { + return Err(Error::EOSTokenDisallowed); + } + let token = token.into(); + self.tokens.entry(token).or_default().push(id); + Ok(()) + } + + /// Removes given token from the vocabulary. + pub fn remove(&mut self, token: impl Into) { + let token = token.into(); + self.tokens.remove(&token); + } + /// Filters out `Prepend` kind of tokenizer's normalizers. fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) { // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece @@ -136,143 +179,122 @@ impl Vocabulary { } } -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier. - pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { - self.insert_in_place(token, id); - self - } - - /// Extends the vocabulary with tokens and their identifiers. - pub fn extend, I: IntoIterator>( - mut self, - tokens_and_ids: impl IntoIterator, - ) -> Vocabulary { - self.extend_in_place(tokens_and_ids); - self - } -} - -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier, in place. - pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { - // TODO: return error if eos token id is inserted - let token = token.into(); - self.tokens.entry(token).or_default().push(id); - } - - /// Extends the vocabulary with tokens and their identifiers, in place. - pub fn extend_in_place, I: IntoIterator>( - &mut self, - tokens_and_ids: impl IntoIterator, - ) { - for (token, ids) in tokens_and_ids.into_iter() { - let token = token.into(); - self.tokens.entry(token).or_default().extend(ids); - } - } -} - -impl std::ops::Deref for Vocabulary { - type Target = FxHashMap>; - - fn deref(&self) -> &FxHashMap> { - &self.tokens - } -} - impl std::fmt::Display for Vocabulary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for (index, (token, token_ids)) in self.iter().enumerate() { - if index != (self.len() - 1) { - writeln!(f, "{:?} -> {:?}", token, token_ids)?; - } else { - write!(f, "{:?} -> {:?}", token, token_ids)?; - } + writeln!( + f, + "Vocabulary object with eos_token_id={:?} and the following tokens to token_ids:", + self.eos_token_id + )?; + for (token, token_ids) in self.tokens.iter() { + writeln!( + f, + "{:?} -> {:?}", + token + .iter() + .map(|b| format!("0x{:02X}", b)) + .collect::>(), + token_ids + )?; } Ok(()) } } -impl From>> for Vocabulary { - fn from(tokens: FxHashMap>) -> Vocabulary { - Vocabulary { - eos_token_id: None, - tokens, +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { + let (eos_token_id, tokens) = values; + if tokens.iter().any(|(_, ids)| ids.contains(&eos_token_id)) { + return Err(Error::EOSTokenDisallowed); } + Ok(Vocabulary { + eos_token_id, + tokens, + }) } } -impl FromIterator<(T, I)> for Vocabulary -where - T: Into, - I: IntoIterator, -{ - fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new(None).extend(tokens_and_ids) +impl TryFrom<(TokenId, HashMap>)> for Vocabulary { + type Error = Error; + + fn try_from(values: (TokenId, HashMap>)) -> Result { + let (eos_token_id, tokens) = values; + Ok(Vocabulary { + eos_token_id, + tokens: tokens + .into_iter() + .map(|(k, v)| { + if v.contains(&eos_token_id) { + Err(Error::EOSTokenDisallowed) + } else { + Ok((k.as_bytes().to_vec(), v)) + } + }) + .collect::>, _>>()?, + }) } } #[cfg(test)] mod tests { use super::*; + use rustc_hash::FxHashSet as HashSet; #[test] - fn insert() { - let vocabulary = Vocabulary::new(None) - .insert("blah", 0) - .insert("1a", 1) - .insert("2", 2) - .insert("0", 3); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); - } + fn basic_interface() { + let eos_token_id = 3; + let mut vocabulary = Vocabulary::new(eos_token_id); - #[test] - fn extend() { - let vocabulary = Vocabulary::new(None).extend([ - ("blah", vec![0]), - ("1a", vec![1]), - ("2", vec![2]), - ("0", vec![3]), - ]); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); - } + match vocabulary.try_insert("eos-token", eos_token_id) { + Err(Error::EOSTokenDisallowed) => {} + _ => unreachable!(), + } - #[test] - fn new_empty_vocabulary() { - let vocabulary = Vocabulary::new(None); - assert!(vocabulary.eos_token_id.is_none()); + // New empty vocabulary. + assert_eq!(vocabulary.eos_token_id, eos_token_id); assert!(vocabulary.tokens.is_empty()); + + for (token, id) in [("zero", 0), ("one", 1), ("two", 2)] { + vocabulary.try_insert(token, id).expect("Insert failed"); + assert_eq!(vocabulary.token_ids(token), Some(&vec![id])); + } + assert_eq!(vocabulary.tokens.len(), 3); + assert_eq!(vocabulary.tokens().len(), 3); + + // Confirm different types. + vocabulary.try_insert(b"four", 4).expect("Insert failed"); + assert_eq!(vocabulary.token_ids("four"), Some(&vec![4])); + + vocabulary + .try_insert(b"five".to_vec(), 5) + .expect("Insert failed"); + assert_eq!(vocabulary.token_ids("five"), Some(&vec![5])); + + vocabulary + .try_insert("six".to_string(), 6) + .expect("Insert failed"); + assert_eq!(vocabulary.token_ids("six"), Some(&vec![6])); + + vocabulary.remove(b"four"); + assert_eq!(vocabulary.token_ids("four"), None); + + vocabulary.remove(b"five".to_vec()); + assert_eq!(vocabulary.token_ids("five"), None); + + vocabulary.remove("six".to_string()); + assert_eq!(vocabulary.token_ids("six"), None); } #[test] fn new_empty_vocabulary_from_hashmap() { - let map = FxHashMap::default(); - let vocabulary = Vocabulary::from(map); - assert!(vocabulary.eos_token_id.is_none()); + let map: HashMap> = HashMap::default(); + let vocabulary = Vocabulary::try_from((1_u32, map)).expect("Vocabulary failed"); + assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } - #[test] - fn new_vocabulary_from_iterator() { - let token: Token = "abc".to_string(); - let id: Vec = vec![1]; - let it = vec![(token, id)]; - let vocabulary = Vocabulary::from_iter(it); - assert!(vocabulary.eos_token_id.is_none()); - assert!(!vocabulary.tokens.is_empty()); - } - #[test] fn supported_pretrained_models() { // Support is expected for these: @@ -292,7 +314,6 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None); match vocabulary { Ok(v) => { - assert!(v.eos_token_id().is_some()); assert_eq!(v.eos_token_id, v.eos_token_id()); assert!(!v.tokens.is_empty()); } @@ -309,9 +330,6 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 50256); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -319,11 +337,12 @@ mod tests { ); let token = "Ġal"; - assert!(vocabulary.token_to_ids(token).is_none()); + let btoken = token.as_bytes().to_vec(); + assert!(vocabulary.token_ids(&btoken).is_none()); assert!(tokenizer.token_to_id(token).is_some()); for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { - let v_ids = vocabulary.token_to_ids(v_token); + let v_ids = vocabulary.token_ids(v_token.as_bytes()); assert!(v_ids.is_some()); for v_id in v_ids.unwrap() { let t_token = tokenizer @@ -342,32 +361,37 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 2); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), "" ); - for (v_token, t_token_expected) in [ - ("abc", "abc"), - (" al", "▁al"), - (" O", "▁O"), - (" ", "▁▁▁"), - // TODO: won't pass since first we need to change token's type to bytes - // ("<0xFF>", "ÿ"), - // ("<0x20>", "▁"), - ] { - let v_ids = vocabulary.token_to_ids(v_token); + let tests: &[(Vec, &[&str])] = &[ + ("abc".as_bytes().to_vec(), &["abc"]), + (" al".as_bytes().to_vec(), &["▁al"]), + (" O".as_bytes().to_vec(), &["▁O"]), + (" ".as_bytes().to_vec(), &["▁▁▁"]), + (" ".as_bytes().to_vec(), &["▁", "<0x20>"]), + ("a".as_bytes().to_vec(), &["a", "<0x61>"]), + (vec![0xFF], &["<0xFF>"]), + (vec![0x20], &["▁", "<0x20>"]), + ]; + for (v_token, t_tokens_expected) in tests { + let v_ids = vocabulary.token_ids(v_token); assert!(v_ids.is_some()); - for v_id in v_ids.unwrap() { - let t_token = tokenizer - .id_to_token(*v_id) - .expect("Token id not found in tokenizer"); - assert_eq!(&t_token, t_token_expected); - } + + let t_tokens = v_ids + .unwrap() + .iter() + .map(|v_id| { + tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer") + }) + .collect::>(); + let expected = HashSet::from_iter(t_tokens_expected.iter().map(|s| s.to_string())); + assert_eq!(t_tokens, expected) } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 7426f249..f4b95e5b 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -90,21 +90,13 @@ pub(crate) enum TokenProcessorLevel { /// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. #[derive(Debug, Clone, PartialEq)] pub(crate) struct Mods { - spacechar: char, -} - -impl Default for Mods { - /// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. - fn default() -> Self { - Self { spacechar: ' ' } - } + spacechar: String, } impl Mods { - /// Apply default modifications to each token. - fn apply_default(&self, token: String) -> String { - let to = Self::default().spacechar.to_string(); - token.replace(self.spacechar, &to) + /// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. + fn apply_default(&self, token: &str) -> String { + token.replace(&self.spacechar, " ") } } @@ -116,7 +108,7 @@ struct ReplaceDecoder { } impl ReplaceDecoder { - fn space_replacement(&self) -> Option { + fn space_replacement(&self) -> Option { if self.content != " " { return None; } @@ -126,7 +118,7 @@ impl ReplaceDecoder { let char = chars.next(); if let Some(replacement) = char { if chars.next().is_none() { - return Some(replacement); + return Some(replacement.to_string()); } } None @@ -157,7 +149,7 @@ impl TokenProcessor { }), DecoderWrapper::Sequence(decoding_sequence) => { let mut is_byte_fallback = false; - let mut spacechar = ' '; + let mut spacechar = ' '.to_string(); for decoder in decoding_sequence.get_decoders() { match decoder { @@ -190,7 +182,7 @@ impl TokenProcessor { } /// Operates on each token based on the level of `TokenProcessor`. - pub(crate) fn process(&self, token: String) -> Result> { + pub(crate) fn process(&self, token: &str) -> Result> { match &self.level { TokenProcessorLevel::Byte => token .chars() @@ -275,7 +267,7 @@ mod tests { ('þ', 0xFE), ('ÿ', 0xFF), ] { - let processed = processor.process(ch.to_string()).expect("Not processed"); + let processed = processor.process(&ch.to_string()).expect("Not processed"); assert_eq!(processed, [byte]); } } @@ -285,8 +277,10 @@ mod tests { let model = "hf-internal-testing/llama-tokenizer"; let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); - let spacechar = '▁'; - let mods = Mods { spacechar }; + let spacechar = '▁'.to_string(); + let mods = Mods { + spacechar: spacechar.clone(), + }; assert_eq!(processor.level, TokenProcessorLevel::ByteFallback(mods)); @@ -294,7 +288,7 @@ mod tests { ("abc", vec![0x61, 0x62, 0x63]), ("<0x61>", vec![0x61]), ("<0x61>a", vec![0x3C, 0x30, 0x78, 0x36, 0x31, 0x3E, 0x61]), - (&spacechar.to_string(), vec![0x20]), + (&spacechar, vec![0x20]), ( &format!("{}{}abc", spacechar, spacechar), vec![0x20, 0x20, 0x61, 0x62, 0x63], @@ -304,7 +298,7 @@ mod tests { vec![0x20, 0x20, 0x20], ), ] { - let processed = processor.process(input.to_string()).expect("Not processed"); + let processed = processor.process(input).expect("Not processed"); assert_eq!(processed, expected); } } @@ -328,7 +322,7 @@ mod tests { let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); for token in ["𝒜𝒷𝒸𝒟𝓔", "🦄🌈🌍🔥🎉", "京东购物"] { - let result = processor.process(token.to_string()); + let result = processor.process(token); match result { Err(Error::ByteProcessorFailed) => {} _ => unreachable!(), @@ -342,7 +336,7 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); - let result = processor.process("<0x6y>".to_string()); + let result = processor.process("<0x6y>"); match result { Err(Error::ByteFallbackProcessorFailed) => {} _ => unreachable!(), diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 905bfded..c9262a39 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,216 +1,150 @@ -import interegular -import pytest -from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write - - -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) - - -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 - - fsm = StopAtEOSGuide(MockTokenizer()) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert instruction.tokens is None - - instruction = fsm.get_next_instruction(fsm.final_state) - assert isinstance(instruction, Write) - assert instruction.tokens == [2] - - assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide.from_regex(regex_str, MockTokenizer()) - - -def test_from_regex(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token +import copy +import pickle +from typing import Dict, List, Union - regex_str = "[1-9]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == {0: {1: 1}} - - instruction = fsm.get_next_instruction(-1) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - instruction = fsm.get_next_instruction(3) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - - assert fsm.get_next_state(state=0, token_id=1) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_from_fsm(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_interegular_fsm( - interegular.parse_pattern(regex_str).to_fsm(), tokenizer - ) - - assert fsm.get_index_dict() == {0: {1: 1}} - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - - assert fsm.get_next_state(state=0, token_id=1) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_multi_byte_llama_like(): - class MockTokenizer: - vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - "<0xF0>": 5, - "<0x9F>": 6, - "<0x98>": 7, - "<0x88>": 8, # 😈 - "\ufffd": 9, - "\ufffd\ufffd": 10, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if token[0] == "<": - return "\ufffd" - return token - - regex_str = "[😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == { - 0: {5: 1, 4: 2}, - 1: {6: 3}, - 3: {7: 4}, - 4: {8: 2}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 4]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_multi_byte_gpt2_like(): - class MockTokenizer: - vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - " ": 5, - "\ufffd": 6, - "\ufffd\ufffd": 7, - "ðŁĺ": 8, - "Ī": 9, # '😈' - "Ġð": 10, - "ŁĺĪ": 11, # ' 😈' - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if self.vocabulary[token] >= 8: - return "\ufffd" - return token - - regex_str = " [😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - assert fsm.get_index_dict() == { - 0: {5: 1, 10: 2}, - 1: {8: 5, 4: 3}, - 2: {11: 3}, - 5: {9: 3}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 10]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - state = fsm.get_next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.get_next_state(state=5, token_id=103) - assert fsm.is_final_state(state) +import pytest +from outlines_core.fsm import Guide, Index, Vocabulary + + +@pytest.fixture(scope="session") +def index() -> Index: + eos_token_id = 3 + # types here only to please mypy checks + tokens: Dict[Union[str, bytes], List[int]] = {"1": [1], "2": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + return Index(regex, vocabulary) + + +def test_interface(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + guide = Guide(index) + + assert guide.get_state() == index.get_initial_state() == 12 + assert guide.get_tokens() == [1] + + assert guide.advance(1) == [vocabulary.get_eos_token_id()] + assert guide.is_finished() + assert guide.get_state() == 20 + assert guide.get_tokens() == [eos_token_id] + + with pytest.raises( + ValueError, + match="No next state found for the current state", + ): + # No advancement is possible for state with allowed tokens == eos + assert guide.advance(eos_token_id) + # As well as with any other random token id + assert guide.advance(4) + + +def test_regex_final_state_walk(): + # Make sure that the Guide can walk to the final state correctly. + eos_token_id = 104 + tokens = {b"\n": [103], b".": [102], b"`": [101]} + regex = r"`\n(\.\n)?`\n" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + guide = Guide(index) + + assert guide.get_tokens() == [101] + assert guide.advance(101) == [103] + assert sorted(guide.advance(103)) == [101, 102] + assert guide.advance(101) == [103] + assert guide.advance(103) == [vocabulary.get_eos_token_id()] + assert guide.is_finished() + + +def test_token_trans_keys_identical(): + tokens = {"a": [1], "b": [2], "z": [3]} + eos_token_id = 4 + regex = r"z[ab]z" + + vocabulary = Vocabulary(eos_token_id, tokens) + index = Index(regex, vocabulary) + + guide1 = Guide(index) + guide2 = Guide(index) + + assert guide1.advance(3) == guide2.advance(3) + # `a` and `b` have similar transitions to `z` + assert guide1.advance(1) == guide2.advance(2) + assert guide1.advance(3) == guide2.advance(3) == [eos_token_id] + assert guide1.is_finished() + assert guide2.is_finished() + + +def test_str_and_bytes_produce_the_same(): + tokens1 = {"a": [1], "b": [2], "z": [3]} + tokens2 = {b"a": [1], b"b": [2], b"z": [3]} + eos_token_id = 4 + regex = r"z[ab]z" + + vocabulary1 = Vocabulary(eos_token_id, tokens1) + vocabulary2 = Vocabulary(eos_token_id, tokens2) + index1 = Index(regex, vocabulary1) + index2 = Index(regex, vocabulary2) + guide1 = Guide(index1) + guide2 = Guide(index2) + + assert guide1.advance(3) == guide2.advance(3) + # `a` and `b` have similar transitions to `z` + assert guide1.advance(1) == guide2.advance(2) + assert guide1.advance(3) == guide2.advance(3) == [eos_token_id] + assert guide1.is_finished() + assert guide2.is_finished() + + +def test_pickling(index): + guide = Guide(index) + serialized = pickle.dumps(guide) + deserialized = pickle.loads(serialized) + assert sorted(deserialized.get_tokens()) == sorted(guide.get_tokens()) + + +@pytest.mark.parametrize( + "model, revision", + [ + ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), + ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), + ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), + ( + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "783fd50eb82d7f57758de033861f54d62dde234f", + ), + ], +) +def test_pickling_from_pretrained_with_revision(model, revision): + regex = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + vocabulary = Vocabulary.from_pretrained(model, revision=revision) + index = Index(regex, vocabulary) + assert len(index.get_transitions()) == 810 + + guide = Guide(index) + serialized = pickle.dumps(guide) + deserialized = pickle.loads(serialized) + assert sorted(deserialized.get_tokens()) == sorted(guide.get_tokens()) + + +def test_equality(index): + guide1 = Guide(index) + guide2 = Guide(index) + assert guide1 == guide2 + + # confirm that equality is about inner index, not reference difference + index2 = copy.deepcopy(index) + guide3 = Guide(index2) + assert guide3 == guide2 == guide1 + + # progress one of the guides, confirm different state == different guide + guide1.advance(guide1.get_tokens()[-1]) + assert guide1 != guide2 + assert guide3 == guide2 diff --git a/tests/fsm/test_index.py b/tests/fsm/test_index.py new file mode 100644 index 00000000..5b560889 --- /dev/null +++ b/tests/fsm/test_index.py @@ -0,0 +1,65 @@ +import copy +import gc +import pickle +from typing import Dict, List, Union + +import pytest +from outlines_core.fsm import Index, Vocabulary + + +@pytest.fixture(scope="session") +def index() -> Index: + eos_token_id = 3 + # types here only to please mypy checks + tokens: Dict[Union[str, bytes], List[int]] = {"1": [1], "2": [2]} + regex = r"[1-9]" + + vocabulary = Vocabulary(eos_token_id, tokens) + return Index(regex, vocabulary) + + +def test_basic_interface(index): + init_state = index.get_initial_state() + assert init_state == 12 + assert index.is_final_state(init_state) is False + + allowed_tokens = index.get_allowed_tokens(init_state) + assert allowed_tokens == [1, 2] + + next_state = index.get_next_state(init_state, allowed_tokens[-1]) + assert next_state == 20 + assert index.is_final_state(next_state) is True + assert index.get_final_states() == {20} + + expected_transitions = { + 12: { + 1: 20, + 2: 20, + }, + 20: { + 3: 20, + }, + } + assert index.get_transitions() == expected_transitions + + +def test_pickling(index): + serialized = pickle.dumps(index) + deserialized = pickle.loads(serialized) + assert deserialized == index + + +def test_deepcopy(index): + index2 = copy.deepcopy(index) + assert index2 == index + + copy_index2 = copy.deepcopy(index2) + assert copy_index2 == index2 + + index2_id = id(index2) + del index2 + gc.collect() + is_deleted = not any(id(o) == index2_id for o in gc.get_objects()) + assert is_deleted + + assert copy_index2 == index diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5dd8b5d4..36a269eb 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -2,7 +2,6 @@ import re from typing import Literal, Union -import interegular import pytest from outlines_core.fsm.json_schema import build_regex_from_schema, to_regex from pydantic import BaseModel, Field @@ -53,10 +52,7 @@ class Model(BaseModel): n: int json_schema = json.dumps(Model.model_json_schema()) - pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) - - # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() - interegular.parse_pattern(pattern).to_fsm() + build_regex_from_schema(json_schema, whitespace_pattern=None) def test_match_object(): diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py deleted file mode 100644 index 1007d8b0..00000000 --- a/tests/fsm/test_regex.py +++ /dev/null @@ -1,587 +0,0 @@ -from typing import List, Tuple, Union - -import interegular -import pytest -import torch -from datasets.fingerprint import Hasher -from outlines_core.fsm.outlines_core_rs import Vocabulary -from outlines_core.fsm.regex import ( - BetterAlphabet, - BetterFSM, - _walk_fsm, - create_fsm_index_end_to_end, - create_fsm_index_tokenizer, - get_token_transition_keys, - get_vocabulary_transition_keys, - make_byte_level_fsm, - make_deterministic_fsm, - reduced_vocabulary, -) -from transformers import AutoTokenizer, PreTrainedTokenizer - - -def get_llama_tokenizer_types(): - """Get all the Llama tokenizer types/classes that need work-arounds. - - When they can't be imported, a dummy class is created. - - """ - try: - from transformers.models.llama import LlamaTokenizer - except ImportError: - - class LlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.llama import LlamaTokenizerFast - except ImportError: - - class LlamaTokenizerFast: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizer - except ImportError: - - class CodeLlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizerFast - except ImportError: - - class CodeLlamaTokenizerFast: # type: ignore - pass - - return ( - LlamaTokenizer, - LlamaTokenizerFast, - CodeLlamaTokenizer, - CodeLlamaTokenizerFast, - ) - - -class TransformerTokenizer: - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs): - self.tokenizer = tokenizer - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: torch.LongTensor) -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __hash__(self): - return hash(Hasher.hash(self.tokenizer)) - - def __eq__(self, other): - if isinstance(other, type(self)): - if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return ( - other.model_name == self.model_name and other.kwargs == self.kwargs - ) - else: - return other.tokenizer == self.tokenizer - return NotImplemented - - def __getstate__(self): - state = {"tokenizer": self.tokenizer} - return state - - def __setstate__(self, state): - self.__init__(state["tokenizer"]) - - -def identity(s): - return s - - -def to_bytes(s): - return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] - - -def merge_symbols(byte_hexs): - return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) - - -def token_str_to_trans_key(fsm, input_string): - return get_token_transition_keys( - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, - input_string, - ) - - -def walk_fsm_from_token_str_rust( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return _walk_fsm( - fsm.fsm_info.transitions, - fsm.fsm_info.initial, - fsm.fsm_info.finals, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match=full_match, - ) - - -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - -def test_walk_fsm(): - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "0", regex_fsm.initial, full_match=True) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=False - ) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "!", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - # This should fail, because state `1` reads nothing - res = tuple(walk_fsm_from_token_str_rust(regex_fsm, "0", 1, full_match=True)) - assert res == tuple() - - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "1", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "1", regex_fsm.initial, full_match=False - ) - ) - assert res == (2,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "12", regex_fsm.initial, full_match=True - ) - ) - assert res == (2, 3) - - pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") - fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - - res = tuple(walk_fsm_from_token_str_rust(fsm, "x ", fsm.initial, full_match=False)) - assert res == (2,) - - start_state = list(fsm.finals)[0] - res = tuple(walk_fsm_from_token_str_rust(fsm, "!", start_state, full_match=False)) - assert res == tuple() - - -@pytest.mark.parametrize( - "transform", - [ - identity, - to_bytes, - ], -) -def test_walk_fsm_multi_bytes(transform): - regex_pattern = interegular.parse_pattern("😂|[😇-😍][😈-😍]*") - str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=False, - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=True, - ) - ) - assert res == tuple() - - -def test_create_fsm_index_end_to_end(): - regex_str = "0|[1-9][0-9]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - tokens_to_token_ids = { - "blah": [0], - "1a": [1], - "2": [2], - "0": [3], - "": [4], - } - - res = create_fsm_index_end_to_end( - regex_fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - assert res == {0: {2: 2, 3: 1}, 2: {2: 2, 3: 2}} - - -def test_create_fsm_index_end_to_end_multi_byte(): - regex_str = "😇| [😈-😍][😇-😎]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - tokens_to_token_ids = { - "blah": [0], - "😈a": [1], - "😇": [2], - "😍": [3], - merge_symbols(("F0", "9F", "98", "8D")): [4], # '😍' - " 😍": [5], - merge_symbols((" ", "F0", "9F", "98", "8D")): [6], # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): [7], # ' 😍' incomplete - "": [8], - } - - res = create_fsm_index_end_to_end( - byte_fsm.fsm_info, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - assert res == {0: {5: 3, 6: 3, 7: 7, 2: 2}, 3: {2: 3, 3: 3, 4: 3}} - - -@pytest.mark.parametrize( - "hf_tokenizer_uri, revision", - [ - ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), - ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), - ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), - ( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "783fd50eb82d7f57758de033861f54d62dde234f", - ), - ], -) -def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - num_bytes_fsm_states = len(bytes_fsm.states) - assert num_bytes_fsm_states == 235 - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) - tokenizer = TransformerTokenizer(tokenizer) - - states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( - bytes_fsm, tokenizer - ) - - assert not empty_token_ids - assert len(states_to_token_subsets.get_transitions()) / num_fsm_states > 0.94 - - -@pytest.mark.parametrize( - "regex,string,should_accept", - [ - ("[a-c]+", "😀", False), - ("[^a-c]+", "😀", True), - ("😀+", "😀😀😀", True), - ("😀+", "a", False), - ("[😀-😍]{2}", "😈😈", True), - ("[😀-😍]{2}", "aa", False), - ("[^😀-😍]{2}", "aa", True), - ("[^😀-😍]{2}", "😈😈", False), - ("[^😀-😍]{2}", "😎😎", True), - ("[^😀-😍]{2}", "😎😓", True), - ("[^😀-😍]{2}", "😎😈", False), - ("[😀-🙌]{2}", "😎😈", True), - ("[^😀-🙌]{2}", "😎😈", False), - ("[^😀-🙌]{2}", "🙏🙏", True), - ("[^😀-🙌]{2}", "🙏😎", False), - ], -) -def test_make_byte_level_fsm(regex, string, should_accept): - str_fsm = interegular.parse_pattern(regex).to_fsm() - str_accepts = str_fsm.accepts(string) - assert str_accepts == should_accept - - byte_fsm = make_byte_level_fsm(str_fsm) - byte_accepts = byte_fsm.accepts(to_bytes(string)) # type: ignore - assert byte_accepts == str_accepts - - mix_fsm = make_byte_level_fsm(str_fsm, keep_utf8=True) - mix_accepts = mix_fsm.accepts(to_bytes(string)) # type: ignore - assert mix_accepts == str_accepts - - mix_accepts_utf8 = mix_fsm.accepts(string) # type: ignore - assert mix_accepts_utf8 == str_accepts - - def advance(fsm, state, seq): - for symbol in seq: - if state is None: - return None - key = fsm.alphabet[symbol] - state = fsm.map[state].get(key) - return state - - # verify each state along the pattern - str_state = str_fsm.initial - byte_state = byte_fsm.initial - mix_state = byte_fsm.initial - for symbol in string: - str_state = advance(str_fsm, str_state, symbol) - byte_state = advance(byte_fsm, byte_state, to_bytes(symbol)) - mix_state_utf8 = advance(mix_fsm, mix_state, symbol) - mix_state = advance(mix_fsm, mix_state, to_bytes(symbol)) - assert byte_state == str_state - assert mix_state == str_state - assert mix_state_utf8 == str_state - - -@pytest.mark.skip(reason="Only for local profiling") -def test_regex_index_performance(): - from line_profiler import LineProfiler # type: ignore [import] - - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TransformerTokenizer(tokenizer) - - res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) - assert len(res) > 1 - - profiler = LineProfiler(create_fsm_index_end_to_end) - - profiler.runctx( - "create_fsm_index_tokenizer(regex_fsm, tokenizer)", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-create_fsm_index.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -def test_token_trans_keys_identical(): - """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" - - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"z[ab]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - tokens_to_token_ids, _ = reduced_vocabulary(tokenizer) - token_str_to_tranition_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - # `a` and `b` both are workable, but `z` has distinct transition rules - assert interegular_fsm.accepts("zaz") - assert interegular_fsm.accepts("zbz") - assert token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"] - assert not token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - - -def test_token_trans_keys_walk_fsm(): - """assert _walk_fsm works using transition keys""" - - class MockTokenizer: - vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"a[bc]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - tokens_to_token_ids, _ = reduced_vocabulary(tokenizer) - token_str_to_tranition_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - Vocabulary.from_dict(tokens_to_token_ids), - frozenset(), - ) - - # verify initial state valid only for "ab" and "ac" using transition key seq - token_acceptance = {"ab": True, "ac": True, "az": False} - for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_to_tranition_keys[token] - state_seq = _walk_fsm( - regex_fsm.fsm_info.transitions, - regex_fsm.fsm_info.initial, - regex_fsm.fsm_info.finals, - token_trans_key_seq, - regex_fsm.initial, - False, - ) - is_accepted = len(state_seq) >= len(token_trans_key_seq) - assert should_accept == is_accepted - - -@pytest.mark.parametrize( - "rare_token", - [ - "�", - "��", - "�.", - ".�", - ".�.", - "▁�", - "�▁", - "▁�▁", - "?�", - "�?", - "?�?", - ], -) -def test_reduced_vocabulary_with_rare_tokens(rare_token): - """Assert reduced_vocabulary works with rare tokens. - - See [1] and [2] for context. - - [1]: https://github.com/dottxt-ai/outlines/pull/763 - [2]: https://github.com/dottxt-ai/outlines/pull/948 - [3]: https://github.com/dottxt-ai/outlines/pull/1153 - """ - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tokenizer = TransformerTokenizer(tokenizer=tokenizer) - tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 - reduced_vocabulary(tokenizer) - - -def test_reduced_vocabulary_with_byte_tokens(): - class MockTokenizer: - vocabulary = { - "string": 1, - b"\xa1": 2, # Qwen-Style - "eos": 3, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return b"\xef\xbf\xbd".decode() - - tokens_to_token_ids = reduced_vocabulary(MockTokenizer()) - - # See fsm.regex.get_token_transition_keys() - # FSM transition keys represents bytes as - assert tokens_to_token_ids[0] == {"string": [1], "\x00A1": [2]} diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py deleted file mode 100644 index d3c38365..00000000 --- a/tests/fsm/test_serialization.py +++ /dev/null @@ -1,56 +0,0 @@ -import pickle - -import pytest -from outlines_core.fsm.guide import RegexGuide -from transformers import AutoTokenizer - -from tests.fsm.test_regex import TransformerTokenizer - - -def test_serialization(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - serialized = pickle.dumps(fsm) - deserialized = pickle.loads(serialized) - - assert fsm.eos_tensor == deserialized.eos_tensor - assert fsm.initial_state == deserialized.initial_state - - -@pytest.mark.parametrize( - "hf_tokenizer_uri, revision", - [ - ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), - ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), - ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), - ( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "783fd50eb82d7f57758de033861f54d62dde234f", - ), - ], -) -def test_complex_serialization(hf_tokenizer_uri, revision): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) - tokenizer = TransformerTokenizer(tokenizer) - - fsm = RegexGuide.from_regex(regex_str, tokenizer) - - serialized = pickle.dumps(fsm) - deserialized = pickle.loads(serialized) - - assert fsm.eos_tensor == deserialized.eos_tensor - assert fsm.initial_state == deserialized.initial_state diff --git a/tests/fsm/test_statistical.py b/tests/fsm/test_statistical.py index 20ef28cd..ba4a7bf7 100644 --- a/tests/fsm/test_statistical.py +++ b/tests/fsm/test_statistical.py @@ -1,27 +1,12 @@ from typing import Callable, List, Optional import numpy as np -from outlines_core.fsm.guide import RegexGuide +from outlines_core.fsm import Guide, Index, Vocabulary from pytest import approx from scipy.stats import ks_2samp def test_generate_length(): - class MockTokenizer: - vocabulary = {"0": 1, "1": 2, "eos": 3} - inverse_vocabulary = {1: "0", 2: "1", 3: ""} - special_tokens = {"eos"} - eos_token_id = 3 - - def length(self): - return len(self.vocabulary) - - def convert_token_to_string(self, token): - return token - - def decode(self, token): - return self.inverse_vocabulary[token] - class NextToken: def __init__( self, @@ -43,17 +28,20 @@ def __call__( next_t = [self.rng.choice(self.states, p=prob / np.sum(prob))] return tokens + next_t if tokens is not None else next_t - def generate(model, tokenizer, regex_str) -> Optional[List[int]]: - n_tokens = tokenizer.length() + def generate(model, regex_str) -> Optional[List[int]]: + vocabulary = Vocabulary(3, {"0": [1], "1": [2]}) + index = Index(regex_str, vocabulary) + guide = Guide(index) - fsm = RegexGuide.from_regex(regex_str, tokenizer) - state: int = fsm.initial_state + n_tokens = len(vocabulary) + 1 # include eos token in count tokens = None - while state != -1: - allowed = fsm.get_next_instruction(state).tokens + allowed = guide.get_tokens() + while True: mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)] tokens = model(tokens, mask=mask) - state = fsm.get_next_state(state, tokens[-1]) + if tokens[-1] == 3: + break + allowed = guide.advance(tokens[-1]) return tokens def prob_non_markov(tokens: List[int]) -> np.array: @@ -75,17 +63,16 @@ def prob_markov(token: List[int]) -> np.array: n_samples: int = 250 regex_str: str = r"11[01]+|0[01]*" - tokenizer = MockTokenizer() model1 = NextToken(prob_markov, p0, states, 30127) model2 = NextToken(prob_non_markov, p0, states, 24601) lengths1: np.array = np.zeros((n_samples,)) lengths2: np.array = np.zeros((n_samples,)) for i in range(n_samples): - out1: List[int] = generate(model1, tokenizer, regex_str) - lengths1[i] = len(out1) - 1 # take off the eos token - out2: List[int] = generate(model2, tokenizer, regex_str) - lengths2[i] = len(out2) - 1 # take off the eos token + out1: List[int] = generate(model1, regex_str) + lengths1[i] = len(out1) - 1 + out2: List[int] = generate(model2, regex_str) + lengths2[i] = len(out2) - 1 # 2 sample KS test to check that lengths has the same distribution as # L = 1 + 2*X + Y, where X ~ Bern(0.75) and Y ~ Neg-Binom(1, 0.3) diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py new file mode 100644 index 00000000..e44e2dae --- /dev/null +++ b/tests/fsm/test_vocabulary.py @@ -0,0 +1,116 @@ +import copy +import pickle + +import pytest +from outlines_core.fsm import Vocabulary + + +@pytest.fixture(scope="session") +def vocabulary(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + return Vocabulary(eos_token_id, tokens) + + +def test_basic_vocabulary_interface(vocabulary): + assert vocabulary.get_eos_token_id() == 3 + assert vocabulary.get("1") == vocabulary.get(b"1") == [1] + assert len(vocabulary) == 2 + + vocabulary.insert("b", 4) + assert vocabulary.get("b") == [4] + assert len(vocabulary) == 3 + + vocabulary.insert(b"b", 5) + assert vocabulary.get("b") == vocabulary.get(b"b") == [4, 5] + assert len(vocabulary) == 3 + + vocabulary.remove("b") + assert vocabulary.get("b") is None + + # second remove doesn't fail too + vocabulary.remove("b") + assert vocabulary.get("b") is None + + assert vocabulary.get("a") == [2] + vocabulary.remove(b"a") + assert vocabulary.get("a") is None + + +def test_string_and_bytes_as_tokens(): + eos_token_id = 3 + tokens = {"1": [1], "a": [2]} + btokens = {b"1": [1], b"a": [2]} + vocabulary = Vocabulary(eos_token_id, tokens) + bvocabulary = Vocabulary(eos_token_id, btokens) + + assert ( + vocabulary.get_eos_token_id() == bvocabulary.get_eos_token_id() == eos_token_id + ) + assert vocabulary.get(b"1") == vocabulary.get("1") == [1] + assert bvocabulary.get(b"1") == bvocabulary.get("1") == [1] + assert len(vocabulary) == len(bvocabulary) == 2 + + +def test_do_not_supports_other_types(): + eos_token_id = 0 + + with pytest.raises( + TypeError, + match=r"Expected a dict with keys of type str or bytes and values of type list\[int\], got", + ): + Vocabulary(eos_token_id, 1) + + with pytest.raises( + TypeError, + match="Dict keys or/and values of the wrong types", + ): + Vocabulary(eos_token_id, {1: [1], 2: [2]}) + + +def test_get_bad_type(vocabulary): + with pytest.raises( + TypeError, + match="Expected a token of type str or bytes, got", + ): + vocabulary.get(1) + + +def test_insert_bad_type(vocabulary): + with pytest.raises( + TypeError, + match="Expected a token of type str or bytes, got", + ): + vocabulary.insert(1, 6) + + +def test_insert_eos_token(vocabulary): + with pytest.raises( + ValueError, match="EOS token should not be inserted into Vocabulary" + ): + vocabulary.insert("eos-token", 3) + + +def test_from_pretrained(): + vocabulary = Vocabulary.from_pretrained("gpt2") + assert vocabulary.get_eos_token_id() == 50256 + + +def test_pickling(vocabulary): + serialized = pickle.dumps(vocabulary) + deserialized = pickle.loads(serialized) + assert deserialized == vocabulary + + +def test_deepcopy(vocabulary): + vocabulary2 = copy.deepcopy(vocabulary) + assert vocabulary2 == vocabulary + + copy_vocabulary2 = copy.deepcopy(vocabulary2) + assert copy_vocabulary2 == vocabulary2 + + vocabulary2.insert("new", 4) + assert vocabulary2 != copy_vocabulary2 + assert len(vocabulary2) - 1 == len(copy_vocabulary2) + assert copy_vocabulary2 == vocabulary + assert len(copy_vocabulary2) == len(vocabulary)