From c0b47a45121e1c865524a309eef4cf2331129233 Mon Sep 17 00:00:00 2001 From: Dan Saattrup Nielsen <47701536+saattrupdan@users.noreply.github.com> Date: Fri, 1 Mar 2024 08:53:30 +0100 Subject: [PATCH] Return FSM final state if already in final state (#718) Fixes #716 --- outlines/fsm/fsm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 5544d4925..e15f0dc7f 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -78,7 +78,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state return self.first_state @@ -172,7 +172,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state last_token_to_end_state = self.states_to_token_maps[state] @@ -354,7 +354,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: ------- The new state of the FSM. """ - if token_id == self.tokenizer.eos_token_id: + if token_id == self.tokenizer.eos_token_id or state == self.final_state: return self.final_state self.generation += self.tokenizer.decode([token_id])[0]