Skip to content

Commit

Permalink
Add basic parser-driven masking utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 5, 2023
1 parent e88f13b commit 2e07973
Show file tree
Hide file tree
Showing 4 changed files with 465 additions and 0 deletions.
108 changes: 108 additions & 0 deletions examples/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""An example illustrating parser-based masking."""
import math
import time

import torch
from lark import Lark
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
LogitsProcessor,
LogitsProcessorList,
set_seed,
)

from outlines.text.parsing import PartialPythonIndenter, copy_parser_state, parse_to_end

revision = None
checkpoint = "Salesforce/codegen-350M-mono"
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model = AutoModelForCausalLM.from_pretrained(
checkpoint, trust_remote_code=True, revision=revision
).to(device)

input_text = "def "
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)


class ParserLogitsProcessor(LogitsProcessor):
"""Bias invalid token scores according to a running parse state."""

def __init__(self):
pyparser = Lark.open_from_package(
"lark",
"python.lark",
["grammars"],
parser="lalr",
postlex=PartialPythonIndenter(),
start="file_input",
)
ip = pyparser.parse_interactive("")
self.parser_state = ip.parser_state
self.states_stack = [self.parser_state]
self.token_seq = None
self.token_idx = 0

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if self.token_seq is None:
self.token_seq = tokenizer.decode(input_ids[0])
self.token_idx = len(input_ids[0]) - 1
else:
self.token_idx += 1
self.token_seq += tokenizer.decode(input_ids[0][self.token_idx])

# Process the last sampled token
lex_state = self.parser_state.lexer.state
lex_state.text = self.token_seq

self.parser_state, partial_tokens = parse_to_end(self.parser_state)

print("Parsed:\n")
print(self.token_seq)

print(partial_tokens)

mask = torch.full_like(scores, -math.inf)

# Determine which tokens in the vocabulary are valid next tokens
# given the parser state.
#
# TODO: This is a very naive and slow approach. It could be done in
# parallel, but there are a few other approaches to try first, and
# those should dramatically reduce the amount of work done here.
t0 = time.perf_counter()
for test_token, token_id in tokenizer.vocab.items():
ps = copy_parser_state(self.parser_state)
ls = ps.lexer.state
ls.text = self.token_seq + test_token

try:
# TODO: The resulting states could possibly be reused?
parse_to_end(ps)
mask[0][token_id] = 0
except (UnexpectedToken, UnexpectedCharacters, DedentError):
pass

print(f"Next token masking duration: {time.perf_counter() - t0}")

return scores + mask


set_seed(20399)

outputs = model.generate(
inputs,
max_length=100,
temperature=0.1,
logits_processor=LogitsProcessorList([ParserLogitsProcessor()]),
renormalize_logits=True,
)

print(tokenizer.decode(outputs[0]))
240 changes: 240 additions & 0 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
from copy import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple

import regex
from lark.exceptions import (
LexError,
UnexpectedCharacters,
UnexpectedEOF,
UnexpectedToken,
)
from lark.indenter import PythonIndenter
from lark.lexer import BasicLexer, LexerState, Scanner, Token
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark.utils import get_regexp_width

if TYPE_CHECKING:
from lark.lexer import LexerThread
from lark.parsers.lalr_parser import ParserState


class PartialTokenEOF(UnexpectedEOF):
pass


class PartialScanner(Scanner):
def __init__(self, scanner: Scanner):
self.terminals = scanner.terminals
self.g_regex_flags = scanner.g_regex_flags
self.re_ = regex
self.use_bytes = scanner.use_bytes
self.match_whole = scanner.match_whole
self.allowed_types = scanner.allowed_types
self._mres = scanner._mres

def match(self, text, pos) -> Optional[Tuple[str, Optional[str], bool]]:
for mre in self._mres:
m = mre.match(text, pos=pos, partial=True)
if m: # and ((not m.partial) or m.endpos == len(text)):
return m.group(0), m.lastgroup, m.partial
return None


class PartialBasicLexer(BasicLexer):
def __init__(self, basic_lexer: BasicLexer):
self.re = regex
self.newline_types = basic_lexer.newline_types
self.ignore_types = basic_lexer.ignore_types
self.terminals = basic_lexer.terminals
self.user_callbacks = basic_lexer.user_callbacks
self.g_regex_flags = basic_lexer.g_regex_flags
self.use_bytes = basic_lexer.use_bytes
self.terminals_by_name = basic_lexer.terminals_by_name
self.callback = getattr(basic_lexer, "callback", None)

if basic_lexer._scanner is not None:
self._scanner: Optional[PartialScanner] = PartialScanner(
basic_lexer._scanner
)
else:
self._scanner = None

# This is used to determine the token type for partial matches
self.terminal_to_regex = {}
for name, terminal in self.terminals_by_name.items():
self.terminal_to_regex[name] = self.re.compile(
terminal.pattern.to_regexp(), self.g_regex_flags
)

def _build_scanner(self):
super()._build_scanner()
self._scanner = PartialScanner(self._scanner)

def partial_matches(self, value, type_):
partial_matches = set()

# TODO: It's unfortunate that we have to do this costly search (again).
# It would be better if we could *not* short-circuit the first time we
# scan in the call to `self.match`.
for term_name, term_regex in self.terminal_to_regex.items():
if term_name == type_:
# A standard lexed token result could actual indicate a partial
# match
regex_min, regex_max = get_regexp_width(term_regex.pattern)
if regex_min <= len(value) < regex_max:
partial_matches.add(term_name)
else:
m = term_regex.match(value, partial=True)
if m:
partial_matches.add(term_name)

return partial_matches

def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
line_ctr = lex_state.line_ctr
while line_ctr.char_pos < len(lex_state.text):
res = self.match(lex_state.text, line_ctr.char_pos)

if not res:
allowed = self.scanner.allowed_types - self.ignore_types
if not allowed:
allowed = {"<END-OF-FILE>"}
raise UnexpectedCharacters(
lex_state.text,
line_ctr.char_pos,
line_ctr.line,
line_ctr.column,
allowed=allowed,
token_history=lex_state.last_token and [lex_state.last_token],
state=parser_state,
terminals_by_name=self.terminals_by_name,
)

value, type_, partial = res

# Don't advance the lexing state if we're at the end; there could
# be ambiguous token types that aren't finished.
if line_ctr.char_pos + len(value) >= len(lex_state.text):
partial_matches = self.partial_matches(value, type_)
if partial_matches or partial:
raise PartialTokenEOF(partial_matches)

assert isinstance(self.callback, Dict)

if type_ not in self.ignore_types:
t = Token(
type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
)
line_ctr.feed(value, type_ in self.newline_types)
t.end_line = line_ctr.line
t.end_column = line_ctr.column
t.end_pos = line_ctr.char_pos
if t.type in self.callback:
t = self.callback[t.type](t)
if not isinstance(t, Token):
raise LexError(
"Callbacks must return a token (returned %r)" % t
)
lex_state.last_token = t
return t

if type_ in self.callback:
t2 = Token(
type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
)
self.callback[type_](t2)

line_ctr.feed(value, type_ in self.newline_types)

raise EOFError(self)


class PartialPythonIndenter(PythonIndenter):
"""An `Indenter` that doesn't reset its state every time `process` is called."""

def process(self, stream):
return self._process(stream)

def _process(self, stream):
for token in stream:
# These were previously *after* the `yield`, but that makes the
# state tracking unnecessarily convoluted.
if token.type in self.OPEN_PAREN_types:
self.paren_level += 1
elif token.type in self.CLOSE_PAREN_types:
self.paren_level -= 1
if self.paren_level < 0:
raise UnexpectedToken(token, [])

if token.type == self.NL_type:
yield from self.handle_NL(token)
else:
yield token

# while len(self.indent_level) > 1:
# self.indent_level.pop()
# yield Token(self.DEDENT_type, "")

def __copy__(self):
res = type(self)()
res.paren_level = self.paren_level
res.indent_level = copy(self.indent_level)
return res


def copy_lexer_thread(lexer_thread: "LexerThread") -> "LexerThread":
res = copy(lexer_thread)
res.lexer = copy(res.lexer)

if (
res.lexer.postlexer
and isinstance(res.lexer.postlexer, PythonIndenter)
and not isinstance(res.lexer.postlexer, PartialPythonIndenter)
):
# Patch these methods so that the post lexer keeps its state
# XXX: This won't really work in generality.
postlexer = PartialPythonIndenter()
postlexer.paren_level = res.lexer.postlexer.paren_level
postlexer.indent_level = res.lexer.postlexer.indent_level
res.lexer.postlexer = postlexer

# Patch/replace the lexer objects so that they support partial matches
lexer = res.lexer.lexer
if not isinstance(lexer.root_lexer, PartialBasicLexer):
lexer.root_lexer = PartialBasicLexer(lexer.root_lexer)

basic_lexers = res.lexer.lexer.lexers
for idx, lexer in basic_lexers.items():
basic_lexers[idx] = PartialBasicLexer(lexer)

res.lexer.postlexer = copy(res.lexer.postlexer)

return res


def copy_parser_state(parser_state: "ParserState") -> "ParserState":
res = copy(parser_state)
res.lexer = copy_lexer_thread(res.lexer)

return res


def copy_ip(ip: "InteractiveParser") -> "InteractiveParser":
res = copy(ip)
res.lexer_thread = copy_lexer_thread(res.lexer_thread)
return res


def parse_to_end(parser_state: "ParserState") -> Tuple["ParserState", Set[str]]:
"""Continue parsing from the current parse state and return partial next tokens."""

parser_state = copy_parser_state(parser_state)

expected_next_tokens: Set[str] = set()
try:
for token in parser_state.lexer.lex(parser_state):
parser_state.feed_token(token)
except PartialTokenEOF as e:
expected_next_tokens = e.expected

return parser_state, expected_next_tokens
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ test = [
"transformers",
"coverage[toml]>=5.1",
"diff-cover",
"lark",
"regex",
]

[project.urls]
Expand Down Expand Up @@ -87,6 +89,8 @@ module = [
"tiktoken.*",
"torch",
"transformers.*",
"lark.*",
"regex.*",
]
ignore_missing_imports = true

Expand Down
Loading

0 comments on commit 2e07973

Please sign in to comment.