Skip to content

Commit

Permalink
Remove CFG logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Aug 15, 2024
1 parent 00b3d33 commit 2087e6a
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 973 deletions.
24 changes: 1 addition & 23 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import TYPE_CHECKING, Iterable, NewType, Optional

from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide
from outlines.fsm.guide import RegexGuide, StopAtEOSGuide

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer
Expand Down Expand Up @@ -45,25 +45,3 @@ def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:

def next_state(self, state: FSMState, token_id: int) -> FSMState:
return FSMState(self.get_next_state(state, token_id))


class CFGFSM(CFGGuide):
"""FSM to generate text that is in the language of a context-free grammar."""

def __init__(self, cfg_string: str, tokenizer):
warnings.warn(
UserWarning(
"The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead."
)
)
super().__init__(cfg_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
return self.get_next_instruction(state).tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
return FSMState(self.get_next_state(state, token_id))

def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
179 changes: 0 additions & 179 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

import interegular
import torch
from lark import Lark

from outlines import grammars
from outlines.caching import cache
from outlines.fsm.regex import (
create_fsm_index_tokenizer,
Expand Down Expand Up @@ -298,180 +296,3 @@ def is_final_state(self, state: int) -> bool:

def copy(self):
return self


class CFGGuide(Guide):
"""Guide to generate text that is in the language of a context-free grammar."""

def __init__(self, cfg_string: str, tokenizer):
self.cfg_string = cfg_string
self.tokenizer = tokenizer

self.parser = Lark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token

self.generation = ""
self.reset_state = False
self.allow_eos = False
self.regex_fsm: RegexGuide

self.check_last = False
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexGuide

self.start_state = 0
self.final_state = -1

def get_next_instruction(self, state: int) -> Instruction:
"""Generate an instruction for the next step.
Upon initialization, the CFG incremental parser is used to determine the
first regex and construct the first FSM to generate the first terminal.
This FSM is used for proposals until either:
- The FSM is exhausted, and its only remaining option is the EOS token,
in which case we feed the generated terminal to the
CFG incremental parser and allow it to propose the next regex
corresponding to the next set of valid terminals.
- The current FSM can be exhausted, but the EOS token is not the only
remaining option. In this case we allow proposal of current terminal
extensions, store the current FSM and its state, then also use the CFG
parser to propose a new regex corresponding to terminating the current
terminal and starting the next one. The model can then sample from
either of these sets to determine whether to extend the current
terminal or terminate it and start the next one.
The CFG incremental parser is allowed to propose the EOS token from any accepting state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
if self.is_final_state(state):
return Write([self.tokenizer.eos_token_id])

proposal: List[int] = []
if self.generation != "":
if self.check_last:
proposer = self.regex_fsm_last
else:
proposer = self.regex_fsm

instruction = proposer.get_next_instruction(state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
proposal += instruction.tokens

if self.tokenizer.eos_token_id not in proposal:
return Generate(proposal)

self.check_last = False
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
if len(proposal) > 0:
self.check_last = True
self.proposal_last = proposal.copy()
self.regex_fsm_last = proposer

interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer()

options = {self.terminal_regexps[x] for x in interactive.accepts()}
# add %ignore terminals
options |= {self.terminal_regexps[x] for x in self.parser.lexer_conf.ignore}

if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
return Write([self.tokenizer.eos_token_id])
self.allow_eos = True
options.add("")
assert len(options) > 1

regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
self.regex_fsm = RegexGuide(regex_string, self.tokenizer)
self.reset_state = True

instruction = self.regex_fsm.get_next_instruction(self.start_state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
proposal += instruction.tokens

if self.allow_eos:
self.allow_eos = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0

return Generate(proposal)

def get_next_state(self, state: int, token_id: int) -> int:
"""Update the state of the guide.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""

# We need to return the final state when in the final state because we
# then generate EOS tokens instead of stopping the generation.
if token_id == self.tokenizer.eos_token_id or state == self.final_state:
return self.final_state

self.generation += self.tokenizer.decode([token_id])[0]

if self.check_last:
if token_id in self.proposal_last:
return self.regex_fsm_last.get_next_state(state, token_id)
self.check_last = False

if self.reset_state:
self.reset_state = False
state = self.start_state

return self.regex_fsm.get_next_state(state, token_id)

def is_final_state(self, state: int) -> bool:
return state == self.final_state

def copy(self) -> "CFGGuide":
"""Create a copy of the FSM."""
return CFGGuide(self.cfg_string, self.tokenizer)
8 changes: 3 additions & 5 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from outlines.processors import OutlinesLogitsProcessor

__all__ = ["transformers"]


Expand Down Expand Up @@ -217,7 +215,7 @@ def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
logits_processor,
sampling_parameters: SamplingParameters,
) -> Union[str, List[str], List[List[str]]]:
"""Generate text using `transformers`.
Expand Down Expand Up @@ -275,7 +273,7 @@ def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
logits_processor,
sampling_parameters: SamplingParameters,
) -> Iterator[Union[str, List[str]]]:
"""
Expand Down Expand Up @@ -319,7 +317,7 @@ def _get_generation_kwargs(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
logits_processor,
sampling_parameters: SamplingParameters,
) -> dict:
"""
Expand Down
Loading

0 comments on commit 2087e6a

Please sign in to comment.