diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 4a7fce8c..bfcf55c0 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Iterable, NewType, Optional -from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide +from outlines.fsm.guide import RegexGuide, StopAtEOSGuide if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -45,25 +45,3 @@ def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: def next_state(self, state: FSMState, token_id: int) -> FSMState: return FSMState(self.get_next_state(state, token_id)) - - -class CFGFSM(CFGGuide): - """FSM to generate text that is in the language of a context-free grammar.""" - - def __init__(self, cfg_string: str, tokenizer): - warnings.warn( - UserWarning( - "The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead." - ) - ) - super().__init__(cfg_string, tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - return self.get_next_instruction(state).tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) - - def copy(self) -> "CFGFSM": - """Create a copy of the FSM.""" - return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2e441514..c846c441 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -13,9 +13,7 @@ import interegular import torch -from lark import Lark -from outlines import grammars from outlines.caching import cache from outlines.fsm.regex import ( create_fsm_index_tokenizer, @@ -298,180 +296,3 @@ def is_final_state(self, state: int) -> bool: def copy(self): return self - - -class CFGGuide(Guide): - """Guide to generate text that is in the language of a context-free grammar.""" - - def __init__(self, cfg_string: str, tokenizer): - self.cfg_string = cfg_string - self.tokenizer = tokenizer - - self.parser = Lark( - cfg_string, - parser="lalr", - lexer="contextual", - propagate_positions=False, - maybe_placeholders=False, - regex=True, - import_paths=[grammars.GRAMMAR_PATH], - ) - self.terminal_regexps = dict() - for terminal in self.parser.terminals: - if terminal.pattern is not None: - self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() - self.terminal_regexps["$END"] = tokenizer.eos_token - - self.generation = "" - self.reset_state = False - self.allow_eos = False - self.regex_fsm: RegexGuide - - self.check_last = False - self.proposal_last: List[int] = [] - self.regex_fsm_last: RegexGuide - - self.start_state = 0 - self.final_state = -1 - - def get_next_instruction(self, state: int) -> Instruction: - """Generate an instruction for the next step. - - Upon initialization, the CFG incremental parser is used to determine the - first regex and construct the first FSM to generate the first terminal. - - This FSM is used for proposals until either: - - - The FSM is exhausted, and its only remaining option is the EOS token, - in which case we feed the generated terminal to the - CFG incremental parser and allow it to propose the next regex - corresponding to the next set of valid terminals. - - The current FSM can be exhausted, but the EOS token is not the only - remaining option. In this case we allow proposal of current terminal - extensions, store the current FSM and its state, then also use the CFG - parser to propose a new regex corresponding to terminating the current - terminal and starting the next one. The model can then sample from - either of these sets to determine whether to extend the current - terminal or terminate it and start the next one. - - The CFG incremental parser is allowed to propose the EOS token from any accepting state, - and once it is generated, the FSM will continue to always generate the EOS token. - - Parameters - ---------- - state - The current state of the FSM. - - Returns - ------- - A list that contains the tokens to mask. - - """ - if self.is_final_state(state): - return Write([self.tokenizer.eos_token_id]) - - proposal: List[int] = [] - if self.generation != "": - if self.check_last: - proposer = self.regex_fsm_last - else: - proposer = self.regex_fsm - - instruction = proposer.get_next_instruction(state) - - assert instruction.tokens is not None - - if isinstance(instruction, Write): - proposal += instruction.tokens - else: - proposal += instruction.tokens - - if self.tokenizer.eos_token_id not in proposal: - return Generate(proposal) - - self.check_last = False - proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] - if len(proposal) > 0: - self.check_last = True - self.proposal_last = proposal.copy() - self.regex_fsm_last = proposer - - interactive = self.parser.parse_interactive(self.generation) - interactive.exhaust_lexer() - - options = {self.terminal_regexps[x] for x in interactive.accepts()} - # add %ignore terminals - options |= {self.terminal_regexps[x] for x in self.parser.lexer_conf.ignore} - - if self.terminal_regexps["$END"] in options: - options.remove(self.terminal_regexps["$END"]) - if len(options) == 0: - return Write([self.tokenizer.eos_token_id]) - self.allow_eos = True - options.add("") - assert len(options) > 1 - - regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" - self.regex_fsm = RegexGuide(regex_string, self.tokenizer) - self.reset_state = True - - instruction = self.regex_fsm.get_next_instruction(self.start_state) - - assert instruction.tokens is not None - - if isinstance(instruction, Write): - proposal += instruction.tokens - else: - proposal += instruction.tokens - - if self.allow_eos: - self.allow_eos = False - else: - proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] - assert len(proposal) > 0 - - return Generate(proposal) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - Transitions the underlying regex FSM to its next state. - If at max tokens or EOS token, transition permanently to the final state. - Update stored partial generations for subsequent incremental parsing. - - Parameters - ---------- - state - The current state of the FSM. - token_id - The id of the token that was just generated. - - Returns - ------- - The new state of the FSM. - """ - - # We need to return the final state when in the final state because we - # then generate EOS tokens instead of stopping the generation. - if token_id == self.tokenizer.eos_token_id or state == self.final_state: - return self.final_state - - self.generation += self.tokenizer.decode([token_id])[0] - - if self.check_last: - if token_id in self.proposal_last: - return self.regex_fsm_last.get_next_state(state, token_id) - self.check_last = False - - if self.reset_state: - self.reset_state = False - state = self.start_state - - return self.regex_fsm.get_next_state(state, token_id) - - def is_final_state(self, state: int) -> bool: - return state == self.final_state - - def copy(self) -> "CFGGuide": - """Create a copy of the FSM.""" - return CFGGuide(self.cfg_string, self.tokenizer) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 949d3c9c..10f8f248 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -10,8 +10,6 @@ import torch from transformers import PreTrainedModel, PreTrainedTokenizer - from outlines.processors import OutlinesLogitsProcessor - __all__ = ["transformers"] @@ -217,7 +215,7 @@ def generate( self, prompts: Union[str, List[str]], generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], + logits_processor, sampling_parameters: SamplingParameters, ) -> Union[str, List[str], List[List[str]]]: """Generate text using `transformers`. @@ -275,7 +273,7 @@ def stream( self, prompts: Union[str, List[str]], generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], + logits_processor, sampling_parameters: SamplingParameters, ) -> Iterator[Union[str, List[str]]]: """ @@ -319,7 +317,7 @@ def _get_generation_kwargs( self, prompts: Union[str, List[str]], generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], + logits_processor, sampling_parameters: SamplingParameters, ) -> dict: """ diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 21163b70..94166fd9 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -1,6 +1,6 @@ import pytest -from outlines.fsm.fsm import CFGFSM, RegexFSM, StopAtEosFSM +from outlines.fsm.fsm import RegexFSM, StopAtEosFSM def assert_expected_tensor_ids(tensor, ids): @@ -90,263 +90,3 @@ def convert_token_to_string(self, token): state = fsm.next_state(state=5, token_id=103) assert fsm.is_final_state(state) - - -def test_cfg(): - class MockTokenizer: - vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 5 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr - expr: "{" expr "}" | "[" expr "]" | - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3, 5]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "{" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "{[" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3, 4]) - state = fsm.next_state(state=state, token_id=4) - assert fsm.generation == "{[]" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "{[]}" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [5]) - state = fsm.next_state(state=state, token_id=5) - assert fsm.generation == "{[]}" - assert fsm.is_final_state(state) - - -def test_cfg_early_termination(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr+ - expr: "(" subexpr ")" - subexpr: expr | - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "()" - assert not fsm.is_final_state(state) - - # possible to continue or terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3]) - state = fsm.next_state(state=state, token_id=3) # feed eos - assert fsm.generation == "()" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - - -def test_cfg_ignore_directive(): - class MockTokenizer: - vocabulary = {"a": 1, " ": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: LETTER+ - LETTER: "a" - WS: " " - %ignore WS - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - state = 0 - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2]) - state = fsm.next_state(state=0, token_id=2) - assert fsm.generation == " " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2]) - state = fsm.next_state(state=0, token_id=1) - assert fsm.generation == " a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=1) - assert fsm.generation == " a a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == " a a" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - - -def test_cfg_multitoken_terminal(): - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: S - S: "aa" | "bb" - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 2]) - assert fsm.reset_state # starting new regex - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1]) - assert not fsm.reset_state # continuing current regex - state = fsm.next_state(state=state, token_id=1) - assert fsm.generation == "aa" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3]) - assert not fsm.reset_state # completing current regex - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "aa" - assert fsm.is_final_state(state) - - -def test_cfg_allow_both_extend_and_shift_terminal(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "a": 3, "eos": 4} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: s - s: "(" s ")" | /a+/ - """ - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = CFGFSM(cfg_str, tokenizer) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3]) - state = fsm.next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "(a" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3]) - state = fsm.next_state(state=state, token_id=3) - assert fsm.generation == "(aa" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3]) - state = fsm.next_state(state=state, token_id=2) - assert fsm.generation == "(aa)" - assert not fsm.is_final_state(state) - - assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [4]) - state = fsm.next_state(state=state, token_id=4) - assert fsm.generation == "(aa)" - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 20ba7589..9a66bc04 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,6 +1,6 @@ import pytest -from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, StopAtEOSGuide, Write +from outlines.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write def assert_expected_tensor_ids(tensor, ids): @@ -188,301 +188,3 @@ def convert_token_to_string(self, token): state = fsm.get_next_state(state=5, token_id=103) assert fsm.is_final_state(state) - - -def test_cfg(): - class MockTokenizer: - vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 5 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr - expr: "{" expr "}" | "[" expr "]" | - """ - tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3, 5]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "{" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "{[" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3, 4]) - state = fsm.get_next_state(state=state, token_id=4) - assert fsm.generation == "{[]" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "{[]}" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [5]) - state = fsm.get_next_state(state=state, token_id=5) - assert fsm.generation == "{[]}" - assert fsm.is_final_state(state) - - -def test_cfg_early_termination(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: expr+ - expr: "(" subexpr ")" - subexpr: expr | - """ - tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "()" - assert not fsm.is_final_state(state) - - # possible to continue or terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=state, token_id=3) # feed eos - assert fsm.generation == "()" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - -def test_cfg_ignore_directive(): - class MockTokenizer: - vocabulary = {"a": 1, " ": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: LETTER+ - LETTER: "a" - WS: " " - %ignore WS - """ - tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - - state = 0 - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=0, token_id=2) - assert fsm.generation == " " - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - state = fsm.get_next_state(state=0, token_id=1) - assert fsm.generation == " a" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == " a " - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=1) - assert fsm.generation == " a a" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == " a a" - assert fsm.is_final_state(state) - - # once eos generated, can only terminate - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - - -def test_cfg_multitoken_terminal(): - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "eos": 3} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: S - S: "aa" | "bb" - """ - tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 2]) - assert fsm.reset_state # starting new regex - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "a" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - assert not fsm.reset_state # continuing current regex - state = fsm.get_next_state(state=state, token_id=1) - assert fsm.generation == "aa" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [3]) - assert not fsm.reset_state # completing current regex - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "aa" - assert fsm.is_final_state(state) - - -def test_cfg_allow_both_extend_and_shift_terminal(): - class MockTokenizer: - vocabulary = {"(": 1, ")": 2, "a": 3, "eos": 4} - special_tokens = {"eos"} - eos_token = "eos" - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - @property - def inverse_vocabulary(self): - return {v: k for k, v in self.vocabulary.items()} - - def decode(self, token_ids): - return [self.inverse_vocabulary[t] for t in token_ids] - - cfg_str = """ - start: s - s: "(" s ")" | /a+/ - """ - tokenizer = MockTokenizer() - fsm = CFGGuide(cfg_str, tokenizer) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=fsm.start_state, token_id=1) - assert fsm.generation == "(" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "(a" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2, 3]) - state = fsm.get_next_state(state=state, token_id=3) - assert fsm.generation == "(aa" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [2, 3]) - state = fsm.get_next_state(state=state, token_id=2) - assert fsm.generation == "(aa)" - assert not fsm.is_final_state(state) - - instruction = fsm.get_next_instruction(state) - assert isinstance(instruction, Write) - assert_expected_tensor_ids(instruction.tokens, [4]) - state = fsm.get_next_state(state=state, token_id=4) - assert fsm.generation == "(aa)" - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_parsing.py b/tests/fsm/test_parsing.py deleted file mode 100644 index 3f4c1ba4..00000000 --- a/tests/fsm/test_parsing.py +++ /dev/null @@ -1,206 +0,0 @@ -import importlib -from copy import copy - -import lark.lark -import pytest -from lark.indenter import DedentError -from lark.lexer import UnexpectedCharacters, UnexpectedToken - -from outlines.fsm.parsing import PartialLark, PartialPythonIndenter - - -@pytest.fixture -def cleanup_lark_import(): - yield - # Clean up lark.lark.LarkOptions._defaults - importlib.reload(lark.lark) - - -def test_partial_parsing(cleanup_lark_import): - lp = PartialLark.open_from_package( - "tests", - "partial_python.lark", - ["fsm"], - parser="lalr", - postlex=PartialPythonIndenter(), - start="file_input", - deterministic=True, - ) - - # End with a potentially unfinished NAME - parser_state = lp.parse("x") - assert parser_state.state_stack == [0] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 15) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - # End with an ignored token - parser_state = lp.parse("x ") - assert parser_state.state_stack == [0, 692] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 1) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - # Could be a complete `=` or the start of a `==` - parser_state = lp.parse("x =") - assert parser_state.state_stack == [0, 692] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert any( - term_info.terminal_name == "EQUAL" - for term_info in last_token.value.terminals_and_info - ) - assert not parser_state.value_stack - - parser_state = lp.parse("x = '") - assert parser_state.state_stack == [0, 58, 59] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 6) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - parser_state = lp.parse("x = 'hi") - assert parser_state.state_stack == [0, 58, 59] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 6, 6, 6) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - parser_state = lp.parse("x = ('hi") - assert parser_state.state_stack == [0, 58, 59, 254] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 6, 6, 6) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - parser_state = lp.parse("def") - assert parser_state.state_stack == [0] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - assert last_token.value.fsm_state_seq == (0, 26, 99, 100) - assert last_token.value.is_not_finished is True - assert not parser_state.value_stack - - # Now, try something incremental - last_lexer_state = parser_state.lexer.state - last_lexer_state.text += " blah()" - lp.parse_from_state(parser_state, is_end=False) - last_token = parser_state.lexer.state.last_token - assert not parser_state.value_stack - - last_lexer_state = parser_state.lexer.state - last_valid_token = last_lexer_state.last_token - assert last_valid_token.type == "RPAR" - assert not parser_state.value_stack - - # Something incremental and a little more complicated - parser_state = lp.parse("x = 1\ndef foo(x):\n ") - assert parser_state.state_stack == [0, 94, 600, 601, 602, 607, 608, 269] - last_lexer_state = parser_state.lexer.state - last_lexer_state.text += " return x" - - lp.parse_from_state(parser_state, is_end=False) - assert parser_state.state_stack == [ - 0, - 94, - 600, - 601, - 602, - 607, - 608, - 269, - 764, - 95, - 305, - ] - last_token = parser_state.lexer.state.last_token - assert last_token.type == "partial" - - with pytest.raises(UnexpectedToken): - lp.parse("def \n") - - with pytest.raises(UnexpectedToken): - lp.parse("def hot no") - - lp = PartialLark.open_from_package( - "tests", - "partial_python.lark", - ["fsm"], - parser="lalr", - postlex=PartialPythonIndenter(), - start="file_input", - use_value_stack=True, - ) - parser_state = lp.parse("x = ('hi") - lp.parse_from_state(parser_state, is_end=False) - assert len(parser_state.state_stack) == 4 - assert parser_state.value_stack[-1].type == "LPAR" - - -def test_sequential_parse_example(cleanup_lark_import): - input_tokens = [ - "x ", - "= ", - "1", - "\nde", - "f ", - "foo(", - "x)", - ":\n", - " ", - " return x", - " + 1", - "\n", - "z ", - "= ", - "foo(", - '"hi', - '")\n', - ] - vocab = sorted(set(input_tokens)) - - lp = PartialLark.open_from_package( - "tests", - "partial_python.lark", - ["fsm"], - parser="lalr", - postlex=PartialPythonIndenter(), - start="file_input", - deterministic=True, - ) - parser_state = lp.parse("") - - token_seq = "" - for i, token in enumerate(input_tokens): - token_seq += token - - lex_state = parser_state.lexer.state - lex_state.text = token_seq - - lp.parse_from_state(parser_state, is_end=False) - - next_vocab = set() - for test_token in vocab: - ps = copy(parser_state) - ls = ps.lexer.state - ls.text = token_seq + test_token - - if i + 1 < len(input_tokens) and test_token == input_tokens[i + 1]: - lp.parse_from_state(ps, is_end=False) - next_vocab.add(test_token) - else: - try: - lp.parse_from_state(ps, is_end=False) - next_vocab.add(test_token) - except (EOFError, UnexpectedToken, UnexpectedCharacters, DedentError): - pass - - if i + 1 == len(input_tokens): - assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])