From 0ad30e070e0fb928220718723cd5c0e346c1f2ef Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 4 Jun 2024 16:59:59 -0500 Subject: [PATCH] Revert "Incorporate Trie into fsm index calculation" This reverts commit 591ad2a188608224b4303a55fb4c56366b4427bd. --- outlines/fsm/regex.py | 43 +++---- outlines/fsm/vocab_trie.py | 241 ------------------------------------- tests/fsm/test_regex.py | 17 +-- 3 files changed, 25 insertions(+), 276 deletions(-) delete mode 100644 outlines/fsm/vocab_trie.py diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index adcfb3993..6e2b81412 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -28,8 +28,6 @@ from numba.typed.typedobjectutils import _nonoptional from tqdm import tqdm -from outlines.fsm.vocab_trie import VocabTrie - if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -651,38 +649,29 @@ def state_scan_tokens( fsm_initial: int, fsm_finals: Set[int], vocabulary: List[Tuple[str, Sequence[int]]], - vocab_trie: VocabTrie, + token_trans_key_seqs: List[Sequence[int]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() - # Initialize the stack with tokens having no prefixes - stack = numba.typed.List() - for token_transitions_seq in vocab_trie.get_children(): - stack.append(token_transitions_seq) - - # Process the tokens using the stack - while len(stack) > 0: - token_transition_seq = stack.pop() + for (token, token_ids), token_trans_key_seq in zip( + vocabulary, token_trans_key_seqs + ): state_seq = _walk_fsm( fsm_transitions, fsm_initial, fsm_finals, - token_transition_seq, + token_trans_key_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token_transition_seq): + if state_seq is not None and len(state_seq) < len(token_trans_key_seq): continue - for token_id in vocab_trie.get_token_ids(token_transition_seq): + for token_id in token_ids: res.add((token_id, state_seq[-1])) - # Add successors to the stack - for new_token in vocab_trie.get_children(token_transition_seq): - stack.append(new_token) - return res @@ -713,7 +702,7 @@ def get_token_transitions( @numba.njit(cache=True, nogil=True) -def get_all_token_transitions( +def get_tokens_trans_keys( alphabet_symbol_mapping: Dict[str, int], alphabet_anything_value: int, vocabulary: List[Tuple[str, Sequence[int]]], @@ -740,20 +729,18 @@ def create_fsm_index_end_to_end( seen: Set[int] = set() next_states = {fsm_info.initial} - all_token_transitions = get_all_token_transitions( - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - ) - - vocab_trie = VocabTrie(all_token_transitions, vocabulary) - pbar = tqdm( total=len(set(fsm_info.transitions.values())) + 1, # all transitions plus initial desc="Compiling FSM index for all state transitions", ) + tokens_trans_key_seqs = get_tokens_trans_keys( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + while next_states: start_state = next_states.pop() @@ -764,7 +751,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, - vocab_trie, + tokens_trans_key_seqs, start_state, ) diff --git a/outlines/fsm/vocab_trie.py b/outlines/fsm/vocab_trie.py deleted file mode 100644 index 52d11b0cf..000000000 --- a/outlines/fsm/vocab_trie.py +++ /dev/null @@ -1,241 +0,0 @@ -import operator -from typing import List, Optional, Sequence, Tuple - -import numpy as np -from numba import njit, typed, types -from numba.cpython.hashing import ( - _Py_uhash_t, - _PyHASH_XXPRIME_1, - _PyHASH_XXPRIME_2, - _PyHASH_XXPRIME_5, - _PyHASH_XXROTATE, - process_return, -) -from numba.experimental import jitclass, structref -from numba.extending import overload -from numba.typed import Dict - -########################### -# Dict With Int[:] Key Impl -########################### - - -# Register type -@structref.register -class IntArrayDictType(types.StructRef): - """ - Represents a dictionary using int64[:] as keys, - intended for byte-level FSM representation with int64[:] transition. - """ - - def preprocess_fields(self, fields): - return tuple( - (name, typ.dtype if isinstance(typ, types.TypeRef) else typ) - for name, typ in fields - ) - - -class IntArrayDict(structref.StructRefProxy): - """Python proxy""" - - @property - def wrapped_dict(self): - return IntArrayDict_get_wrapped_dict(self) # noqa: F821 - - -structref.define_proxy(IntArrayDict, IntArrayDictType, ["wrapped_dict"]) - - -@njit -def hash_key(key): - """ - XXH64 Hash for int64[:] keys - adapted from https://github.com/numba/numba/blob/556545/numba/cpython/hashing.py - """ - acc = _PyHASH_XXPRIME_5 - for i in range(key.shape[0]): - x = key[i] - lane = hash(x) - if lane == _Py_uhash_t(-1): - return -1 - acc += lane * _PyHASH_XXPRIME_2 - acc = _PyHASH_XXROTATE(acc) - acc *= _PyHASH_XXPRIME_1 - - acc += key.shape[0] ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539)) - - if acc == _Py_uhash_t(-1): - return process_return(1546275796) - - return process_return(acc) - - -@overload(IntArrayDict) -def custom_int_array_dict_constructor(value_type): - if isinstance(value_type, types.Type): - - def impl(value_type): - wrapped_dictionary = Dict.empty(types.intp, value_type) - return IntArrayDict(wrapped_dictionary) - - return impl - - -@overload(operator.getitem) -def ol_int_array_dict_getitem(inst, key): - if isinstance(inst, IntArrayDictType): - - def impl(inst, key): - return inst.wrapped_dict[hash_key(key)] - - return impl - - -@overload(operator.setitem) -def ol_int_array_dict_setitem(inst, key, value): - if isinstance(inst, IntArrayDictType): - - def impl(inst, key, value): - inst.wrapped_dict[hash_key(key)] = value - - return impl - - -@overload(operator.contains) -def ol_int_array_dict_contains(inst, key): - if isinstance(inst, IntArrayDictType): - - def impl(inst, key): - return hash_key(key) in inst.wrapped_dict - - return impl - - -################# -# Vocab Trie Impl -################# - -nb_int64_array_type = types.int64[:] - -# use intp keys as that is the hash type, -# but the true key type is nb_int64_array_type -IntArrayToIntType = IntArrayDictType( - (("wrapped_dict", types.DictType(types.intp, types.int64)),) -) -IntArrayToIntArrayType = IntArrayDictType( - (("wrapped_dict", types.DictType(types.intp, nb_int64_array_type)),) -) - - -@jitclass( - [ - ("token_to_token_key", IntArrayToIntType), - ("token_key_to_token", types.DictType(types.int64, nb_int64_array_type)), - ( - "token_key_to_child_token_keys", - types.DictType(types.int64, nb_int64_array_type), - ), - ("token_to_token_ids", IntArrayToIntArrayType), - ], -) -class VocabTrie: - """ - VocabTrie: Class for efficient traversal of the vocabulary - Bidirectional mapping between trie node ID and nb_unichar_2_type token - - token_to_token_key: Dict[nb_unichar_2_array_type, int] - - token_key_to_token: Dict[int, nb_unichar_2_array_type] - Allow retrieval of children in trie - - token_key_to_child_token_keys: Dict[int, int64[:]] - Allow retrieval of of token_ids for a given token - - token_to_token_ids: Dict[nb_unichar_2_array_type, int64[:]] - Trie structure: - Only members of the vocabulary are included as nodes, no intermediates. - Structured to guarantee that recursive calls to get_children() - will return every token once, only once. - Given a vocabulary of ["a", "ab", "abc", "ac", "ace", "apple"], - the children of "a" are "ab", "ac", "apple". - "abc" and "ace" are excluded because they have intermediate parents in the vocabulary. - """ - - def __init__( - self, - all_token_transitions: List[Sequence[int]], - vocabulary: List[Tuple[str, Sequence[int]]], - ): - self.token_to_token_key = IntArrayDict( - typed.Dict.empty(types.intp, types.int64) - ) - self.token_key_to_token = typed.Dict.empty( - key_type=types.int64, value_type=nb_int64_array_type - ) - self.token_key_to_child_token_keys = typed.Dict.empty( - key_type=types.int64, value_type=nb_int64_array_type - ) - self.token_to_token_ids = IntArrayDict( - typed.Dict.empty(types.intp, nb_int64_array_type) - ) - - self._insert(all_token_transitions, vocabulary) - - def _insert( - self, - all_token_transitions: List[Sequence[int]], - vocabulary: List[Tuple[str, Sequence[int]]], - ) -> None: - # Initialize an empty array for the root token key to store child token keys - self.token_key_to_child_token_keys[-1] = np.empty((0,), types.int64) - - # It's necessary to insert shorter transition sequences (prefixes) first - sorted_idx_transition_seq = sorted( - enumerate(all_token_transitions), key=lambda x: len(x[1]) - ) - - for idx, token_transitions in sorted_idx_transition_seq: - token_ids = vocabulary[idx][1] - if token_transitions not in self.token_to_token_key: - # create bimapping between token and token_key (tokens trie node key) - self.token_to_token_key[token_transitions] = idx - self.token_key_to_token[idx] = token_transitions - - # find parent token key - parent_token_key = -1 # root token - for i in range(len(token_transitions) - 1, -1, -1): - prefix_token = token_transitions[:i] - - if prefix_token in self.token_to_token_key: - parent_token_key = self.token_to_token_key[prefix_token] - break - # map parent token to current token - self.token_key_to_child_token_keys[parent_token_key] = np.append( - self.token_key_to_child_token_keys[parent_token_key], - np.array([idx]), - ) - - # map current token to empty list of children - self.token_key_to_child_token_keys[idx] = np.empty((0,), types.int64) - - # set current tokens token ids - self.token_to_token_ids[token_transitions] = token_ids - - else: - # if exists, append to current tokens token ids - self.token_to_token_ids[token_transitions] = np.append( - self.token_to_token_ids[token_transitions], token_ids - ) - - def get_children(self, token_transitions: Optional[Sequence[int]] = None): - """ - Get the token_ids of all children for the given token_id. - If token_id is None, get the root children. - """ - if token_transitions is None: - token_key = -1 - else: - token_key = self.token_to_token_key[token_transitions] - - child_token_keys = self.token_key_to_child_token_keys[token_key] - - return [self.token_key_to_token[token_key] for token_key in child_token_keys] - - def get_token_ids(self, token): - return self.token_to_token_ids[token] diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index eeeafcb07..1e14182b5 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -9,9 +9,8 @@ create_fsm_index_end_to_end, create_fsm_index_tokenizer, fsm_union, - get_all_token_transitions, get_sub_fsms_from_seq, - get_token_transitions, + get_tokens_trans_keys, make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, @@ -34,11 +33,15 @@ def merge_symbols(byte_hexs): def token_str_to_trans_key(fsm, input_string): - return get_token_transitions( + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple((numba.types.unicode_type, numba.int64[:])) + ) + vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64")))) + return get_tokens_trans_keys( fsm.fsm_info.alphabet_symbol_mapping, fsm.fsm_info.alphabet_anything_value, - input_string, - ) + vocabulary_nb, + )[0] def walk_fsm_from_token_str( @@ -595,7 +598,7 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_all_token_transitions( + token_trans_keys = get_tokens_trans_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, vocabulary, @@ -630,7 +633,7 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_all_token_transitions( + token_trans_keys = get_tokens_trans_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, vocabulary,