diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index e347a31cb..0b35c978f 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -414,7 +414,7 @@ def _walk_fsm( fsm_transitions: Dict[Tuple[int, int], int], fsm_initial: int, fsm_finals: Set[int], - token_trans_key_seq: Sequence[int], + token_transition_keys: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -424,7 +424,7 @@ def _walk_fsm( # Iterate over token transition key sequence. The transition key # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_trans_key_seq): + for i, trans_key in enumerate(token_transition_keys): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -448,7 +448,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - token_trans_key_seq: Sequence[int], + token_transition_keys: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -462,7 +462,7 @@ def walk_fsm( # Iterate over token transition key sequence. The transition key # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_trans_key_seq): + for i, trans_key in enumerate(token_transition_keys): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -703,10 +703,10 @@ def get_token_transition_keys( alphabet_symbol_mapping.get(symbol, alphabet_anything_value) ) - tok_trans_array = np.empty(len(token_transition_keys), dtype=np.int64) + token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64) for j in range(len(token_transition_keys)): - tok_trans_array[j] = token_transition_keys[j] - return tok_trans_array + token_transition_keys_array[j] = token_transition_keys[j] + return token_transition_keys_array @numba.njit(cache=True, nogil=True) @@ -718,14 +718,14 @@ def get_vocabulary_transition_keys( """ Calculate the sequence transition keys for each token str within a vocabulary """ - tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:]) + vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:]) for token_str, _ in vocabulary: - trans_key_seq_array = get_token_transition_keys( + token_transition_keys = get_token_transition_keys( alphabet_symbol_mapping, alphabet_anything_value, token_str ) - tokens_trans_keys.append(trans_key_seq_array) + vocab_transition_keys.append(token_transition_keys) - return tokens_trans_keys + return vocab_transition_keys def create_fsm_index_end_to_end(