Skip to content

Commit

Permalink
Adding all states, even if they are empty (Fixes dottxt-ai#605) from …
Browse files Browse the repository at this point in the history
…viktor-ferenczi/issue-605

Fixed test_regex to expect the final state

This test case fails now, which is expected until the fix is applied.

Regression test case

It reproduces the case where state 5 is missing from the generated `fsm.states_to_token_maps`.
  • Loading branch information
viktor-ferenczi committed Feb 5, 2024
1 parent 244c914 commit 870d918
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
13 changes: 4 additions & 9 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,9 @@ def create_fsm_index_end_to_end(
start_state,
)

states_to_token_subsets[start_state] = set(token_ids_end_states)

for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)
Expand Down Expand Up @@ -572,13 +571,9 @@ def create_fsm_index_tokenizer(

states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)

# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
# Allow transitions to EOS from all terminals FSM states that are reachable
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset.add((tokenizer.eos_token_id, state))
states_to_token_subsets[state].add((tokenizer.eos_token_id, state))

# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
Expand Down
26 changes: 25 additions & 1 deletion tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def convert_token_to_string(self, token):
tokenizer = MockTokenizer()
fsm = RegexFSM(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}
assert fsm.states_to_token_maps == {0: {1: 1}, 1: {3: 1}}
assert fsm.allowed_token_ids(state=0) == [1]
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == fsm.final_state
Expand Down Expand Up @@ -241,3 +241,27 @@ def decode(self, token_ids):
state = fsm.next_state(state=state, token_id=4)
assert fsm.generation == "(aa)"
assert fsm.is_final_state(state)


def test_regression_regex_missing_final_state():
class MockTokenizer:
vocabulary = {'`': 101, '.': 102, '\n': 103, "eos": 104}
special_tokens = {"eos"}
eos_token_id = 104

def convert_token_to_string(self, token):
return token

regex_str = r'`\n(\.\n)?`\n'
tokenizer = MockTokenizer()
fsm = RegexFSM(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {101: 1},
1: {103: 2},
2: {102: 3, 101: 4},
3: {103: 6},
4: {103: 5},
5: {104: 5},
6: {101: 4},
}

0 comments on commit 870d918

Please sign in to comment.