From 2e0797394824e35b4007b2be3f5e2a777d208517 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 5 Jun 2023 23:18:49 -0500 Subject: [PATCH] Add basic parser-driven masking utilities --- examples/parsing.py | 108 +++++++++++++++++ outlines/text/parsing.py | 240 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 4 + tests/text/test_parsing.py | 113 +++++++++++++++++ 4 files changed, 465 insertions(+) create mode 100644 examples/parsing.py create mode 100644 outlines/text/parsing.py create mode 100644 tests/text/test_parsing.py diff --git a/examples/parsing.py b/examples/parsing.py new file mode 100644 index 000000000..3f070c470 --- /dev/null +++ b/examples/parsing.py @@ -0,0 +1,108 @@ +"""An example illustrating parser-based masking.""" +import math +import time + +import torch +from lark import Lark +from lark.indenter import DedentError +from lark.lexer import UnexpectedCharacters, UnexpectedToken +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LogitsProcessor, + LogitsProcessorList, + set_seed, +) + +from outlines.text.parsing import PartialPythonIndenter, copy_parser_state, parse_to_end + +revision = None +checkpoint = "Salesforce/codegen-350M-mono" +device = "cuda" + +tokenizer = AutoTokenizer.from_pretrained(checkpoint) + +model = AutoModelForCausalLM.from_pretrained( + checkpoint, trust_remote_code=True, revision=revision +).to(device) + +input_text = "def " +inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) + + +class ParserLogitsProcessor(LogitsProcessor): + """Bias invalid token scores according to a running parse state.""" + + def __init__(self): + pyparser = Lark.open_from_package( + "lark", + "python.lark", + ["grammars"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", + ) + ip = pyparser.parse_interactive("") + self.parser_state = ip.parser_state + self.states_stack = [self.parser_state] + self.token_seq = None + self.token_idx = 0 + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if self.token_seq is None: + self.token_seq = tokenizer.decode(input_ids[0]) + self.token_idx = len(input_ids[0]) - 1 + else: + self.token_idx += 1 + self.token_seq += tokenizer.decode(input_ids[0][self.token_idx]) + + # Process the last sampled token + lex_state = self.parser_state.lexer.state + lex_state.text = self.token_seq + + self.parser_state, partial_tokens = parse_to_end(self.parser_state) + + print("Parsed:\n") + print(self.token_seq) + + print(partial_tokens) + + mask = torch.full_like(scores, -math.inf) + + # Determine which tokens in the vocabulary are valid next tokens + # given the parser state. + # + # TODO: This is a very naive and slow approach. It could be done in + # parallel, but there are a few other approaches to try first, and + # those should dramatically reduce the amount of work done here. + t0 = time.perf_counter() + for test_token, token_id in tokenizer.vocab.items(): + ps = copy_parser_state(self.parser_state) + ls = ps.lexer.state + ls.text = self.token_seq + test_token + + try: + # TODO: The resulting states could possibly be reused? + parse_to_end(ps) + mask[0][token_id] = 0 + except (UnexpectedToken, UnexpectedCharacters, DedentError): + pass + + print(f"Next token masking duration: {time.perf_counter() - t0}") + + return scores + mask + + +set_seed(20399) + +outputs = model.generate( + inputs, + max_length=100, + temperature=0.1, + logits_processor=LogitsProcessorList([ParserLogitsProcessor()]), + renormalize_logits=True, +) + +print(tokenizer.decode(outputs[0])) diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py new file mode 100644 index 000000000..7b1d9cbab --- /dev/null +++ b/outlines/text/parsing.py @@ -0,0 +1,240 @@ +from copy import copy +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple + +import regex +from lark.exceptions import ( + LexError, + UnexpectedCharacters, + UnexpectedEOF, + UnexpectedToken, +) +from lark.indenter import PythonIndenter +from lark.lexer import BasicLexer, LexerState, Scanner, Token +from lark.parsers.lalr_interactive_parser import InteractiveParser +from lark.utils import get_regexp_width + +if TYPE_CHECKING: + from lark.lexer import LexerThread + from lark.parsers.lalr_parser import ParserState + + +class PartialTokenEOF(UnexpectedEOF): + pass + + +class PartialScanner(Scanner): + def __init__(self, scanner: Scanner): + self.terminals = scanner.terminals + self.g_regex_flags = scanner.g_regex_flags + self.re_ = regex + self.use_bytes = scanner.use_bytes + self.match_whole = scanner.match_whole + self.allowed_types = scanner.allowed_types + self._mres = scanner._mres + + def match(self, text, pos) -> Optional[Tuple[str, Optional[str], bool]]: + for mre in self._mres: + m = mre.match(text, pos=pos, partial=True) + if m: # and ((not m.partial) or m.endpos == len(text)): + return m.group(0), m.lastgroup, m.partial + return None + + +class PartialBasicLexer(BasicLexer): + def __init__(self, basic_lexer: BasicLexer): + self.re = regex + self.newline_types = basic_lexer.newline_types + self.ignore_types = basic_lexer.ignore_types + self.terminals = basic_lexer.terminals + self.user_callbacks = basic_lexer.user_callbacks + self.g_regex_flags = basic_lexer.g_regex_flags + self.use_bytes = basic_lexer.use_bytes + self.terminals_by_name = basic_lexer.terminals_by_name + self.callback = getattr(basic_lexer, "callback", None) + + if basic_lexer._scanner is not None: + self._scanner: Optional[PartialScanner] = PartialScanner( + basic_lexer._scanner + ) + else: + self._scanner = None + + # This is used to determine the token type for partial matches + self.terminal_to_regex = {} + for name, terminal in self.terminals_by_name.items(): + self.terminal_to_regex[name] = self.re.compile( + terminal.pattern.to_regexp(), self.g_regex_flags + ) + + def _build_scanner(self): + super()._build_scanner() + self._scanner = PartialScanner(self._scanner) + + def partial_matches(self, value, type_): + partial_matches = set() + + # TODO: It's unfortunate that we have to do this costly search (again). + # It would be better if we could *not* short-circuit the first time we + # scan in the call to `self.match`. + for term_name, term_regex in self.terminal_to_regex.items(): + if term_name == type_: + # A standard lexed token result could actual indicate a partial + # match + regex_min, regex_max = get_regexp_width(term_regex.pattern) + if regex_min <= len(value) < regex_max: + partial_matches.add(term_name) + else: + m = term_regex.match(value, partial=True) + if m: + partial_matches.add(term_name) + + return partial_matches + + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + line_ctr = lex_state.line_ctr + while line_ctr.char_pos < len(lex_state.text): + res = self.match(lex_state.text, line_ctr.char_pos) + + if not res: + allowed = self.scanner.allowed_types - self.ignore_types + if not allowed: + allowed = {""} + raise UnexpectedCharacters( + lex_state.text, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + allowed=allowed, + token_history=lex_state.last_token and [lex_state.last_token], + state=parser_state, + terminals_by_name=self.terminals_by_name, + ) + + value, type_, partial = res + + # Don't advance the lexing state if we're at the end; there could + # be ambiguous token types that aren't finished. + if line_ctr.char_pos + len(value) >= len(lex_state.text): + partial_matches = self.partial_matches(value, type_) + if partial_matches or partial: + raise PartialTokenEOF(partial_matches) + + assert isinstance(self.callback, Dict) + + if type_ not in self.ignore_types: + t = Token( + type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column + ) + line_ctr.feed(value, type_ in self.newline_types) + t.end_line = line_ctr.line + t.end_column = line_ctr.column + t.end_pos = line_ctr.char_pos + if t.type in self.callback: + t = self.callback[t.type](t) + if not isinstance(t, Token): + raise LexError( + "Callbacks must return a token (returned %r)" % t + ) + lex_state.last_token = t + return t + + if type_ in self.callback: + t2 = Token( + type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column + ) + self.callback[type_](t2) + + line_ctr.feed(value, type_ in self.newline_types) + + raise EOFError(self) + + +class PartialPythonIndenter(PythonIndenter): + """An `Indenter` that doesn't reset its state every time `process` is called.""" + + def process(self, stream): + return self._process(stream) + + def _process(self, stream): + for token in stream: + # These were previously *after* the `yield`, but that makes the + # state tracking unnecessarily convoluted. + if token.type in self.OPEN_PAREN_types: + self.paren_level += 1 + elif token.type in self.CLOSE_PAREN_types: + self.paren_level -= 1 + if self.paren_level < 0: + raise UnexpectedToken(token, []) + + if token.type == self.NL_type: + yield from self.handle_NL(token) + else: + yield token + + # while len(self.indent_level) > 1: + # self.indent_level.pop() + # yield Token(self.DEDENT_type, "") + + def __copy__(self): + res = type(self)() + res.paren_level = self.paren_level + res.indent_level = copy(self.indent_level) + return res + + +def copy_lexer_thread(lexer_thread: "LexerThread") -> "LexerThread": + res = copy(lexer_thread) + res.lexer = copy(res.lexer) + + if ( + res.lexer.postlexer + and isinstance(res.lexer.postlexer, PythonIndenter) + and not isinstance(res.lexer.postlexer, PartialPythonIndenter) + ): + # Patch these methods so that the post lexer keeps its state + # XXX: This won't really work in generality. + postlexer = PartialPythonIndenter() + postlexer.paren_level = res.lexer.postlexer.paren_level + postlexer.indent_level = res.lexer.postlexer.indent_level + res.lexer.postlexer = postlexer + + # Patch/replace the lexer objects so that they support partial matches + lexer = res.lexer.lexer + if not isinstance(lexer.root_lexer, PartialBasicLexer): + lexer.root_lexer = PartialBasicLexer(lexer.root_lexer) + + basic_lexers = res.lexer.lexer.lexers + for idx, lexer in basic_lexers.items(): + basic_lexers[idx] = PartialBasicLexer(lexer) + + res.lexer.postlexer = copy(res.lexer.postlexer) + + return res + + +def copy_parser_state(parser_state: "ParserState") -> "ParserState": + res = copy(parser_state) + res.lexer = copy_lexer_thread(res.lexer) + + return res + + +def copy_ip(ip: "InteractiveParser") -> "InteractiveParser": + res = copy(ip) + res.lexer_thread = copy_lexer_thread(res.lexer_thread) + return res + + +def parse_to_end(parser_state: "ParserState") -> Tuple["ParserState", Set[str]]: + """Continue parsing from the current parse state and return partial next tokens.""" + + parser_state = copy_parser_state(parser_state) + + expected_next_tokens: Set[str] = set() + try: + for token in parser_state.lexer.lex(parser_state): + parser_state.feed_token(token) + except PartialTokenEOF as e: + expected_next_tokens = e.expected + + return parser_state, expected_next_tokens diff --git a/pyproject.toml b/pyproject.toml index 62c7ae99a..686de1a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ test = [ "transformers", "coverage[toml]>=5.1", "diff-cover", + "lark", + "regex", ] [project.urls] @@ -87,6 +89,8 @@ module = [ "tiktoken.*", "torch", "transformers.*", + "lark.*", + "regex.*", ] ignore_missing_imports = true diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py new file mode 100644 index 000000000..2e69b3cb4 --- /dev/null +++ b/tests/text/test_parsing.py @@ -0,0 +1,113 @@ +from lark import Lark +from lark.indenter import DedentError +from lark.lexer import UnexpectedCharacters, UnexpectedToken + +from outlines.text.parsing import PartialPythonIndenter, copy_parser_state, parse_to_end + + +def test_parse_to_end(): + pyparser = Lark.open_from_package( + "lark", + "python.lark", + ["grammars"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", + ) + + ip = pyparser.parse_interactive("x") + parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + assert not parser_state.value_stack + assert expected_next_tokens == {"NAME"} + + ip = pyparser.parse_interactive("x = '") + parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + assert parser_state.value_stack[-1].type == "EQUAL" + assert expected_next_tokens == {"LONG_STRING", "STRING"} + + ip = pyparser.parse_interactive("x = 'hi") + parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + assert parser_state.value_stack[-1].type == "EQUAL" + assert expected_next_tokens == {"STRING"} + + ip = pyparser.parse_interactive("x = ('hi") + parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + assert parser_state.value_stack[-1].type == "LPAR" + assert expected_next_tokens == {"STRING"} + + ip = pyparser.parse_interactive("def") + parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + assert not parser_state.value_stack + assert expected_next_tokens == {"NAME", "DEF"} + + # Now, try something incremental + parser_state = copy_parser_state(parser_state) + last_lexer_state = parser_state.lexer.state + last_lexer_state.text = "def blah()" + + (parser_state, expected_next_tokens) = parse_to_end(parser_state) + + last_lexer_state = parser_state.lexer.state + last_valid_token = last_lexer_state.last_token + assert last_valid_token.type == "RPAR" + assert not expected_next_tokens + + +def test_sequential_parse_example(): + input_tokens = [ + "x ", + "= ", + "1", + "\nde", + "f ", + "foo(", + "x)", + ":\n", + " ", + " return x", + " + 1", + "\n", + "z ", + "= ", + "foo(", + '"hi' '")', + ] + vocab = set(input_tokens) + + pyparser = Lark.open_from_package( + "lark", + "python.lark", + ["grammars"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", + ) + ip = pyparser.parse_interactive("") + parser_state = ip.parser_state + + token_seq = "" + for i, token in enumerate(input_tokens): + token_seq += token + + lex_state = parser_state.lexer.state + lex_state.text = token_seq + + parser_state, partial_tokens = parse_to_end(parser_state) + + next_vocab = set() + for test_token in vocab: + ps = copy_parser_state(parser_state) + ls = ps.lexer.state + ls.text = token_seq + test_token + + try: + # TODO: The resulting states could possibly be reused? + parse_to_end(ps) + next_vocab.add(test_token) + except (UnexpectedToken, UnexpectedCharacters, DedentError): + pass + + if i + 1 < len(input_tokens): + assert input_tokens[i + 1] in next_vocab + else: + assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])