Skip to content

Commit

Permalink
Make map_partial_states_to_vocab return vocab indices and filter matches
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and rlouf committed Jul 6, 2023
1 parent ced6f50 commit c855860
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 42 deletions.
30 changes: 24 additions & 6 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from collections import ChainMap, defaultdict
from copy import copy
from itertools import chain
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Iterable, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Optional,
Set,
Tuple,
)

import interegular
import regex
Expand Down Expand Up @@ -329,7 +339,10 @@ def map_partial_states_to_vocab(
vocabulary: Iterable[str],
terminals_to_fsms_map: Dict[str, FSM],
map_to_antecedents: bool = False,
) -> DefaultDict[PartialParseState, Set[str]]:
partial_match_filter: Callable[
[str, Optional[int], Tuple[int, ...]], bool
] = lambda *args: True,
) -> DefaultDict[PartialParseState, Set[int]]:
"""Construct a map from partial parse states to the vocabulary elements that start in those states.
Parameters
Expand All @@ -342,13 +355,18 @@ def map_partial_states_to_vocab(
When ``True``, return a map with keys that are the antecedent partial
parse states. In other words, this is a map that can be used to
determine valid next tokens given a parse state.
partial_match_filter
A callable that determines which partial matches to keep. The first
argument is the string being match, the rest are the unpacked partial
match return values of `find_partial_matches`.
"""

pstate_to_vocab = defaultdict(set)
for symbol_name, fsm in terminals_to_fsms_map.items():
for tk in vocabulary:
for _, states in find_partial_matches(fsm, tk):
pstate_to_vocab[(symbol_name, states[0])].add(tk)
for i, vocab_string in enumerate(vocabulary):
for end_idx, state_seq in find_partial_matches(fsm, vocab_string):
if partial_match_filter(vocab_string, end_idx, state_seq):
pstate_to_vocab[(symbol_name, state_seq[0])].add(i)

if not map_to_antecedents:
return pstate_to_vocab
Expand All @@ -373,7 +391,7 @@ def map_partial_states_to_vocab(

# A version of `pstate_to_vocab` that is keyed on states that *transition to*
# the original keys of `pstate_to_vocab`.
_pstate_to_vocab: DefaultDict[PartialParseState, Set[str]] = defaultdict(set)
_pstate_to_vocab: DefaultDict[PartialParseState, Set[int]] = defaultdict(set)
for pstate, vocab in pstate_to_vocab.items():
for next_pstate in rev_ts_pstate_to_substates[pstate]:
_pstate_to_vocab[next_pstate] |= vocab
Expand Down
121 changes: 85 additions & 36 deletions tests/text/test_parsing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random
import re

import interegular
import pytest
from lark import Lark
Expand Down Expand Up @@ -155,7 +158,7 @@ def test_partial_match():
}


def test_partial_match_preprocessing():
def test_map_partial_states_to_vocab_python():
pyparser = Lark.open_from_package(
"lark",
"python.lark",
Expand All @@ -171,54 +174,36 @@ def test_partial_match_preprocessing():
k: v for k, v in symbol_names_and_fsms.items() if k in test_symbols
}

vocabulary = {"d", "e", "ef foo", "f ", " "}
vocabulary = ["d", "e", "ef foo", "f ", " "]

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, False
)

assert dict(pstate_to_vocab) == {
("NAME", 1): {"d", "e", "ef foo", "f "},
("NAME", 2): {"d", "e", "ef foo", "f "},
("DEF", 1): {
"d",
},
("DEF", 2): {"e", "ef foo"},
("DEF", 3): {
"f ",
},
("__IGNORE_0", 1): {
" ",
},
("__IGNORE_0", 2): {
" ",
},
("__IGNORE_0", 2): {4},
("__IGNORE_0", 1): {4},
("NAME", 2): {0, 1, 2, 3},
("NAME", 1): {0, 1, 2, 3},
("DEF", 1): {0},
("DEF", 2): {1, 2},
("DEF", 3): {3},
}

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, True
)

assert dict(pstate_to_vocab) == {
("DEF", 1): {"e", "ef foo"},
("DEF", 2): {
"f ",
},
("DEF", 0): {
"d",
},
("NAME", 1): {"d", "e", "ef foo", "f "},
("NAME", 2): {"d", "e", "ef foo", "f "},
("NAME", 0): {"d", "e", "ef foo", "f "},
("__IGNORE_0", 1): {
" ",
},
("__IGNORE_0", 2): {
" ",
},
("__IGNORE_0", 0): {
" ",
},
("__IGNORE_0", 1): {4},
("__IGNORE_0", 2): {4},
("__IGNORE_0", 0): {4},
("NAME", 1): {0, 1, 2, 3},
("NAME", 2): {0, 1, 2, 3},
("NAME", 0): {0, 1, 2, 3},
("DEF", 0): {0},
("DEF", 1): {1, 2},
("DEF", 2): {3},
}


Expand Down Expand Up @@ -278,3 +263,67 @@ def test_parse_from_partial_match():
)
with pytest.raises(UnexpectedToken):
parse_to_end(parser_state)


def test_map_partial_states_to_vocab_regex():
regex_string = r"(([0-9]+)?([.]([0-9]*)?)?|[.][0-9]+)"
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm = regex_pattern.simplify().to_fsm()

vocabulary = ["1.", "2", "3.", ".", ".80", "42", "1a", " ", "0", "a", "b", "$"]

# We want the vocabulary strings to entirely match the regex--not just the
# prefixes of the vocabulary strings
def partial_match_filter(string, end_idx, state_seq):
if end_idx is not None and end_idx < len(string) - 1:
return False
return True

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, {"FLOAT": regex_fsm}, True, partial_match_filter
)

assert dict(pstate_to_vocab) == {
("FLOAT", 0): {0, 1, 2, 3, 4, 5, 8},
("FLOAT", 3): {0, 1, 2, 3, 4, 5, 8},
("FLOAT", 1): {0, 1, 2, 3, 4, 5, 8},
("FLOAT", 5): {1, 5, 8},
("FLOAT", 7): {1, 5, 8},
("FLOAT", 4): {1, 5, 8},
("FLOAT", 6): {1, 5, 8},
("FLOAT", 2): {1, 5, 8},
}

pstate_to_vocab = {k: tuple(v) for k, v in pstate_to_vocab.items()}

random.seed(24080)

# Start at the initial state
pstate = ("FLOAT", regex_fsm.initial)

sample_seq = ""

for i in range(10):
next_support = pstate_to_vocab[pstate]

(next_sample_idx,) = random.sample(next_support, 1)

next_sample = vocabulary[next_sample_idx]
sample_seq += next_sample

# Parse the entire sampled sequence/string
# TODO: We could continue from the previous parse state, but this is
# easier for now and only for demonstration purposes.
partial_matches = find_partial_matches(regex_fsm, sample_seq)

# Use the/a longest match
pmatch = max(partial_matches, key=lambda x: x[0] if x[0] is not None else -1)

# Create the next state
pstate = (pstate[0], pmatch[1][-1])

# TODO: We could check if the FSM is done (i.e. in an final/accept
# state) and end the sampling loop

# Make sure the whole thing matches the regex
assert re.fullmatch(regex_string, sample_seq) is not None

0 comments on commit c855860

Please sign in to comment.