diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 5544d4925..e15f0dc7f 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -78,7 +78,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state return self.first_state @@ -172,7 +172,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state last_token_to_end_state = self.states_to_token_maps[state] @@ -354,7 +354,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: ------- The new state of the FSM. """ - if token_id == self.tokenizer.eos_token_id: + if token_id == self.tokenizer.eos_token_id or state == self.final_state: return self.final_state self.generation += self.tokenizer.decode([token_id])[0]