Skip to content

Commit

Permalink
Add vocabulary pre-parsing tools
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 5, 2023
1 parent 2e07973 commit a034c78
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 3 deletions.
181 changes: 179 additions & 2 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from collections import ChainMap, defaultdict
from copy import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple
from itertools import chain
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Iterable, Optional, Set, Tuple

import interegular
import regex
from interegular.fsm import FSM, anything_else
from interegular.patterns import Unsupported
from lark import Lark
from lark.exceptions import (
LexError,
UnexpectedCharacters,
Expand All @@ -11,11 +17,14 @@
from lark.indenter import PythonIndenter
from lark.lexer import BasicLexer, LexerState, Scanner, Token
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark.parsers.lalr_parser import ParserState
from lark.utils import get_regexp_width

if TYPE_CHECKING:
from lark.lexer import LexerThread
from lark.parsers.lalr_parser import ParserState


PartialParseState = Tuple[str, int]


class PartialTokenEOF(UnexpectedEOF):
Expand Down Expand Up @@ -238,3 +247,171 @@ def parse_to_end(parser_state: "ParserState") -> Tuple["ParserState", Set[str]]:
expected_next_tokens = e.expected

return parser_state, expected_next_tokens


def find_partial_matches(
fsm: FSM, input_string: str
) -> Set[Tuple[Optional[int], Tuple[int, ...]]]:
"""Find the states in the finite state machine `fsm` that accept `input_string`.
Returns
-------
A set of tuples corresponding to each valid starting state in the FSM.
The first element of each tuple contains either ``None`` or an integer
indicating the position in `input_string` at which the FSM terminated. The
second element is a tuple of the states visited during execution of the
FSM.
"""
if len(input_string) == 0 or input_string[0] not in fsm.alphabet:
return set()

trans_key = fsm.alphabet[input_string[0]]

# TODO: We could probably memoize this easily (i.e. no need to recompute
# paths shared by different starting states)
def _partial_match(trans: int) -> Optional[Tuple[Optional[int], Tuple[int, ...]]]:
fsm_map = ChainMap({fsm.initial: trans}, fsm.map)
state = fsm.initial
accepted_states: Tuple[int, ...] = ()

for i, symbol in enumerate(input_string):
if anything_else in fsm.alphabet and symbol not in fsm.alphabet:
symbol = anything_else

trans_key = fsm.alphabet[symbol]

if not (state in fsm_map and trans_key in fsm_map[state]):
if state in fsm.finals:
i -= 1
break
return None

state = fsm_map[state][trans_key]

accepted_states += (state,)

terminated = state in fsm.finals
if not terminated and state == fsm.initial:
return None

return None if not terminated else i, accepted_states

res = set()
for s_now, trans in fsm.map.items():
if trans_key in trans:
path = _partial_match(trans)
if path is not None:
res.add(path)

return res


def terminals_to_fsms(lp: Lark) -> Dict[str, FSM]:
"""Construct a ``dict`` mapping terminal symbol names to their finite state machines."""

symbol_names_and_fsms = {}
for terminal in lp.terminals:
pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
# TODO: Use `pyparser.terminals[0].pattern.flags`?
try:
fsm = pattern.to_fsm()
except Unsupported:
fsm = None

symbol_names_and_fsms[terminal.name] = fsm

return symbol_names_and_fsms


def map_partial_states_to_vocab(
vocabulary: Iterable[str],
terminals_to_fsms_map: Dict[str, FSM],
map_to_antecedents: bool = False,
) -> DefaultDict[PartialParseState, Set[str]]:
"""Construct a map from partial parse states to the vocabulary elements that start in those states.
Parameters
----------
vocabulary
The vocabulary composed of strings.
terminals_to_fsms_map
Terminal symbol names mapped to FSMs, as provided by `terminals_to_fsms`.
map_to_antecedents
When ``True``, return a map with keys that are the antecedent partial
parse states. In other words, this is a map that can be used to
determine valid next tokens given a parse state.
"""

pstate_to_vocab = defaultdict(set)
for symbol_name, fsm in terminals_to_fsms_map.items():
for tk in vocabulary:
for _, states in find_partial_matches(fsm, tk):
pstate_to_vocab[(symbol_name, states[0])].add(tk)

if not map_to_antecedents:
return pstate_to_vocab

# Partially parsed states to next/transition states (for the same terminal symbol)
ts_pstate_to_substates = dict(
chain.from_iterable(
[
((symbol_name, s), {(symbol_name, v) for v in ts.values()})
for s, ts in fsm.map.items()
]
for symbol_name, fsm in terminals_to_fsms_map.items()
)
)

# Reverse the map
# TODO: We could construct this more directly.
rev_ts_pstate_to_substates = defaultdict(set)
for pstate, to_pstates in ts_pstate_to_substates.items():
for to_pstate in to_pstates:
rev_ts_pstate_to_substates[to_pstate].add(pstate)

# A version of `pstate_to_vocab` that is keyed on states that *transition to*
# the original keys of `pstate_to_vocab`.
_pstate_to_vocab: DefaultDict[PartialParseState, Set[str]] = defaultdict(set)
for pstate, vocab in pstate_to_vocab.items():
for next_pstate in rev_ts_pstate_to_substates[pstate]:
_pstate_to_vocab[next_pstate] |= vocab

return _pstate_to_vocab


def terminals_to_lalr_states(lp: Lark) -> DefaultDict[str, Set[int]]:
from lark.parsers.lalr_analysis import Shift

terminals_to_states = defaultdict(set)
parse_table = lp.parser.parser.parser.parse_table
for state, tokens_to_ops in parse_table.states.items():
for token, op in tokens_to_ops.items():
if op[0] == Shift:
# `op[1]` is the state we shift to when `token` is observed
terminals_to_states[token].add(op[1])

return terminals_to_states


def create_pmatch_parser_states(
lp: Lark,
terminals_to_states: Dict[str, Set[int]],
term_type: str,
ptoken: str,
pmatch: Tuple[int, Tuple[int, ...]],
) -> Tuple[ParserState, ...]:
from lark import Token
from lark.parsers.lalr_parser import ParseConf, ParserState

parse_table = lp.parser.parser.parser.parse_table
parse_conf = ParseConf(parse_table, lp._callbacks, lp.options.start[0])
lexer_thread = lp.parser._make_lexer_thread(ptoken)
lexer_state = lexer_thread.state
lexer_state.line_ctr.char_pos = pmatch[0] + 1
lexer_state.last_token = Token(term_type, "")
res = tuple(
ParserState(parse_conf, lexer_thread, [state], None)
for state in terminals_to_states[term_type]
)
return res
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"diff-cover",
"lark",
"regex",
"interegular",
]

[project.urls]
Expand Down Expand Up @@ -91,6 +92,7 @@ module = [
"transformers.*",
"lark.*",
"regex.*",
"interegular.*",
]
ignore_missing_imports = true

Expand Down
166 changes: 165 additions & 1 deletion tests/text/test_parsing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import interegular
import pytest
from lark import Lark
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken

from outlines.text.parsing import PartialPythonIndenter, copy_parser_state, parse_to_end
from outlines.text.parsing import (
PartialPythonIndenter,
copy_parser_state,
create_pmatch_parser_states,
find_partial_matches,
map_partial_states_to_vocab,
parse_to_end,
terminals_to_fsms,
terminals_to_lalr_states,
)


def test_parse_to_end():
Expand Down Expand Up @@ -111,3 +122,156 @@ def test_sequential_parse_example():
assert input_tokens[i + 1] in next_vocab
else:
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])


def test_partial_match():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm = name_pattern.to_fsm()

def_pattern = interegular.parse_pattern("def")
def_fsm = def_pattern.to_fsm()

assert find_partial_matches(def_fsm, "def") == {(2, (1, 2, 3))}
assert find_partial_matches(def_fsm, "de") == {(None, (1, 2))}
assert find_partial_matches(def_fsm, "d") == {(None, (1,))}
assert find_partial_matches(def_fsm, "") == set()
assert find_partial_matches(def_fsm, "df") == set()
assert find_partial_matches(def_fsm, "ef") == {(1, (2, 3))}
assert find_partial_matches(def_fsm, "e") == {(None, (2,))}
assert find_partial_matches(def_fsm, "f") == {(0, (3,))}
assert find_partial_matches(def_fsm, "ef foo") == {(1, (2, 3))}

# This string has a `DEF` token in it, but should ultimately not lex one
assert find_partial_matches(def_fsm, "defb") == {(2, (1, 2, 3))}

# `NAME` can have multiple start states for this input
assert find_partial_matches(name_fsm, "d") == {(0, (1,)), (0, (2,))}
# Not this case
assert find_partial_matches(name_fsm, "1d") == {(1, (2, 2))}

assert find_partial_matches(name_fsm, "blah") == {
(3, (1, 2, 2, 2)),
(3, (2, 2, 2, 2)),
}


def test_partial_match_preprocessing():
pyparser = Lark.open_from_package(
"lark",
"python.lark",
["grammars"],
parser="lalr",
postlex=PartialPythonIndenter(),
start="file_input",
)

symbol_names_and_fsms = terminals_to_fsms(pyparser)
test_symbols = {"DEF", "NAME", "__IGNORE_0"}
symbol_names_and_fsms = {
k: v for k, v in symbol_names_and_fsms.items() if k in test_symbols
}

vocabulary = {"d", "e", "ef foo", "f ", " "}

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, False
)

assert dict(pstate_to_vocab) == {
("NAME", 1): {"d", "e", "ef foo", "f "},
("NAME", 2): {"d", "e", "ef foo", "f "},
("DEF", 1): {
"d",
},
("DEF", 2): {"e", "ef foo"},
("DEF", 3): {
"f ",
},
("__IGNORE_0", 1): {
" ",
},
("__IGNORE_0", 2): {
" ",
},
}

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, True
)

assert dict(pstate_to_vocab) == {
("DEF", 1): {"e", "ef foo"},
("DEF", 2): {
"f ",
},
("DEF", 0): {
"d",
},
("NAME", 1): {"d", "e", "ef foo", "f "},
("NAME", 2): {"d", "e", "ef foo", "f "},
("NAME", 0): {"d", "e", "ef foo", "f "},
("__IGNORE_0", 1): {
" ",
},
("__IGNORE_0", 2): {
" ",
},
("__IGNORE_0", 0): {
" ",
},
}


def test_parse_from_partial_match():
"""Make sure we can continue parsing from an FSM-based partial match."""
pyparser = Lark(
r"""
start: funcdef
funcdef: "def" name "(" ")" ":"
%ignore /[\t \f]+/ // WS
!name: NAME | "match" | "case"
NAME: /[^\W\d]\w*/
""",
parser="lalr",
postlex=PartialPythonIndenter(),
)

terminals_to_states = terminals_to_lalr_states(pyparser)
symbol_names_and_fsms = terminals_to_fsms(pyparser)

term_type = "DEF"
def_fsm = symbol_names_and_fsms[term_type]

# TODO FIXME: This is broken, and it's a bug in `lark`'s Python grammar!
# ptoken = "defx"

ptoken = "ef foo"
pmatch = find_partial_matches(def_fsm, ptoken)
first_pmatch = next(pm for pm in pmatch if pm[0] is not None)
(parser_state,) = create_pmatch_parser_states(
pyparser, terminals_to_states, term_type, ptoken, first_pmatch
)
new_parser_state, expected_next_tokens = parse_to_end(parser_state)
assert expected_next_tokens == {"NAME"}

ptoken = "ef foo():"
pmatch = find_partial_matches(def_fsm, ptoken)
first_pmatch = next(pm for pm in pmatch if pm[0] is not None)
(parser_state,) = create_pmatch_parser_states(
pyparser, terminals_to_states, term_type, ptoken, first_pmatch
)
new_parser_state, expected_next_tokens = parse_to_end(parser_state)
assert not expected_next_tokens

ptoken = "ef ("
pmatch = find_partial_matches(def_fsm, ptoken)
first_pmatch = next(pm for pm in pmatch if pm[0] is not None)
(parser_state,) = create_pmatch_parser_states(
pyparser, terminals_to_states, term_type, ptoken, first_pmatch
)
with pytest.raises(UnexpectedToken):
parse_to_end(parser_state)

0 comments on commit a034c78

Please sign in to comment.