diff --git a/outlines/fsm/parsing.py b/outlines/fsm/parsing.py index 9ebc2af55..19deb975e 100644 --- a/outlines/fsm/parsing.py +++ b/outlines/fsm/parsing.py @@ -38,6 +38,7 @@ from outlines.fsm.regex import ( fsm_union, get_sub_fsms_from_seq, + get_token_transitions, make_deterministic_fsm, walk_fsm, ) @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None) text_part = text[start_pos:] + text_transitions = get_token_transitions( + self.fsm.fsm_info.alphabet_symbol_mapping, + self.fsm.fsm_info.alphabet_anything_value, + text_part, + ) + state_seq = walk_fsm( self.fsm, - text_part, + text_transitions, start_state, full_match=self.match_whole, ) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index b68e31897..6e2b81412 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -87,14 +87,11 @@ def fsm_info(self): ((k, z) for k, v in self.trans_key_to_states.items() for z in v), dtype=np.dtype("int64, int64"), ) - alphabet_symbol_mapping_items = np.fromiter( - ( - it - for it in self.alphabet._symbol_mapping.items() - if it[0] != anything_else - ), - dtype=np.dtype("U2, int64"), - ) + alphabet_symbol_mapping_items = [ + (k, v) + for k, v in self.alphabet._symbol_mapping.items() + if k != anything_else + ] nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) self.__dict__["_fsm_info"] = create_fsm_info( self.initial, @@ -110,7 +107,7 @@ def fsm_info(self): nb_int_list_type = numba.types.ListType(numba.int64) nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unichar_2_type = numba.types.UnicodeCharSeq(2) +nb_unicode_type = numba.types.unicode_type @numba.njit(cache=True) @@ -136,7 +133,7 @@ def create_fsm_info( # use 2-char strings so that we can represent incomplete utf-8 sequences # as 2-hex-digit pairs - alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64) + alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) for symbol_and_trans_key in alphabet_symbol_mapping_items: alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] @@ -199,7 +196,7 @@ def transition_trie_setdefault( def byte_symbol(byte: int) -> str: - return f"{byte:02X}" if byte >= 0x80 else chr(byte) + return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: @@ -415,11 +412,9 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: @numba.njit(nogil=True, cache=True) def _walk_fsm( fsm_transitions: Dict[Tuple[int, int], int], - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -427,9 +422,9 @@ def _walk_fsm( accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) last_final_idx: int = numba.uint64(0) - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -453,7 +448,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -463,13 +458,11 @@ def walk_fsm( accepted_states: List[int] = [] last_final_idx: int = 0 - alphabet_symbol_mapping = fsm.alphabet._symbol_mapping - alphabet_anything_value = fsm.alphabet.anything_value fsm_transitions = fsm.flat_transition_map - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -655,24 +648,25 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], + token_trans_key_seqs: List[Sequence[int]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for token, token_ids in vocabulary: + for (token, token_ids), token_trans_key_seq in zip( + vocabulary, token_trans_key_seqs + ): state_seq = _walk_fsm( fsm_transitions, - alphabet_symbol_mapping, - alphabet_anything_value, fsm_initial, fsm_finals, - token, + token_trans_key_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token): + if state_seq is not None and len(state_seq) < len(token_trans_key_seq): continue for token_id in token_ids: @@ -681,9 +675,51 @@ def state_scan_tokens( return res +@numba.njit(cache=True, nogil=True) +def get_token_transitions( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + token_str: str, +) -> Sequence[int]: + trans_key_seq = [] + i = 0 + while i < len(token_str): + if token_str[i] == "\x00" and i != len(token_str) - 1: + symbol = token_str[i : i + 3] + i += 3 + else: + symbol = token_str[i] + i += 1 + + trans_key_seq.append( + alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + ) + + trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64) + for j in range(len(trans_key_seq)): + trans_key_seq_array[j] = trans_key_seq[j] + return trans_key_seq_array + + +@numba.njit(cache=True, nogil=True) +def get_tokens_trans_keys( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + vocabulary: List[Tuple[str, Sequence[int]]], +) -> List[Sequence[int]]: + tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:]) + for token_str, _ in vocabulary: + trans_key_seq_array = get_token_transitions( + alphabet_symbol_mapping, alphabet_anything_value, token_str + ) + tokens_trans_keys.append(trans_key_seq_array) + + return tokens_trans_keys + + def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], ) -> Dict[int, Set[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" @@ -699,6 +735,12 @@ def create_fsm_index_end_to_end( desc="Compiling FSM index for all state transitions", ) + tokens_trans_key_seqs = get_tokens_trans_keys( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + while next_states: start_state = next_states.pop() @@ -709,6 +751,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, + tokens_trans_key_seqs, start_state, ) @@ -771,7 +814,7 @@ def gpt2_unicode_to_bytes(): @lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" empty_token_ids = set() vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} @@ -804,7 +847,7 @@ def reduced_vocabulary( raise RuntimeError( f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" ) - token_str = tuple(byte_symbol(b) for b in token_bytes) + token_str = "".join(byte_symbol(b) for b in token_bytes) vocabulary.setdefault(token_str, []).append(token_idx) else: @@ -813,15 +856,14 @@ def reduced_vocabulary( vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - nb_unichar_2_type[:], + nb_unicode_type, numba.int64[:], ) ) ) - for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + for token_str, token_ids in vocabulary.items(): token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token_str, token_ids_np)) return vocabulary_nb, empty_token_ids diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 4041c54fb..8e18c33e7 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -26,7 +26,7 @@ """ import math -from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import numpy as np import torch @@ -36,37 +36,12 @@ from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer if TYPE_CHECKING: from llama_cpp import Llama -class LlamaCppTokenizer: - def __init__(self, model: "Llama"): - self.eos_token_id = model.token_eos() - self.eos_token = model.tokenizer().decode([self.eos_token_id]) - self.pad_token_id = self.eos_token_id - self.special_tokens: Set[int] = set() - - self.vocabulary: Dict[str, int] = dict() - - tokenizer = model.tokenizer() - - self.decode = tokenizer.decode - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - try: - self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() - except AttributeError: - # ### - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - def convert_token_to_string(self, token: str) -> str: - return token - - class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 5920f08d6..b57eae4db 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,15 +1,97 @@ import dataclasses +import pickle import warnings -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if not isinstance(other, LlamaCppTokenizer): + return False + return self.__getstate__() == other.__getstate__() + + def __hash__(self): + # cache object hash + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + self.vocabulary, + self.eos_token_id, + self.eos_token, + self.pad_token_id, + self.special_tokens, + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + class LlamaCppParams(TypedDict, total=False): suffix: Optional[str] temperature: float diff --git a/tests/fsm/test_parsing.py b/tests/fsm/test_parsing.py index 4e093a994..b624fddee 100644 --- a/tests/fsm/test_parsing.py +++ b/tests/fsm/test_parsing.py @@ -9,7 +9,14 @@ from outlines.fsm.parsing import PartialLark, PartialPythonIndenter -def test_partial_parsing(): +@pytest.fixture +def cleanup_lark_import(): + yield + # Clean up lark.lark.LarkOptions._defaults + importlib.reload(lark.lark) + + +def test_partial_parsing(cleanup_lark_import): lp = PartialLark.open_from_package( "tests", "partial_python.lark", @@ -136,11 +143,8 @@ def test_partial_parsing(): assert len(parser_state.state_stack) == 4 assert parser_state.value_stack[-1].type == "LPAR" - # Clean up lark.lark.LarkOptions._defaults - importlib.reload(lark.lark) - -def test_sequential_parse_example(): +def test_sequential_parse_example(cleanup_lark_import): input_tokens = [ "x ", "= ", diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2fc8a5384..f1bf0f06a 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,5 +1,3 @@ -from typing import Sequence - import interegular import numba import numpy as np @@ -12,9 +10,11 @@ create_fsm_index_tokenizer, fsm_union, get_sub_fsms_from_seq, + get_tokens_trans_keys, make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, + reduced_vocabulary, walk_fsm, ) from outlines.models.transformers import TransformerTokenizer @@ -25,22 +25,50 @@ def identity(s): def to_bytes(s): - return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")] + return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] + + +def merge_symbols(byte_hexs): + return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) + + +def token_str_to_trans_key(fsm, input_string): + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple((numba.types.unicode_type, numba.int64[:])) + ) + vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64")))) + return get_tokens_trans_keys( + fsm.fsm_info.alphabet_symbol_mapping, + fsm.fsm_info.alphabet_anything_value, + vocabulary_nb, + )[0] + + +def walk_fsm_from_token_str( + fsm, + input_string: str, + start_state: int, + full_match: bool = True, +): + return walk_fsm( + fsm, + token_str_to_trans_key(fsm, input_string), + start_state, + full_match, + ) -def walk_fsm_numba( +def walk_fsm_from_token_str_numba( fsm, - input_string: Sequence[str], + input_string: str, start_state: int, full_match: bool = True, ): return _walk_fsm( fsm.fsm_info.transitions, - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, fsm.fsm_info.initial, fsm.fsm_info.finals, - input_string, + token_str_to_trans_key(fsm, input_string), start_state, full_match=full_match, ) @@ -49,8 +77,8 @@ def walk_fsm_numba( @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) def test_walk_fsm(function): @@ -99,8 +127,8 @@ def test_walk_fsm(function): @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) @pytest.mark.parametrize( @@ -115,19 +143,37 @@ def test_walk_fsm_multi_bytes(function, transform): str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True + ) + ) assert res[-1:] == (1,) res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False) + function( + regex_fsm, + merge_symbols(transform("😂😂")), + regex_fsm.initial, + full_match=False, + ) ) assert res[-1:] == (1,) - res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True + ) + ) assert res == tuple() res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True) + function( + regex_fsm, + merge_symbols(transform("😂😂")), + regex_fsm.initial, + full_match=True, + ) ) assert res == tuple() @@ -194,14 +240,14 @@ def test_get_sub_fsms_from_seq(): assert fsm.accepts("+=") assert fsm.accepts("+") - state_seq = walk_fsm(fsm, "def", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) state_seq.insert(0, fsm.fsm_info.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, False, True), (2, True, True)] # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm(def_fsm, "def", fsm.initial) + def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) def_state_seq.insert(0, fsm.fsm_info.initial) def_old_to_new_states = fsms_to_trans_finals[0][2] @@ -210,13 +256,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(def_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "ef", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - name_state_seq = walk_fsm(name_fsm, "ef", fsm.initial) + name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) name_state_seq.insert(0, fsm.initial) name_old_to_new_states = fsms_to_trans_finals[2][2] @@ -225,13 +271,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(name_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "match", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(1, False, True), (2, True, True)] - match_state_seq = walk_fsm(match_fsm, "match", fsm.initial) + match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) match_state_seq.insert(0, fsm.initial) match_old_to_new_states = fsms_to_trans_finals[1][2] @@ -240,25 +286,25 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(match_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "defa", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - state_seq = walk_fsm(fsm, "de", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, True, False), (2, True, True)] - state_seq = walk_fsm(fsm, "+", fsm.initial, False) + state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(3, True, False), (4, False, True)] - state_seq = walk_fsm(fsm, "+=", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) @@ -304,15 +350,15 @@ def test_create_fsm_index_end_to_end(): vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token, token_ids_np)) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) @@ -331,23 +377,25 @@ def test_create_fsm_index_end_to_end_multi_byte(): "😈a": numba.typed.List([1]), "😇": numba.typed.List([2]), "😍": numba.typed.List([3]), - ("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍' + merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' " 😍": numba.typed.List([5]), - (" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍' - (" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete + merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( + [7] + ), # ' 😍' incomplete "": numba.typed.List([8]), } vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) @@ -356,7 +404,16 @@ def test_create_fsm_index_end_to_end_multi_byte(): assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} -def test_create_fsm_index_tokenizer(): +@pytest.mark.parametrize( + "hf_tokenizer_uri", + [ + "gpt2", + "microsoft/phi-2", + "Qwen/Qwen1.5-0.5B-Chat", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + ], +) +def test_create_fsm_index_tokenizer(hf_tokenizer_uri): # The combined regular expressions of a lexer state in a Python grammar regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" @@ -371,7 +428,7 @@ def test_create_fsm_index_tokenizer(): num_bytes_fsm_states = len(bytes_fsm.states) assert num_bytes_fsm_states == 235 - tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( @@ -521,3 +578,83 @@ def build_regex(): ) profiler.dump_stats("line-profiler-build-json-regex.pkl") profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) + + +def test_token_trans_keys_identical(): + """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" + + class MockTokenizer: + vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"z[ab]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_tokens_trans_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + ) + + token_str_trans_key_seq = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + # `a` and `b` both are workable, but `z` has distinct transition rules + assert interegular_fsm.accepts("zaz") + assert interegular_fsm.accepts("zbz") + assert (token_str_trans_key_seq["a"] == token_str_trans_key_seq["b"]).all() + assert not (token_str_trans_key_seq["a"] == token_str_trans_key_seq["z"]).all() + + +def test_token_trans_keys_walk_fsm(): + """assert _walk_fsm works using transition keys""" + + class MockTokenizer: + vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"a[bc]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_tokens_trans_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + ) + + token_str_trans_key_seq = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + + # verify initial state valid only for "ab" and "ac" using transition key seq + token_acceptance = {"ab": True, "ac": True, "az": False} + for token, should_accept in token_acceptance.items(): + token_trans_key_seq = token_str_trans_key_seq[token] + state_seq = _walk_fsm( + regex_fsm.fsm_info.transitions, + regex_fsm.fsm_info.initial, + regex_fsm.fsm_info.finals, + token_trans_key_seq, + regex_fsm.fsm_info.initial, + False, + ) + is_accepted = len(state_seq) >= len(token_trans_key_seq) + assert should_accept == is_accepted diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py new file mode 100644 index 000000000..ef7e40eed --- /dev/null +++ b/tests/generate/conftest.py @@ -0,0 +1,24 @@ +from importlib import reload + +import pytest + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 75d0e4cef..b7eb8b3cb 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -279,3 +279,58 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): model = models.llamacpp(repo, model_path) with pytest.raises(RuntimeError): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) + + +def test_RegexGuide_caching(model, temp_cache_dir): + import llama_cpp + + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + generator = generate.regex(model, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.llamacpp( + "Qwen/Qwen1.5-0.5B-Chat-GGUF", + "*q2*.gguf", + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + "Qwen/Qwen1.5-0.5B-Chat" + ), + ) + generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert ( + generator.logits_processor.fsm.states_to_token_maps + != generator_2.logits_processor.fsm.states_to_token_maps + ) + + generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (1, 2) + assert ( + generator_2.logits_processor.fsm.states_to_token_maps + == generator_3.logits_processor.fsm.states_to_token_maps + ) + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index cee3ca312..da08bed71 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,7 +1,6 @@ import datetime import re from enum import Enum -from importlib import reload from typing import List, Union import pytest @@ -15,27 +14,6 @@ from outlines.samplers import beam_search, greedy, multinomial -@pytest.fixture -def temp_cache_dir(): - import os - import tempfile - - import outlines.caching - import outlines.fsm.guide - - with tempfile.TemporaryDirectory() as tempdir: - os.environ["OUTLINES_CACHE_DIR"] = tempdir - outlines.caching.get_cache.cache_clear() - reload(outlines) - reload(outlines.fsm.guide) - cache_status = outlines.caching._caching_enabled - try: - outlines.caching._caching_enabled = True - yield - finally: - outlines.caching._caching_enabled = cache_status - - def test_transformers_integration_text(): rng = torch.Generator() rng.manual_seed(10000) # Choosen so is generated