diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index a96fa8e24..e8cd4c918 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -516,10 +516,9 @@ def create_fsm_index_end_to_end( start_state, ) + states_to_token_subsets[start_state] = set(token_ids_end_states) + for token_id_and_end_state in token_ids_end_states: - states_to_token_subsets.setdefault(start_state, set()).add( - token_id_and_end_state - ) end_state = token_id_and_end_state[1] if end_state not in seen: next_states.add(end_state) @@ -572,13 +571,9 @@ def create_fsm_index_tokenizer( states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary) - # Allow transitions to EOS from all terminals FSM states that are - # reachable - # TODO: Do we really need this anymore? + # Allow transitions to EOS from all terminals FSM states that are reachable for state in fsm.fsm_info.finals: - subset = states_to_token_subsets.get(state) - if subset is not None: - subset.add((tokenizer.eos_token_id, state)) + states_to_token_subsets[state].add((tokenizer.eos_token_id, state)) # Convert to token-to-end-state maps states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 88862e5d4..77cb3a58e 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -46,7 +46,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexFSM(regex_str, tokenizer) - assert fsm.states_to_token_maps == {0: {1: 1}} + assert fsm.states_to_token_maps == {0: {1: 1}, 1: {3: 1}} assert fsm.allowed_token_ids(state=0) == [1] assert fsm.next_state(state=0, token_id=1) == 1 assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == fsm.final_state @@ -241,3 +241,27 @@ def decode(self, token_ids): state = fsm.next_state(state=state, token_id=4) assert fsm.generation == "(aa)" assert fsm.is_final_state(state) + + +def test_regression_regex_missing_final_state(): + class MockTokenizer: + vocabulary = {'`': 101, '.': 102, '\n': 103, "eos": 104} + special_tokens = {"eos"} + eos_token_id = 104 + + def convert_token_to_string(self, token): + return token + + regex_str = r'`\n(\.\n)?`\n' + tokenizer = MockTokenizer() + fsm = RegexFSM(regex_str, tokenizer) + + assert fsm.states_to_token_maps == { + 0: {101: 1}, + 1: {103: 2}, + 2: {102: 3, 101: 4}, + 3: {103: 6}, + 4: {103: 5}, + 5: {104: 5}, + 6: {101: 4}, + }