Skip to content

Commit

Permalink
Stop generation at every FSM final state
Browse files Browse the repository at this point in the history
A recent change replaced the set of FSM final states with the state -1
that is used to represent an EOS token being generated. This could
explain the issue reported in #605.
  • Loading branch information
rlouf committed Mar 7, 2024
1 parent 9fd0f46 commit c1851df
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 30 deletions.
56 changes: 40 additions & 16 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@


class FSM(Protocol):
first_state: FSMState = FSMState(0)
final_state: FSMState = FSMState(-1)

def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state == self.final_state
...

def allowed_token_ids(self, state: FSMState) -> List[int]:
...
Expand All @@ -32,12 +28,14 @@ def copy(self) -> "FSM":
...


class StopAtEosFSM(FSM):
class StopAtEosFSM:
"""FSM to generate text until EOS has been generated."""

def __init__(self, tokenizer: "Tokenizer"):
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.start_state: FSMState = FSMState(0)
self.final_state: FSMState = FSMState(1)

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -81,21 +79,25 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state

return self.first_state
return self.start_state

def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state == self.final_state

def copy(self) -> "StopAtEosFSM":
"""Create a copy of the FSM."""
return self


class RegexFSM(FSM):
class RegexFSM:
"""FSM to generate text that is in the language of a regular expression."""

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int], ...]
) -> Tuple[dict, set]:
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
Expand All @@ -116,13 +118,19 @@ def create_states_mapping(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids
return states_to_token_maps, empty_token_ids, regex_fsm.finals

self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.vocabulary = list(tokenizer.vocabulary.values())
self.eos_token_id = tokenizer.eos_token_id
self.start_state = FSMState(0)
self.final_states = fsm_finals | {-1}

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -172,13 +180,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
The new state of the FSM.
"""
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state
if token_id == self.eos_token_id:
return FSMState(-1)
elif (
state in self.final_states
): # Necessary because we keep generating EOS tokens when finished
return state

last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
return self.final_state
return FSMState(-1)

return FSMState(next_state)

Expand Down Expand Up @@ -222,6 +234,9 @@ def create_states_mapping_from_interegular_fsm(
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
return from_interegular_instance

def is_final_state(self, state: FSMState) -> bool:
return state in self.final_states

def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
Expand Down Expand Up @@ -258,6 +273,9 @@ def __init__(self, cfg_string: str, tokenizer):
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexFSM

self.start_state = FSMState(0)
self.final_state = FSMState(-1)

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -328,7 +346,7 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True

proposal += self.regex_fsm.allowed_token_ids(self.first_state)
proposal += self.regex_fsm.allowed_token_ids(self.start_state)
if self.allow_eos:
self.allow_eos = False
else:
Expand All @@ -354,6 +372,9 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
-------
The new state of the FSM.
"""

# We need to return the final state when in the final state because we
# then generate EOS tokens instead of stopping the generation.
if token_id == self.tokenizer.eos_token_id or state == self.final_state:
return self.final_state

Expand All @@ -366,10 +387,13 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:

if self.reset_state:
self.reset_state = False
state = self.first_state
state = self.start_state

return self.regex_fsm.next_state(state, token_id)

def is_final_state(self, state: FSMState) -> bool:
return state == self.final_state

def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
55 changes: 41 additions & 14 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class MockTokenizer:

fsm = StopAtEosFSM(MockTokenizer())

assert fsm.allowed_token_ids(fsm.first_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.final_state) == [2]
assert fsm.next_state(fsm.first_state, 2) == fsm.final_state
assert fsm.next_state(fsm.first_state, 1) == fsm.first_state
assert fsm.is_final_state(fsm.first_state) is False
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state
assert fsm.is_final_state(fsm.start_state) is False
assert fsm.is_final_state(fsm.final_state) is True


Expand Down Expand Up @@ -49,10 +49,37 @@ def convert_token_to_string(self, token):
assert fsm.states_to_token_maps == {0: {1: 1}}
assert fsm.allowed_token_ids(state=0) == [1]
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == fsm.final_state
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False
assert fsm.is_final_state(fsm.final_state) is True

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""

class MockTokenizer:
vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
special_tokens = {"eos"}
eos_token_id = 104

def convert_token_to_string(self, token):
return token

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()
fsm = RegexFSM(regex_str, tokenizer)

state = fsm.next_state(state=4, token_id=103)
assert state == 5
assert fsm.is_final_state(state)

state = fsm.next_state(state=5, token_id=103)
assert state == 5

assert fsm.is_final_state(-1)


def test_cfg():
Expand All @@ -79,8 +106,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 3, 5}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3, 5}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "{"
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -130,8 +157,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -236,9 +263,9 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 2}
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 2}
assert fsm.reset_state # starting new regex
state = fsm.next_state(state=fsm.first_state, token_id=1)
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "a"
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -279,8 +306,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 3}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

Expand Down

0 comments on commit c1851df

Please sign in to comment.