Skip to content

Commit

Permalink
njitize create_fsm_index_end_to_end
Browse files Browse the repository at this point in the history
use updated interegular

revert

fix

readd type hint and comment

use list of lists instead of dict for numba, get rid of reflected set warning by using a list of finals

don't specify git version for interegular

include empty states

remove unused def

fix initial state bug

fix bug involving end states allowing continuation

oops thats not part of this pr

fix test by making conform to new type
  • Loading branch information
Andrew Lapp committed Jan 11, 2024
1 parent 8b1ff9a commit 903a084
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 116 deletions.
38 changes: 20 additions & 18 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,40 +491,37 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
) -> List[List[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
finals_list = numba.typed.List.empty_list(numba.int64)
for final in fsm_info.finals:
finals_list.append(final)

while next_states:
start_state = next_states.pop()
all_states = set(fsm_info.transitions.values())
all_states.add(fsm_info.initial)
num_states = max(all_states)
states_to_token_subsets = numba.typed.List(
[numba.typed.List.empty_list(nb_int_pair_type) for _ in range(num_states + 1)]
)

for state_id in all_states:
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
finals_list,
vocabulary,
start_state,
state_id,
)

for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)

seen.add(start_state)
states_to_token_subsets[state_id].append(token_id_and_end_state)

return states_to_token_subsets

Expand Down Expand Up @@ -571,6 +568,11 @@ def create_fsm_index_tokenizer(
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {
state: set(tokens_id_end_state_list)
for state, tokens_id_end_state_list in enumerate(states_to_token_subsets)
if tokens_id_end_state_list
}

# Allow transitions to EOS from all terminals FSM states that are
# reachable
Expand Down
15 changes: 7 additions & 8 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _patched_apply_logits_processors(
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits_row = logits_processor(seq_id, token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
Expand All @@ -56,21 +56,20 @@ def __init__(self, regex_string, llm):
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm

def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

state_id = hash(tuple(input_ids))

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
else:
prev_state_id = hash(tuple(input_ids[:-1]))
last_token = input_ids[-1]
self.fsm_state[state_id] = self.fsm.next_state(
self.fsm_state[prev_state_id], last_token
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[state_id])
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
Expand Down
5 changes: 5 additions & 0 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def test_create_fsm_index_end_to_end():
vocabulary_nb.update(vocabulary)

res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb)
res = {
state: set(tokens_id_end_state_list)
for state, tokens_id_end_state_list in enumerate(res)
if tokens_id_end_state_list
}

assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}}

Expand Down
90 changes: 0 additions & 90 deletions tests/generate/test_vllm_regex_logits_process.py

This file was deleted.

0 comments on commit 903a084

Please sign in to comment.