From 4618f3f577c65cde275c4f0a2a0004d2eea72cf3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 31 May 2024 17:18:35 -0500 Subject: [PATCH] Incorporate Trie into fsm index calculation --- outlines/fsm/regex.py | 43 ++++--- outlines/fsm/vocab_trie.py | 241 +++++++++++++++++++++++++++++++++++++ tests/fsm/test_regex.py | 17 ++- 3 files changed, 276 insertions(+), 25 deletions(-) create mode 100644 outlines/fsm/vocab_trie.py diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 6e2b81412..adcfb3993 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -28,6 +28,8 @@ from numba.typed.typedobjectutils import _nonoptional from tqdm import tqdm +from outlines.fsm.vocab_trie import VocabTrie + if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -649,29 +651,38 @@ def state_scan_tokens( fsm_initial: int, fsm_finals: Set[int], vocabulary: List[Tuple[str, Sequence[int]]], - token_trans_key_seqs: List[Sequence[int]], + vocab_trie: VocabTrie, start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for (token, token_ids), token_trans_key_seq in zip( - vocabulary, token_trans_key_seqs - ): + # Initialize the stack with tokens having no prefixes + stack = numba.typed.List() + for token_transitions_seq in vocab_trie.get_children(): + stack.append(token_transitions_seq) + + # Process the tokens using the stack + while len(stack) > 0: + token_transition_seq = stack.pop() state_seq = _walk_fsm( fsm_transitions, fsm_initial, fsm_finals, - token_trans_key_seq, + token_transition_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token_trans_key_seq): + if state_seq is not None and len(state_seq) < len(token_transition_seq): continue - for token_id in token_ids: + for token_id in vocab_trie.get_token_ids(token_transition_seq): res.add((token_id, state_seq[-1])) + # Add successors to the stack + for new_token in vocab_trie.get_children(token_transition_seq): + stack.append(new_token) + return res @@ -702,7 +713,7 @@ def get_token_transitions( @numba.njit(cache=True, nogil=True) -def get_tokens_trans_keys( +def get_all_token_transitions( alphabet_symbol_mapping: Dict[str, int], alphabet_anything_value: int, vocabulary: List[Tuple[str, Sequence[int]]], @@ -729,18 +740,20 @@ def create_fsm_index_end_to_end( seen: Set[int] = set() next_states = {fsm_info.initial} + all_token_transitions = get_all_token_transitions( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + + vocab_trie = VocabTrie(all_token_transitions, vocabulary) + pbar = tqdm( total=len(set(fsm_info.transitions.values())) + 1, # all transitions plus initial desc="Compiling FSM index for all state transitions", ) - tokens_trans_key_seqs = get_tokens_trans_keys( - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - ) - while next_states: start_state = next_states.pop() @@ -751,7 +764,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, - tokens_trans_key_seqs, + vocab_trie, start_state, ) diff --git a/outlines/fsm/vocab_trie.py b/outlines/fsm/vocab_trie.py new file mode 100644 index 000000000..52d11b0cf --- /dev/null +++ b/outlines/fsm/vocab_trie.py @@ -0,0 +1,241 @@ +import operator +from typing import List, Optional, Sequence, Tuple + +import numpy as np +from numba import njit, typed, types +from numba.cpython.hashing import ( + _Py_uhash_t, + _PyHASH_XXPRIME_1, + _PyHASH_XXPRIME_2, + _PyHASH_XXPRIME_5, + _PyHASH_XXROTATE, + process_return, +) +from numba.experimental import jitclass, structref +from numba.extending import overload +from numba.typed import Dict + +########################### +# Dict With Int[:] Key Impl +########################### + + +# Register type +@structref.register +class IntArrayDictType(types.StructRef): + """ + Represents a dictionary using int64[:] as keys, + intended for byte-level FSM representation with int64[:] transition. + """ + + def preprocess_fields(self, fields): + return tuple( + (name, typ.dtype if isinstance(typ, types.TypeRef) else typ) + for name, typ in fields + ) + + +class IntArrayDict(structref.StructRefProxy): + """Python proxy""" + + @property + def wrapped_dict(self): + return IntArrayDict_get_wrapped_dict(self) # noqa: F821 + + +structref.define_proxy(IntArrayDict, IntArrayDictType, ["wrapped_dict"]) + + +@njit +def hash_key(key): + """ + XXH64 Hash for int64[:] keys + adapted from https://github.com/numba/numba/blob/556545/numba/cpython/hashing.py + """ + acc = _PyHASH_XXPRIME_5 + for i in range(key.shape[0]): + x = key[i] + lane = hash(x) + if lane == _Py_uhash_t(-1): + return -1 + acc += lane * _PyHASH_XXPRIME_2 + acc = _PyHASH_XXROTATE(acc) + acc *= _PyHASH_XXPRIME_1 + + acc += key.shape[0] ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539)) + + if acc == _Py_uhash_t(-1): + return process_return(1546275796) + + return process_return(acc) + + +@overload(IntArrayDict) +def custom_int_array_dict_constructor(value_type): + if isinstance(value_type, types.Type): + + def impl(value_type): + wrapped_dictionary = Dict.empty(types.intp, value_type) + return IntArrayDict(wrapped_dictionary) + + return impl + + +@overload(operator.getitem) +def ol_int_array_dict_getitem(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return inst.wrapped_dict[hash_key(key)] + + return impl + + +@overload(operator.setitem) +def ol_int_array_dict_setitem(inst, key, value): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key, value): + inst.wrapped_dict[hash_key(key)] = value + + return impl + + +@overload(operator.contains) +def ol_int_array_dict_contains(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return hash_key(key) in inst.wrapped_dict + + return impl + + +################# +# Vocab Trie Impl +################# + +nb_int64_array_type = types.int64[:] + +# use intp keys as that is the hash type, +# but the true key type is nb_int64_array_type +IntArrayToIntType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, types.int64)),) +) +IntArrayToIntArrayType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, nb_int64_array_type)),) +) + + +@jitclass( + [ + ("token_to_token_key", IntArrayToIntType), + ("token_key_to_token", types.DictType(types.int64, nb_int64_array_type)), + ( + "token_key_to_child_token_keys", + types.DictType(types.int64, nb_int64_array_type), + ), + ("token_to_token_ids", IntArrayToIntArrayType), + ], +) +class VocabTrie: + """ + VocabTrie: Class for efficient traversal of the vocabulary + Bidirectional mapping between trie node ID and nb_unichar_2_type token + - token_to_token_key: Dict[nb_unichar_2_array_type, int] + - token_key_to_token: Dict[int, nb_unichar_2_array_type] + Allow retrieval of children in trie + - token_key_to_child_token_keys: Dict[int, int64[:]] + Allow retrieval of of token_ids for a given token + - token_to_token_ids: Dict[nb_unichar_2_array_type, int64[:]] + Trie structure: + Only members of the vocabulary are included as nodes, no intermediates. + Structured to guarantee that recursive calls to get_children() + will return every token once, only once. + Given a vocabulary of ["a", "ab", "abc", "ac", "ace", "apple"], + the children of "a" are "ab", "ac", "apple". + "abc" and "ace" are excluded because they have intermediate parents in the vocabulary. + """ + + def __init__( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ): + self.token_to_token_key = IntArrayDict( + typed.Dict.empty(types.intp, types.int64) + ) + self.token_key_to_token = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_key_to_child_token_keys = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_to_token_ids = IntArrayDict( + typed.Dict.empty(types.intp, nb_int64_array_type) + ) + + self._insert(all_token_transitions, vocabulary) + + def _insert( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ) -> None: + # Initialize an empty array for the root token key to store child token keys + self.token_key_to_child_token_keys[-1] = np.empty((0,), types.int64) + + # It's necessary to insert shorter transition sequences (prefixes) first + sorted_idx_transition_seq = sorted( + enumerate(all_token_transitions), key=lambda x: len(x[1]) + ) + + for idx, token_transitions in sorted_idx_transition_seq: + token_ids = vocabulary[idx][1] + if token_transitions not in self.token_to_token_key: + # create bimapping between token and token_key (tokens trie node key) + self.token_to_token_key[token_transitions] = idx + self.token_key_to_token[idx] = token_transitions + + # find parent token key + parent_token_key = -1 # root token + for i in range(len(token_transitions) - 1, -1, -1): + prefix_token = token_transitions[:i] + + if prefix_token in self.token_to_token_key: + parent_token_key = self.token_to_token_key[prefix_token] + break + # map parent token to current token + self.token_key_to_child_token_keys[parent_token_key] = np.append( + self.token_key_to_child_token_keys[parent_token_key], + np.array([idx]), + ) + + # map current token to empty list of children + self.token_key_to_child_token_keys[idx] = np.empty((0,), types.int64) + + # set current tokens token ids + self.token_to_token_ids[token_transitions] = token_ids + + else: + # if exists, append to current tokens token ids + self.token_to_token_ids[token_transitions] = np.append( + self.token_to_token_ids[token_transitions], token_ids + ) + + def get_children(self, token_transitions: Optional[Sequence[int]] = None): + """ + Get the token_ids of all children for the given token_id. + If token_id is None, get the root children. + """ + if token_transitions is None: + token_key = -1 + else: + token_key = self.token_to_token_key[token_transitions] + + child_token_keys = self.token_key_to_child_token_keys[token_key] + + return [self.token_key_to_token[token_key] for token_key in child_token_keys] + + def get_token_ids(self, token): + return self.token_to_token_ids[token] diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index f1bf0f06a..14f606f47 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -9,8 +9,9 @@ create_fsm_index_end_to_end, create_fsm_index_tokenizer, fsm_union, + get_all_token_transitions, get_sub_fsms_from_seq, - get_tokens_trans_keys, + get_token_transitions, make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, @@ -33,15 +34,11 @@ def merge_symbols(byte_hexs): def token_str_to_trans_key(fsm, input_string): - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple((numba.types.unicode_type, numba.int64[:])) - ) - vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64")))) - return get_tokens_trans_keys( + return get_token_transitions( fsm.fsm_info.alphabet_symbol_mapping, fsm.fsm_info.alphabet_anything_value, - vocabulary_nb, - )[0] + input_string, + ) def walk_fsm_from_token_str( @@ -598,7 +595,7 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_tokens_trans_keys( + token_trans_keys = get_all_token_transitions( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, vocabulary, @@ -633,7 +630,7 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_tokens_trans_keys( + token_trans_keys = get_all_token_transitions( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, vocabulary,