diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 53d20b70b..b4616a183 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.10"] + python-version: ["3.10", "3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/benchmarks/bench_cfg_guide.py b/benchmarks/bench_cfg_guide.py index 14dc31c73..477104435 100644 --- a/benchmarks/bench_cfg_guide.py +++ b/benchmarks/bench_cfg_guide.py @@ -7,8 +7,6 @@ from outlines.fsm.guide import CFGGuide from outlines.models.transformers import TransformerTokenizer -from .common import ensure_numba_compiled - random.seed(42) @@ -30,9 +28,6 @@ class CFGGuideBenchmark: def setup(self, grammar_name): self.tokenizer = get_tiny_tokenizer() - ensure_numba_compiled( - self.tokenizer - ) # numba not currently used, but will be in the future self.prebuilt_cfg_guide = CFGGuide( benched_grammars[grammar_name], self.tokenizer ) diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 8d1ceeb24..8990b015c 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -2,7 +2,7 @@ from outlines.fsm.guide import RegexGuide from outlines.fsm.json_schema import build_regex_from_schema -from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 +from .common import setup_tokenizer # noqa: E402 simple_schema = """{ "$defs": { @@ -69,7 +69,6 @@ class JsonSchemaBenchmark: def setup(self, schema_name): self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] - ensure_numba_compiled(self.tokenizer) @cache_disabled() def time_json_schema_to_regex(self, schema_name): diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py deleted file mode 100644 index 2713707e5..000000000 --- a/benchmarks/bench_numba_compile.py +++ /dev/null @@ -1,34 +0,0 @@ -import importlib - -import interegular -import numba - -from outlines.caching import cache_disabled -from outlines.fsm import regex - -from .common import setup_tokenizer - - -class NumbaCompileBenchmark: - def setup(self): - self.tokenizer = setup_tokenizer() - self.regex = regex - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - self.original_njit = original_njit - numba.njit = mock_njit - importlib.reload(self.regex) - self.regex_pattern, _ = self.regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - - def teardown(self): - numba.njit = self.original_njit - - @cache_disabled() - def time_compile_numba(self): - self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 099f94df2..7aaef6bac 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,7 +1,7 @@ from outlines.caching import cache_disabled from outlines.fsm.guide import RegexGuide -from .common import ensure_numba_compiled, setup_tokenizer +from .common import setup_tokenizer regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -21,7 +21,6 @@ class RegexGuideBenchmark: def setup(self, pattern_name): self.tokenizer = setup_tokenizer() - ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] @cache_disabled() @@ -34,7 +33,6 @@ class MemoryRegexGuideBenchmark: def setup(self, pattern_name): self.tokenizer = setup_tokenizer() - ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] @cache_disabled() diff --git a/benchmarks/common.py b/benchmarks/common.py index 7d999ea9b..e920888f5 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,14 +1,8 @@ from transformers import AutoTokenizer -from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) - - -def ensure_numba_compiled(tokenizer): - RegexGuide("a", tokenizer) - return True diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py deleted file mode 100644 index bfcf55c03..000000000 --- a/outlines/fsm/fsm.py +++ /dev/null @@ -1,47 +0,0 @@ -import warnings -from typing import TYPE_CHECKING, Iterable, NewType, Optional - -from outlines.fsm.guide import RegexGuide, StopAtEOSGuide - -if TYPE_CHECKING: - from outlines.models.tokenizer import Tokenizer - -FSMState = NewType("FSMState", int) - - -class StopAtEosFSM(StopAtEOSGuide): - """FSM to generate text until EOS has been generated.""" - - def __init__(self, tokenizer: "Tokenizer"): - warnings.warn( - UserWarning( - "The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead." - ) - ) - super().__init__(tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) - - -class RegexFSM(RegexGuide): - """FSM to generate text that is in the language of a regular expression.""" - - def __init__(self, regex_string: str, tokenizer): - warnings.warn( - UserWarning( - "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." - ) - ) - super().__init__(regex_string, tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index b7b121fe6..a3a2c7369 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -1,72 +1,31 @@ import collections import copy import warnings -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - List, - Optional, - Protocol, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Generator, Union -import interegular import torch from lark.indenter import DedentError from lark.lexer import UnexpectedCharacters, UnexpectedToken +from outlines_core.fsm.guide import Generate +from outlines_core.fsm.guide import Guide as CoreGuide +from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide +from outlines_core.fsm.guide import Write +from outlines_core.fsm.guide import ( + create_states_mapping as uncached_create_states_mapping, +) from outlines import grammars from outlines.caching import cache from outlines.fsm.parsing import PartialLark, PartialParserState -from outlines.fsm.regex import ( - create_fsm_index_tokenizer, - make_byte_level_fsm, - make_deterministic_fsm, -) if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer -@dataclass(frozen=True) -class Write: - """Write instruction. - - Attributes - ---------- - tokens - The sequence of tokens to be added to the current sequence by the - generation process. - - """ - - tokens: List[int] - - -@dataclass(frozen=True) -class Generate: - """Generate instruction - - Attributes - ---------- - tokens - The tokens that lead to a valid completion if generated. A value - of ``None`` indicates that all tokens are allowed. - """ - - tokens: Optional[List[int]] - - Instruction = Union[Write, Generate] -class Guide(Protocol): +class Guide(CoreGuide): """Base definition of a generation guide. A generation guide defines the behavior of a finite-state machine that guides @@ -78,18 +37,6 @@ class Guide(Protocol): initial_state: Any - def get_next_instruction(self, state: Any) -> Instruction: - ... - - def get_next_state(self, state: Any, token_id: int) -> Any: - ... - - def is_final_state(self, state: Any) -> bool: - ... - - def copy(self) -> "Guide": - ... - class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" @@ -127,64 +74,15 @@ def copy(self): @cache() -def create_states_mapping( - regex_string: str, - tokenizer: "Tokenizer", - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose. - - Parameters - ---------- - regex_string: (`str`): - The regular expression string to generate a states mapping for. - tokenizer: (`Tokenizer`): - The model's tokenizer. - regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*): - A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): - A set of token ids that correspond to empty strings. - final_states: (`set`): - A set of final states in the FSM. - """ - regex_pattern = regex_parser(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True, frozen_tokens=frozen_tokens - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer, frozen_tokens=frozen_tokens - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) +def create_states_mapping(regex_string, tokenizer): + return uncached_create_states_mapping(regex_string, tokenizer) - return states_to_token_maps, empty_token_ids, regex_fsm.finals - -class RegexGuide(Guide): - """Guide to generate text in the language of a regular expression.""" - - initial_state = 0 +class RegexGuide(CoreRegexGuide): + """ + Guide to generate text in the language of a regular expression. + CoreRegexGuide with outlines cache + """ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): ( @@ -196,119 +94,6 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): self.final_states = fsm_finals | {-1} self._cache_state_to_token_tensor() - def get_next_instruction(self, state: int) -> Instruction: - """Return the next instruction for guided generation. - - The initialization of the guide builds an index which maps FSM states to a - map from authorized tokens to the state in which the guide needs to move - if said token is generated. Therefore the authorized tokens at the - current state are the keys of the map returned by the value of the index - for current state. - - If the current state is not contained in the end this means that we are - in a final state of the guide. We only authorize EOS tokens in the final - state. - - Parameters - ---------- - state - The current state of the guide. - - Returns - ------- - A `Generate` instance that contains the model and the allowed token ids. - - """ - next_tokens_mask = self.states_to_token_mask.get(state) - if next_tokens_mask is None: - return Write(torch.tensor([self.eos_token_id])) - - return Generate(next_tokens_mask) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - We use the index to determine to which state the guide should transition - given the token that was just generated. - - Parameters - ---------- - state - The current state of the guide. - token_id - The id of the token that was just generated. - - Returns - ------- - The new state of the guide. - - """ - if token_id == self.eos_token_id or state not in self.states_to_token_maps: - return -1 - - last_token_to_end_state = self.states_to_token_maps[state] - next_state = last_token_to_end_state.get(token_id) - if next_state is None: - next_state = -1 - - return next_state - - @classmethod - def from_interegular_fsm( - cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" - ): - from_interegular_instance = cls.__new__(cls) - - def create_states_mapping_from_interegular_fsm( - fsm: interegular.fsm.FSM, - ) -> Tuple[dict, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids - - ( - from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, - ) = create_states_mapping_from_interegular_fsm(interegular_fsm) - from_interegular_instance.eos_token_id = tokenizer.eos_token_id - from_interegular_instance._cache_state_to_token_tensor() - return from_interegular_instance - - def _cache_state_to_token_tensor(self): - """ - cache state -> token int tensor - this increases performance of mask construction substantially - """ - self.states_to_token_mask = { - state: torch.tensor(list(next_tokens_to_end_states.keys())) - for state, next_tokens_to_end_states in self.states_to_token_maps.items() - } - - def is_final_state(self, state: int) -> bool: - """Determine whether the current state of the guide is a final state.""" - return state in self.final_states - - def copy(self): - return self - CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"]) diff --git a/outlines/fsm/parsing.py b/outlines/fsm/parsing.py index f780fb46e..92d3cc166 100644 --- a/outlines/fsm/parsing.py +++ b/outlines/fsm/parsing.py @@ -34,8 +34,7 @@ ) from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser - -from outlines.fsm.regex import ( +from outlines_core.fsm.regex import ( fsm_union, get_sub_fsms_from_seq, get_token_transition_keys, diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py deleted file mode 100644 index 6b105a7b9..000000000 --- a/outlines/fsm/regex.py +++ /dev/null @@ -1,1003 +0,0 @@ -import re -from collections import namedtuple -from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Dict, - FrozenSet, - Generator, - List, - Sequence, - Set, - Tuple, - Union, - cast, -) - -import numba -import numpy as np -from interegular.fsm import ( - FSM, - Alphabet, - OblivionError, - State, - TransitionKey, - _AnythingElseCls, - anything_else, -) -from numba.typed.typedobjectutils import _nonoptional -from tqdm import tqdm - -if TYPE_CHECKING: - from outlines.models.tokenizer import Tokenizer - - -class BetterAlphabet(Alphabet): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert anything_else in self._symbol_mapping - self.anything_value = self._symbol_mapping[anything_else] - - def __getitem__(self, item): - return self._symbol_mapping.get(item, self.anything_value) - - def copy(self): - return BetterAlphabet(self._symbol_mapping.copy()) - - -class BetterFSM(FSM): - flat_transition_map: Dict[Tuple[int, int], int] - trans_key_to_states: Dict[int, List[int]] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not isinstance(self.alphabet, BetterAlphabet): - self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) - - flat_transition_map = {} - trans_key_to_states = {} - for from_state, trans_map in self.map.items(): - for trans_key, to_state in trans_map.items(): - flat_transition_map[(from_state, trans_key)] = to_state - trans_key_to_states.setdefault(trans_key, set()).add(from_state) - - self.__dict__["trans_key_to_states"] = trans_key_to_states - self.__dict__["flat_transition_map"] = flat_transition_map - self.__dict__["_fsm_info"] = None - - def copy(self): - return BetterFSM( - alphabet=self.alphabet.copy(), - states=self.states.copy(), - initial=self.initial, - finals=self.finals.copy(), - map=self.map.copy(), - __no_validation__=True, - ) - - @property - def fsm_info(self): - if self._fsm_info is None: - flat_transition_map_items = np.fromiter( - ((a[0], a[1], b) for a, b in self.flat_transition_map.items()), - dtype=np.dtype("int64, int64, int64"), - ) - trans_key_to_states_items = np.fromiter( - ((k, z) for k, v in self.trans_key_to_states.items() for z in v), - dtype=np.dtype("int64, int64"), - ) - alphabet_symbol_mapping_items = [ - (k, v) - for k, v in self.alphabet._symbol_mapping.items() - if k != anything_else - ] - nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) - self.__dict__["_fsm_info"] = create_fsm_info( - self.initial, - nb_finals, - flat_transition_map_items, - trans_key_to_states_items, - self.alphabet.anything_value, - alphabet_symbol_mapping_items, - ) - - return self._fsm_info - - -nb_int_list_type = numba.types.ListType(numba.int64) -nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unicode_type = numba.types.unicode_type - - -@numba.njit(cache=True) -def create_fsm_info( - py_initial, - py_finals, - flat_transition_map_items, - trans_key_to_states_items, - py_anything_value, - alphabet_symbol_mapping_items, -): - trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type) - for trans_key_and_state in trans_key_to_states_items: - trans_key_to_states.setdefault( - trans_key_and_state[0], numba.typed.List.empty_list(numba.int64) - ).append(trans_key_and_state[1]) - - flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64) - for trans_key_and_state in flat_transition_map_items: - flat_transition_map[ - (trans_key_and_state[0], trans_key_and_state[1]) - ] = trans_key_and_state[2] - - # use 2-char strings so that we can represent incomplete utf-8 sequences - # as 2-hex-digit pairs - alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) - for symbol_and_trans_key in alphabet_symbol_mapping_items: - alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] - - initial = numba.int64(py_initial) - - finals = set() - for final in py_finals: - finals.add(final) - - anything_value = numba.int64(py_anything_value) - - return FSMInfo( - initial, - finals, - flat_transition_map, - trans_key_to_states, - anything_value, - alphabet_symbol_map, - ) - - -FSMInfo = namedtuple( - "FSMInfo", - [ - "initial", - "finals", - "transitions", - "trans_key_to_states", - "alphabet_anything_value", - "alphabet_symbol_mapping", - ], -) - - -TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"] - - -def add_to_transition_trie( - trie: TransitionTrie, - key_seq: Sequence[TransitionKey], - value: Union[State, None], -): - for key in key_seq[:-1]: - trie = cast(TransitionTrie, trie.setdefault(key, {})) - assert isinstance(trie, dict), "key sequence of incompatible length" - trie[key_seq[-1]] = value - - -# merge default_trie into the trie, only updating entries not present in the trie -def transition_trie_setdefault( - trie: TransitionTrie, - default_trie: TransitionTrie, -): - for key, default_value in default_trie.items(): - dest_value = trie.get(key) - if isinstance(dest_value, dict) and isinstance(default_value, dict): - transition_trie_setdefault(dest_value, default_value) - elif key not in trie: - trie[key] = default_value - - -def byte_symbol(byte: int) -> str: - return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) - - -def make_byte_level_fsm( - fsm: FSM, keep_utf8: bool = False, frozen_tokens: List[str] = [] -) -> FSM: - """Convert an FSM to a byte-level FSM, expanding multi-byte characters as - sequences of single-byte transitions. - - Parameters - ---------- - fsm: (`interegular.FSM`): - The token-level FSM to convert to a byte-level FSM. - keep_utf8: (`bool`, *optional*): - If set to True, the original utf-8 characters are kept as-is. Defaults to - False. NOTE: we're representing bytes as strings to keep it type-compatible. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is in the byte-level FSM. That is, - these tokens will not be expanded into byte-level transitions. Defaults to - an empty list. - - Returns - ------- - `interegular.FSM`: A byte-level FSM. - """ - - anything_else_key = fsm.alphabet[anything_else] - symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey] = {} - map: Dict[State, Dict[TransitionKey, State]] = {} - states: List[State] = list(fsm.states) - - # identify all multi-byte characters in the alphabet and build a mapping - # from the original transition keys to sequences of new keys for each byte - key_to_key_seqs: Dict[TransitionKey, Set[Tuple[TransitionKey, ...]]] = {} - all_key_seqs: Set[Tuple[TransitionKey, ...]] = set() - all_bytes: Set[int] = set() - max_key = max(fsm.alphabet.values()) - for symbol, transition_key in fsm.alphabet.items(): - assert symbol == anything_else or symbol in frozen_tokens or len(symbol) == 1 - if symbol == anything_else or symbol in frozen_tokens or ord(symbol) < 0x80: - symbol_mapping[symbol] = transition_key - else: - if keep_utf8: - symbol_mapping[symbol] = transition_key - key_list: List[TransitionKey] = [] - for byte in symbol.encode("utf-8"): - symbol = byte_symbol(byte) - if symbol not in symbol_mapping: - symbol_mapping[symbol] = max_key = TransitionKey(max_key + 1) - all_bytes.add(byte) - key_list.append(symbol_mapping[symbol]) - key_seq = tuple(key_list) - key_to_key_seqs.setdefault(transition_key, set()).add(key_seq) - all_key_seqs.add(key_seq) - - # add all remaining multi-byte utf-8 bytes to the alphabet - # (this is required to represent `anything_else`) - utf8_ranges = { - 1: (0x80, 0xC0), # continuation bytes - 2: (0xC0, 0xE0), # 2-byte sequences - 3: (0xE0, 0xF0), # 3-byte sequences - 4: (0xF0, 0xF8), # 4-byte sequences - } - utf8_all_keys: Dict[int, Set[TransitionKey]] = { - n: set() for n in utf8_ranges.keys() - } - for n, (start, end) in utf8_ranges.items(): - range_key = max_key = TransitionKey(max_key + 1) - for byte in range(start, end): - byte_key = symbol_mapping.setdefault(byte_symbol(byte), range_key) - utf8_all_keys[n].add(byte_key) - - # cache of intermediate transition states by transitions from that state - state_cache: Dict[FrozenSet[Tuple[TransitionKey, State]], State] = {} - - # helper function to create multi-step transitions between states - max_state = max(fsm.states) - - def create_seq_transitions( - seq_transitions_trie: TransitionTrie, - ) -> Dict[TransitionKey, State]: - nonlocal max_state - result: Dict[TransitionKey, State] = {} - - for next_key, next_trie in seq_transitions_trie.items(): - if isinstance(next_trie, dict): - next_transitions = create_seq_transitions(next_trie) - if not next_transitions: - continue - cache_key = frozenset(next_transitions.items()) - next_state = state_cache.get(cache_key) - if next_state is None: - next_state = max_state = State(max_state + 1) - map[next_state] = next_transitions - state_cache[cache_key] = next_state - states.append(next_state) - result[next_key] = next_state - elif next_trie is not None: - result[next_key] = next_trie - - return result - - # create new states and transitions - for state, transitions in fsm.map.items(): - seq_transitions_trie: TransitionTrie = {} - state_map: Dict[TransitionKey, State] = {} - - for transition_key, to_state in transitions.items(): - if transition_key in key_to_key_seqs: - if keep_utf8: - state_map[transition_key] = to_state - for key_seq in key_to_key_seqs[transition_key]: - add_to_transition_trie(seq_transitions_trie, key_seq, to_state) - else: # keep single-byte transitions as is - state_map[transition_key] = to_state - - # handle multi-byte anything_else sequences - if anything_else_key in transitions: - for key_seq in all_key_seqs: - add_to_transition_trie(seq_transitions_trie, key_seq, None) - - anything_else_trie: TransitionTrie = {} - cont_trie: Union[TransitionTrie, State] = transitions[anything_else_key] - for n in range(2, 5): - cont_trie = {key: cont_trie for key in utf8_all_keys[1]} - for key in utf8_all_keys[n]: - anything_else_trie[key] = cont_trie - - transition_trie_setdefault(seq_transitions_trie, anything_else_trie) - - # create new states and transitions - next_transitions = create_seq_transitions(seq_transitions_trie) - state_map.update(next_transitions) - map[state] = state_map - - return FSM( - alphabet=Alphabet(symbol_mapping), - states=states, - initial=fsm.initial, - finals=fsm.finals, - map=map, - ) - - -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - -def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: - """Construct an equivalent FSM with deterministic state labels.""" - old_to_new_trans_keys = { - trans_key: i - for i, (trans_key, _) in enumerate( - sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) - ) - } - - new_symbol_mapping = { - symbol: old_to_new_trans_keys[trans_key] - for symbol, trans_key in fsm.alphabet._symbol_mapping.items() - } - - new_alphabet = BetterAlphabet(new_symbol_mapping) - - new_map = { - from_state: { - old_to_new_trans_keys[trans_key]: to_state - for trans_key, to_state in trans_map.items() - } - for from_state, trans_map in fsm.map.items() - } - - old_to_new_states = {} - old_to_new_states[fsm.initial] = 0 - - i = 0 - seen = {fsm.initial} - old_state_queue = [fsm.initial] - while old_state_queue: - old_state = old_state_queue.pop(-1) - transitions = new_map[old_state] - sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) - for _, old_state in sorted_transitions: - if old_state not in seen: - old_state_queue.append(old_state) - seen.add(old_state) - if old_state not in old_to_new_states: - i += 1 - old_to_new_states[old_state] = i - - new_map = dict( - sorted( - ( - ( - old_to_new_states[from_state], - dict( - sorted( - ( - (trans_key, old_to_new_states[to_state]) - for trans_key, to_state in trans_map.items() - ), - key=lambda v: v[0], - ) - ), - ) - for from_state, trans_map in new_map.items() - ), - key=lambda v: v[0], - ) - ) - - new_initial = 0 - new_finals = frozenset( - sorted(old_to_new_states[old_state] for old_state in fsm.finals) - ) - new_states = frozenset(sorted(new_map.keys())) - - new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) - - return new_fsm, old_to_new_states - - -@numba.njit(nogil=True, cache=True) -def _walk_fsm( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - token_transition_keys: Sequence[int], - start_state: int, - full_match: bool = True, -) -> List[int]: - state = start_state - accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) - last_final_idx: int = numba.uint64(0) - - # Iterate over token transition key sequence. The transition key - # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_transition_keys): - new_state = fsm_transitions.get((state, trans_key)) - - if new_state is None: - if not full_match and last_final_idx > 0: - return accepted_states[:last_final_idx] - - return numba.typed.List.empty_list(numba.int64) - - state = new_state - - if state in fsm_finals: - last_final_idx = numba.uint64(i + 1) - - accepted_states.append(_nonoptional(state)) - - if full_match and last_final_idx - 1 != i: - return numba.typed.List.empty_list(numba.int64) - - return accepted_states - - -def walk_fsm( - fsm: BetterFSM, - token_transition_keys: Sequence[int], - start_state: int, - full_match: bool = True, -) -> List[int]: - fsm_finals = fsm.finals - - state = start_state - accepted_states: List[int] = [] - last_final_idx: int = 0 - - fsm_transitions = fsm.flat_transition_map - - # Iterate over token transition key sequence. The transition key - # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_transition_keys): - new_state = fsm_transitions.get((state, trans_key)) - - if new_state is None: - if not full_match and last_final_idx > 0: - return accepted_states[:last_final_idx] - - return [] - - state = new_state - - if state in fsm_finals: - last_final_idx = i + 1 - - accepted_states.append(state) - - if full_match and last_final_idx - 1 != i: - return [] - - return accepted_states - - -def fsm_union( - fsms: Sequence[FSM], -) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: - """Construct an FSM representing the union of the FSMs in `fsms`. - - This is an updated version of `interegular.fsm.FSM.union` made to return an - extra map of component FSMs to the sets of state transitions that - correspond to them in the new FSM. - - """ - - alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) - - indexed_fsms = tuple(enumerate(fsms)) - - initial = {i: fsm.initial for (i, fsm) in indexed_fsms} - - # Dedicated function accepting a "superset" and returning the next - # "superset" obtained by following this transition in the new FSM - def follow(current_state, new_transition: int): - next = {} - for i, f in indexed_fsms: - old_transition = new_to_old[i][new_transition] - if ( - i in current_state - and current_state[i] in f.map - and old_transition in f.map[current_state[i]] - ): - next[i] = f.map[current_state[i]][old_transition] - if not next: - raise OblivionError - return next - - states = [initial] - finals: Set[int] = set() - map: Dict[int, Dict[int, int]] = {} - - # Map component FSMs to their new state-to-state transitions, finals, and a - # map translating component FSM states to aggregate FSM states - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ] = {} - - i = 0 - while i < len(states): - state = states[i] - - # Add to the finals of the aggregate FSM whenever we hit a final in a - # component FSM - if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): - finals.add(i) - - # Compute the map for this state - map[i] = {} - for transition in alphabet.by_transition: - try: - next = follow(state, transition) - except OblivionError: - # Reached an oblivion state; don't list it - continue - else: - try: - # TODO: Seems like this could--and should--be avoided - j = states.index(next) - except ValueError: - j = len(states) - states.append(next) - - map[i][transition] = j - - for fsm_id, fsm_state in next.items(): - ( - fsm_transitions, - fsm_finals, - fsm_old_to_new, - ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) - old_from = state[fsm_id] - old_to = fsm_state - fsm_old_to_new.setdefault(old_from, set()).add(i) - fsm_old_to_new.setdefault(old_to, set()).add(j) - fsm_transitions.add((i, j)) - if fsm_state in fsms[fsm_id].finals: - fsm_finals.add(j) - - i += 1 - - fsm = FSM( - alphabet=alphabet, - states=range(len(states)), - initial=0, - finals=finals, - map=map, - __no_validation__=True, - ) - - fsm, old_to_new_states = make_deterministic_fsm(fsm) - _fsms_to_trans_finals = { - fsm_id: ( - {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, - {old_to_new_states[s] for s in finals}, - { - old_state: {old_to_new_states[new_state] for new_state in new_states} - for old_state, new_states in old_to_new.items() - }, - ) - for fsm_id, (transitions, finals, old_to_new) in sorted( - fsms_to_trans_finals.items(), key=lambda x: x[0] - ) - } - - return ( - fsm, - _fsms_to_trans_finals, - ) - - -def get_sub_fsms_from_seq( - state_seq: Sequence[int], - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ], -) -> Generator[Tuple[int, bool, bool], None, None]: - """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. - - Parameters - ---------- - state_seq - A state sequence. - fsms_to_trans_finals - A map from FSM indices to tuples containing sets of their state transitions - and sets of the final/accept states. - - Returns - ------- - A generator returning tuples containing each sub-FSM index (in the order - they were union-ed to construct `fsm`) and booleans indicating whether or - not there is another valid transition from the last state in the sequence - for the associated sub-FSM (i.e. if the FSM can continue - accepting/matching) and whether or not the sequence ends in a final state - of the sub-FSM. - """ - state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) - last_fsm_state = state_seq[-1] - yield from ( - ( - # The sub-FMS index - fsm_idx, - # Is there another possible transition in this sub-FSM? - any(last_fsm_state == from_s for (from_s, to_s) in transitions), - # Is this sub-FSM in a final state? - state_seq[-1] in finals, - ) - for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() - if state_seq_transitions.issubset(transitions) - ) - - -@numba.njit(cache=True, nogil=True) -def state_scan_tokens( - fsm_transitions: Dict[Tuple[int, int], int], - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - fsm_initial: int, - fsm_finals: Set[int], - vocabulary: List[Tuple[str, Sequence[int]]], - vocabulary_transition_keys: List[Sequence[int]], - start_state: int, -) -> Set[Tuple[int, int]]: - res = set() - - for (token, token_ids), token_transition_keys in zip( - vocabulary, vocabulary_transition_keys - ): - state_seq = _walk_fsm( - fsm_transitions, - fsm_initial, - fsm_finals, - token_transition_keys, - start_state, - False, - ) - - if state_seq is not None and len(state_seq) < len(token_transition_keys): - continue - - for token_id in token_ids: - res.add((token_id, state_seq[-1])) - - return res - - -@numba.njit(cache=True, nogil=True) -def get_token_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - token_str: str, -) -> Sequence[int]: - """ - Get the sequence of transition keys for an individual string - with respect to an FSMs alphabet symbol mapping - - This requires parsing the null-byte prefix rules of a byte-fsm: - - If two characters are prefixed by \x00, they are the grouped as a hex-byte - - Otherwise they are a standalone utf-8 character - """ - token_transition_keys = [] - i = 0 - while i < len(token_str): - if token_str[i] == "\x00" and i != len(token_str) - 1: - symbol = token_str[i : i + 3] - i += 3 - else: - symbol = token_str[i] - i += 1 - - token_transition_keys.append( - alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - ) - - token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64) - for j in range(len(token_transition_keys)): - token_transition_keys_array[j] = token_transition_keys[j] - return token_transition_keys_array - - -@numba.njit(cache=True, nogil=True) -def get_vocabulary_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - vocabulary: List[Tuple[str, Sequence[int]]], - frozen_tokens: List[str] = numba.typed.List.empty_list(numba.types.unicode_type), -) -> List[Sequence[int]]: - """ - Calculate the sequence transition keys for each token str within a vocabulary - - Parameters - ---------- - alphabet_symbol_mapping: (`Dict[str, int]`): - A mapping from an alphabet symbol in a FSM to its corresponding transition key. - alphabet_anything_value: (`int`): - The transition key for the anything_else symbol in the FSM. - vocabulary: (`List[Tuple[str, Sequence[int]]]`): - A list of tuples, each containing a token and a list of equivalent token ids. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that are kept as-is when transforming the FSM. - Defaults to an empty list. - - Returns - ------- - `List[Sequence[int]]`: - A list of token transition keys for each token in the vocabulary. - """ - vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:]) - for token_str, _ in vocabulary: - # Since these tokens are not expanded into byte-level transitions, we can - # simply get their transition keys directly. - if token_str in frozen_tokens: - token_transition_keys = np.array( - [alphabet_symbol_mapping[token_str]], dtype=np.int64 - ) - else: - token_transition_keys = get_token_transition_keys( - alphabet_symbol_mapping, alphabet_anything_value, token_str - ) - vocab_transition_keys.append(token_transition_keys) - - return vocab_transition_keys - - -def create_fsm_index_end_to_end( - fsm_info: FSMInfo, - vocabulary: List[Tuple[str, Sequence[int]]], - frozen_tokens: List[str] = [], -) -> Dict[int, Set[Tuple[int, int]]]: - """Create an FSM state-to-vocabulary map/index through end-to-end token parsing. - - Parameters - ---------- - fsm_info: (`interegular.FSMInfo`): - The FSM information object containing the FSM's alphabet, transitions, initial - and final states, and other relevant information. - vocabulary: (`List[Tuple[str, Sequence[int]]]`): - A list of tuples, each containing a token and a list of equivalent token ids. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that are kept as-is when transforming the FSM. - - Returns - ------- - `Dict[int, Set[Tuple[int, int]]]`: - A mapping from FSM states to sets of tuples containing token ids and the end - states of the FSM after parsing the token. - """ - - # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this - # code, too. - states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} - seen: Set[int] = set() - next_states = {fsm_info.initial} - - pbar = tqdm( - total=len(set(fsm_info.transitions.values())) - + 1, # all transitions plus initial - desc="Compiling FSM index for all state transitions", - ) - - vocabulary_transition_keys = get_vocabulary_transition_keys( - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - frozen_tokens=( - numba.typed.List(frozen_tokens) - if len(frozen_tokens) > 0 - else numba.typed.List.empty_list(numba.types.unicode_type) - ), - ) - - while next_states: - start_state = next_states.pop() - - token_ids_end_states = state_scan_tokens( - fsm_info.transitions, - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - fsm_info.initial, - fsm_info.finals, - vocabulary, - vocabulary_transition_keys, - start_state, - ) - - for token_id_and_end_state in token_ids_end_states: - states_to_token_subsets.setdefault(start_state, set()).add( - token_id_and_end_state - ) - end_state = token_id_and_end_state[1] - if end_state not in seen: - next_states.add(end_state) - - if start_state not in seen: - pbar.update(1) - seen.add(start_state) - - pbar.close() - - return states_to_token_subsets - - -re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") - -# The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*" -# suffix is required to handle the NorwAI tokenizer. -re_replacement_seq = re.compile(r"^▁*�+\.*$") - - -# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode -@lru_cache() -def gpt2_bytes_to_unicode(): - """ - Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control - characters the bpe code barfs on. - - The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab - if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for - decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup - tables between utf-8 bytes and unicode strings. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -@lru_cache() -def gpt2_unicode_to_bytes(): - return {v: k for k, v in gpt2_bytes_to_unicode().items()} - - -# TODO: Cannot cache typed collections to disk, yet. See -# https://github.com/numba/numba/issues/4698 -@lru_cache -def reduced_vocabulary( - tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: - """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" - empty_token_ids = set() - vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} - for token, token_idx in tokenizer.vocabulary.items(): - if token in tokenizer.special_tokens: - continue - - token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( - token - ) - - if token_str: - if isinstance(token, bytes): - # Handle BPE tokenizers where the tokens are directly stored as bytes - # https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md#regular-tokens - token_str = "".join(byte_symbol(b) for b in token) - - elif "\ufffd" in token_str and not re_replacement_seq.match(token): - # invalid utf-8 sequences are replaced with � (\ufffd), but there - # might also be tokens specifically for �, ��, ���, etc. - - if re_llama_byte_token.match(token): - # llama-like tokenizers have <0xXX> tokens for all - # bytes >= 0x80 and represent all incomplete utf-8 - # sequences using such tokens - token_bytes = [int(token[3:5], 16)] - else: - # gpt2-like tokenizers have multi-byte tokens that can - # have a mix of full and incomplete utf-8 characters, - # for example, b` \xf0` can be one token; these tokenizers - # map each byte to a valid utf-8 character - token_bytes = cast( - List[int], [gpt2_unicode_to_bytes().get(c) for c in token] - ) - if None in token_bytes: - raise RuntimeError( - f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" - ) - token_str = "".join(byte_symbol(b) for b in token_bytes) - - vocabulary.setdefault(token_str, []).append(token_idx) - else: - empty_token_ids.add(numba.int64(token_idx)) - - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - nb_unicode_type, - numba.int64[:], - ) - ) - ) - for token_str, token_ids in vocabulary.items(): - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_str, token_ids_np)) - - return vocabulary_nb, empty_token_ids - - -def create_fsm_index_tokenizer( - fsm: BetterFSM, tokenizer: "Tokenizer", frozen_tokens: List[str] = [] -) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: - """Construct an FMS index from a tokenizer. - - This uses the end-to-end approach of `create_fsm_index_end_to_end`. - - Parameters - ---------- - fsm: (`BetterFSM`): - A cache-friendly FSM. Other interegular FSMs can also be used, but caching - may not work as expected. - tokenizer: (`Tokenizer`): - The model's tokenizer. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is when expanding the token-level - FSM into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): - A set of token ids that correspond to empty strings. - - .. warning:: - - `fsm` needs to be deterministically ordered so that future caching makes sense. - """ - vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) - - states_to_token_subsets = create_fsm_index_end_to_end( - fsm.fsm_info, vocabulary, frozen_tokens - ) - - # Allow transitions to EOS from all terminals FSM states that are - # reachable - # TODO: Do we really need this anymore? - for state in fsm.fsm_info.finals: - subset = states_to_token_subsets.get(state) - if subset is not None: - subset.add((tokenizer.eos_token_id, state)) - - # Convert to token-to-end-state maps - states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} - - return states_to_token_subsets, empty_token_ids diff --git a/pyproject.toml b/pyproject.toml index 937b0c9c7..7229afa83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "outlines" authors= [{name = "Outlines Developers"}] description = "Probabilistic Generative Model Programming" -requires-python = ">=3.8" +requires-python = ">=3.9" license = {text = "Apache-2.0"} keywords=[ "machine learning", @@ -32,7 +32,6 @@ dependencies = [ "cloudpickle", "diskcache", "pydantic>=2.0", - "numba", "referencing", "jsonschema", "requests", @@ -41,6 +40,7 @@ dependencies = [ "typing_extensions", "pycountry", "airportsdata", + "outlines_core", ] dynamic = ["version"] @@ -94,7 +94,6 @@ write_to = "outlines/_version.py" testpaths = ["tests"] filterwarnings = [ "error", - "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", "ignore::FutureWarning:huggingface_hub.*", @@ -129,7 +128,6 @@ module = [ "lark.*", "interegular.*", "datasets.*", - "numba.*", "requests.*", "responses.*", "vllm.*", @@ -137,6 +135,7 @@ module = [ "fastapi.*", "pycountry.*", "airportsdata.*", + "outlines_core.*", ] ignore_missing_imports = true diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py deleted file mode 100644 index 94166fd95..000000000 --- a/tests/fsm/test_fsm.py +++ /dev/null @@ -1,92 +0,0 @@ -import pytest - -from outlines.fsm.fsm import RegexFSM, StopAtEosFSM - - -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) - - -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 - - with pytest.warns(UserWarning): - fsm = StopAtEosFSM(MockTokenizer()) - - assert fsm.allowed_token_ids(fsm.start_state) is None - assert fsm.allowed_token_ids(fsm.final_state) == [2] - assert fsm.next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexFSM(regex_str, MockTokenizer()) - - -def test_regex(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - assert fsm.states_to_token_maps == {0: {1: 1}} - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1]) - assert fsm.next_state(state=0, token_id=1) == 1 - assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - state = fsm.next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.next_state(state=5, token_id=103) - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py deleted file mode 100644 index 1789c4a7c..000000000 --- a/tests/fsm/test_regex.py +++ /dev/null @@ -1,742 +0,0 @@ -import interegular -import numba -import numpy as np -import pytest -from transformers import AutoTokenizer - -from outlines.fsm.regex import ( - _walk_fsm, - create_fsm_index_end_to_end, - create_fsm_index_tokenizer, - fsm_union, - get_sub_fsms_from_seq, - get_token_transition_keys, - get_vocabulary_transition_keys, - make_byte_level_better_fsm, - make_byte_level_fsm, - make_deterministic_fsm, - reduced_vocabulary, - walk_fsm, -) -from outlines.models.transformers import TransformerTokenizer - - -def identity(s): - return s - - -def to_bytes(s): - return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] - - -def merge_symbols(byte_hexs): - return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) - - -def token_str_to_trans_key(fsm, input_string): - return get_token_transition_keys( - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, - input_string, - ) - - -def walk_fsm_from_token_str( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return walk_fsm( - fsm, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match, - ) - - -def walk_fsm_from_token_str_numba( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return _walk_fsm( - fsm.fsm_info.transitions, - fsm.fsm_info.initial, - fsm.fsm_info.finals, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match=full_match, - ) - - -@pytest.mark.parametrize( - "function", - [ - walk_fsm_from_token_str, - walk_fsm_from_token_str_numba, - ], -) -def test_walk_fsm(function): - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple(function(regex_fsm, "0", regex_fsm.initial, full_match=True)) - assert res == (1,) - - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=False)) - assert res == (1,) - - res = tuple(function(regex_fsm, "!", regex_fsm.initial, full_match=True)) - assert res == tuple() - - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=True)) - assert res == tuple() - - # This should fail, because state `1` reads nothing - res = tuple(function(regex_fsm, "0", 1, full_match=True)) - assert res == tuple() - - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True)) - assert res == tuple() - - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False)) - assert res == (2,) - - res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True)) - assert res == (2, 3) - - pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") - fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - - res = tuple(function(fsm, "x ", fsm.initial, full_match=False)) - assert res == (2,) - - start_state = list(fsm.finals)[0] - res = tuple(function(fsm, "!", start_state, full_match=False)) - assert res == tuple() - - -@pytest.mark.parametrize( - "function", - [ - walk_fsm_from_token_str, - walk_fsm_from_token_str_numba, - ], -) -@pytest.mark.parametrize( - "transform", - [ - identity, - to_bytes, - ], -) -def test_walk_fsm_multi_bytes(function, transform): - regex_pattern = interegular.parse_pattern("😂|[😇-😍][😈-😍]*") - str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - - res = tuple( - function( - regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True - ) - ) - assert res[-1:] == (1,) - - res = tuple( - function( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=False, - ) - ) - assert res[-1:] == (1,) - - res = tuple( - function( - regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - res = tuple( - function( - regex_fsm, - merge_symbols(transform("😂😂")), - regex_fsm.initial, - full_match=True, - ) - ) - assert res == tuple() - - -def test_get_sub_fsms_from_seq(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - - match_pattern = interegular.parse_pattern("match") - match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) - - peq_pattern = interegular.parse_pattern(r"\+=") - peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) - - plus_pattern = interegular.parse_pattern(r"\+") - plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) - - fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] - - fsm, fsms_to_trans_finals = fsm_union(fsms) - - assert fsms_to_trans_finals == { - 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), - 1: ( - {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, - {8}, - {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, - ), - 2: ( - { - (0, 2), - (0, 3), - (0, 4), - (2, 2), - (3, 2), - (3, 9), - (4, 2), - (4, 5), - (5, 2), - (5, 6), - (6, 2), - (6, 7), - (7, 2), - (7, 8), - (8, 2), - (9, 2), - (9, 10), - (10, 2), - }, - {2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, - ), - 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), - 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), - } - - assert not fsm.accepts("1a") - assert fsm.accepts("a1") - assert fsm.accepts("def") - assert fsm.accepts("match") - assert fsm.accepts("+=") - assert fsm.accepts("+") - - state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) - state_seq.insert(0, fsm.fsm_info.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (2, True, True)] - - # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) - def_state_seq.insert(0, fsm.fsm_info.initial) - - def_old_to_new_states = fsms_to_trans_finals[0][2] - assert all( - new_state in def_old_to_new_states[old_state] - for old_state, new_state in zip(def_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) - name_state_seq.insert(0, fsm.initial) - - name_old_to_new_states = fsms_to_trans_finals[2][2] - assert all( - new_state in name_old_to_new_states[old_state] - for old_state, new_state in zip(name_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, False, True), (2, True, True)] - - match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) - match_state_seq.insert(0, fsm.initial) - - match_old_to_new_states = fsms_to_trans_finals[1][2] - assert all( - new_state in match_old_to_new_states[old_state] - for old_state, new_state in zip(match_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, True, False), (4, False, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, False, True)] - - # Test some overlapping patterns - join_fsms = [ - interegular.parse_pattern(r"JOIN").to_fsm().reduce(), - interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), - ] - fsm, fsms_to_trans_finals = fsm_union(join_fsms) - - # Matching "OI" - state_seq = [1, 2, 3] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (1, True, False)] - - # Matching "N" - state_seq = [3, 4] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (1, True, False)] - - # Matching " " - state_seq = [4, 5] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, True, False)] - - -def test_create_fsm_index_end_to_end(): - regex_str = "0|[1-9][0-9]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - vocabulary = { - "blah": numba.typed.List([0]), - "1a": numba.typed.List([1]), - "2": numba.typed.List([2]), - "0": numba.typed.List([3]), - "": numba.typed.List([4]), - } - - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) - for token_tuple, token_ids in vocabulary.items(): - token = merge_symbols(token_tuple) - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token, token_ids_np)) - - res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) - - assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} - - -def test_create_fsm_index_end_to_end_multi_byte(): - regex_str = "😇| [😈-😍][😇-😎]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - vocabulary = { - "blah": numba.typed.List([0]), - "😈a": numba.typed.List([1]), - "😇": numba.typed.List([2]), - "😍": numba.typed.List([3]), - merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' - " 😍": numba.typed.List([5]), - merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( - [7] - ), # ' 😍' incomplete - "": numba.typed.List([8]), - } - - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) - for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = merge_symbols(token_tuple) - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) - - res = create_fsm_index_end_to_end(byte_fsm.fsm_info, vocabulary_nb) - - assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} - - -@pytest.mark.parametrize( - "hf_tokenizer_uri, revision", - [ - ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), - ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), - ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), - ( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "783fd50eb82d7f57758de033861f54d62dde234f", - ), - ], -) -def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - num_bytes_fsm_states = len(bytes_fsm.states) - assert num_bytes_fsm_states == 235 - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) - tokenizer = TransformerTokenizer(tokenizer) - - states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( - bytes_fsm, tokenizer - ) - - assert not empty_token_ids - assert len(states_to_token_subsets) / num_fsm_states > 0.94 - - -@pytest.mark.parametrize( - "regex,string,should_accept", - [ - ("[a-c]+", "😀", False), - ("[^a-c]+", "😀", True), - ("😀+", "😀😀😀", True), - ("😀+", "a", False), - ("[😀-😍]{2}", "😈😈", True), - ("[😀-😍]{2}", "aa", False), - ("[^😀-😍]{2}", "aa", True), - ("[^😀-😍]{2}", "😈😈", False), - ("[^😀-😍]{2}", "😎😎", True), - ("[^😀-😍]{2}", "😎😓", True), - ("[^😀-😍]{2}", "😎😈", False), - ("[😀-🙌]{2}", "😎😈", True), - ("[^😀-🙌]{2}", "😎😈", False), - ("[^😀-🙌]{2}", "🙏🙏", True), - ("[^😀-🙌]{2}", "🙏😎", False), - ], -) -def test_make_byte_level_fsm(regex, string, should_accept): - str_fsm = interegular.parse_pattern(regex).to_fsm() - str_accepts = str_fsm.accepts(string) - assert str_accepts == should_accept - - byte_fsm = make_byte_level_fsm(str_fsm) - byte_accepts = byte_fsm.accepts(to_bytes(string)) # type: ignore - assert byte_accepts == str_accepts - - mix_fsm = make_byte_level_fsm(str_fsm, keep_utf8=True) - mix_accepts = mix_fsm.accepts(to_bytes(string)) # type: ignore - assert mix_accepts == str_accepts - - mix_accepts_utf8 = mix_fsm.accepts(string) # type: ignore - assert mix_accepts_utf8 == str_accepts - - def advance(fsm, state, seq): - for symbol in seq: - if state is None: - return None - key = fsm.alphabet[symbol] - state = fsm.map[state].get(key) - return state - - # verify each state along the pattern - str_state = str_fsm.initial - byte_state = byte_fsm.initial - mix_state = byte_fsm.initial - for symbol in string: - str_state = advance(str_fsm, str_state, symbol) - byte_state = advance(byte_fsm, byte_state, to_bytes(symbol)) - mix_state_utf8 = advance(mix_fsm, mix_state, symbol) - mix_state = advance(mix_fsm, mix_state, to_bytes(symbol)) - assert byte_state == str_state - assert mix_state == str_state - assert mix_state_utf8 == str_state - - -@pytest.mark.skip(reason="Only for local profiling") -def test_regex_index_performance(): - from line_profiler import LineProfiler # type: ignore [import] - - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TransformerTokenizer(tokenizer) - - # Pre-compile Numba functions - res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) - assert len(res) > 1 - - profiler = LineProfiler(create_fsm_index_end_to_end) - - profiler.runctx( - "create_fsm_index_tokenizer(regex_fsm, tokenizer)", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-create_fsm_index.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -@pytest.mark.skip(reason="Only for local profiling") -def test_json_index_performance(): - import json - from enum import Enum - - from line_profiler import LineProfiler # type: ignore [import] - from pydantic import BaseModel, constr - - import outlines - - class Weapon(str, Enum): - sword = "sword" - axe = "axe" - mace = "mace" - spear = "spear" - bow = "bow" - crossbow = "crossbow" - - class Armor(str, Enum): - leather = "leather" - chainmail = "chainmail" - plate = "plate" - - class Character(BaseModel): - name: constr(max_length=10) - # TODO: Add support for conint - age: int # conint(int, ge=18, le=100) - armor: Armor - weapon: Weapon - # TODO: Add support for conint - strength: int # conint(int, ge=0, le=100) - - model = outlines.models.transformers("gpt2", device="cuda") - json_schema = json.dumps(Character.model_json_schema()) - - def build_regex(): - regex_str = outlines.index.json_schema.build_regex_from_object(json_schema) - outlines.generate.regex(model, regex_str) - - profiler = LineProfiler(create_fsm_index_end_to_end) - profiler.add_function(create_fsm_index_tokenizer) - profiler.add_function(outlines.index.index.RegexFSM.__init__) - - profiler.runctx( - "build_regex()", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-build-json-regex.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -def test_token_trans_keys_identical(): - """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" - - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"z[ab]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), - ) - - token_str_to_tranition_keys = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) - } - # `a` and `b` both are workable, but `z` has distinct transition rules - assert interegular_fsm.accepts("zaz") - assert interegular_fsm.accepts("zbz") - assert (token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"]).all() - assert not ( - token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - ).all() - - -def test_token_trans_keys_walk_fsm(): - """assert _walk_fsm works using transition keys""" - - class MockTokenizer: - vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"a[bc]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), - ) - - token_str_trans_key_seq = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) - } - - # verify initial state valid only for "ab" and "ac" using transition key seq - token_acceptance = {"ab": True, "ac": True, "az": False} - for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_trans_key_seq[token] - state_seq = _walk_fsm( - regex_fsm.fsm_info.transitions, - regex_fsm.fsm_info.initial, - regex_fsm.fsm_info.finals, - token_trans_key_seq, - regex_fsm.fsm_info.initial, - False, - ) - is_accepted = len(state_seq) >= len(token_trans_key_seq) - assert should_accept == is_accepted - - -def test_numba_leading_null_byte_UnicodeCharSeq_remains_broken(): - """Assert numba UnicodeCharSeq w/ leading \x00 is still broken""" - # EXPLANATION: - # https://github.com/dottxt-ai/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.UnicodeCharSeq(1), numba.int64) - d["一"] = 10 # \xe4\xb8\x80 - with pytest.raises(KeyError): - str(d) - - # most characters are fine, but "\x00" is converted to "" - l = np.fromiter(["\x99", "\x00"], dtype=np.dtype("U2")) - assert str(l[0]) == "\x99" # fine - assert str(l[1]) == "" # 1-byte null converted to 0-bytes - - -@pytest.mark.parametrize("input_key", ["一", "\x00"]) -def test_numba_leading_null_byte_unicode_type_sane(input_key): - """Assert numba unicode_type w/ leading \x00 is working""" - # EXPLANATION: - # https://github.com/dottxt-ai/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64) - d["一"] = 10 # \xe4\xb8\x80 - str(d) # assert successfully interprets - - -@pytest.mark.parametrize( - "rare_token", - [ - "�", - "��", - "�.", - "�..", - "▁�", - "▁▁�", - "▁�.", - "▁�.", - "▁▁�..", - ], -) -def test_reduced_vocabulary_with_rare_tokens(rare_token): - """Assert reduced_vocabulary works with rare tokens. - - See [1] and [2] for context. - - [1]: https://github.com/dottxt-ai/outlines/pull/763 - [2]: https://github.com/dottxt-ai/outlines/pull/948 - [3]: https://github.com/dottxt-ai/outlines/pull/1153 - """ - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tokenizer = TransformerTokenizer(tokenizer=tokenizer) - tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 - reduced_vocabulary(tokenizer) - - -def test_reduced_vocabulary_with_byte_tokens(): - class MockTokenizer: - vocabulary = { - "string": 1, - b"\xa1": 2, # Qwen-Style - "eos": 3, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return b"\xef\xbf\xbd".decode() - - reduced_vocab = reduced_vocabulary(MockTokenizer()) - - # See fsm.regex.get_token_transition_keys() - # FSM transition keys represents bytes as - assert reduced_vocab[0][1][0] == "\x00A1" diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index cdd57e0c6..1d26a9ee4 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -5,11 +5,11 @@ import pytest import torch +from outlines_core.fsm.regex import reduced_vocabulary from pydantic import BaseModel, constr import outlines.generate as generate import outlines.models as models -from outlines.fsm.regex import reduced_vocabulary from outlines.models.transformers import Transformers, TransformerTokenizer from outlines.samplers import beam_search, greedy, multinomial