Skip to content

Commit

Permalink
index token -> transition key sequence for efficient fsm walk
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 27, 2024
1 parent 3f499bc commit 53d8a8d
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _walk_fsm(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -428,19 +428,7 @@ def _walk_fsm(
# By default, each symbol is a unicode character
# Except, if the character, input_string[i] == '\x00', then the next two
# in input_string characters are a hex representation of the byte
i = 0
while i < len(input_string):
# if null-byte prefixed its a hex representation
# unless its the last character, then its a trailing null byte symbol
if input_string[i] == "\x00" and i != len(input_string) - 1:
symbol = input_string[i : i + 3]
i += 3
else:
symbol = input_string[i]
i += 1

trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -677,27 +665,26 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, Sequence[int]]],
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_trans_key_seq,
start_state,
False,
)

if token == "\x00":
token_length = 1
else:
token_length = len(token) - 2 * token.count("\x00")
if state_seq is not None and len(state_seq) < token_length:
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in token_ids:
Expand All @@ -706,6 +693,37 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Tuple[str, Sequence[int], Sequence[int]]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]

tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, Sequence[int]]],
Expand All @@ -724,6 +742,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -734,6 +758,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
tokens_trans_key_seqs,
start_state,
)

Expand Down

0 comments on commit 53d8a8d

Please sign in to comment.