From c85586069a63d66fa78a98f64ddc798926f7ea59 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 5 Jul 2023 17:08:05 -0500 Subject: [PATCH] Make map_partial_states_to_vocab return vocab indices and filter matches --- outlines/text/parsing.py | 30 +++++++-- tests/text/test_parsing.py | 121 ++++++++++++++++++++++++++----------- 2 files changed, 109 insertions(+), 42 deletions(-) diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py index 63a7a4917..385bc63d3 100644 --- a/outlines/text/parsing.py +++ b/outlines/text/parsing.py @@ -1,7 +1,17 @@ from collections import ChainMap, defaultdict from copy import copy from itertools import chain -from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Iterable, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + DefaultDict, + Dict, + Iterable, + Optional, + Set, + Tuple, +) import interegular import regex @@ -329,7 +339,10 @@ def map_partial_states_to_vocab( vocabulary: Iterable[str], terminals_to_fsms_map: Dict[str, FSM], map_to_antecedents: bool = False, -) -> DefaultDict[PartialParseState, Set[str]]: + partial_match_filter: Callable[ + [str, Optional[int], Tuple[int, ...]], bool + ] = lambda *args: True, +) -> DefaultDict[PartialParseState, Set[int]]: """Construct a map from partial parse states to the vocabulary elements that start in those states. Parameters @@ -342,13 +355,18 @@ def map_partial_states_to_vocab( When ``True``, return a map with keys that are the antecedent partial parse states. In other words, this is a map that can be used to determine valid next tokens given a parse state. + partial_match_filter + A callable that determines which partial matches to keep. The first + argument is the string being match, the rest are the unpacked partial + match return values of `find_partial_matches`. """ pstate_to_vocab = defaultdict(set) for symbol_name, fsm in terminals_to_fsms_map.items(): - for tk in vocabulary: - for _, states in find_partial_matches(fsm, tk): - pstate_to_vocab[(symbol_name, states[0])].add(tk) + for i, vocab_string in enumerate(vocabulary): + for end_idx, state_seq in find_partial_matches(fsm, vocab_string): + if partial_match_filter(vocab_string, end_idx, state_seq): + pstate_to_vocab[(symbol_name, state_seq[0])].add(i) if not map_to_antecedents: return pstate_to_vocab @@ -373,7 +391,7 @@ def map_partial_states_to_vocab( # A version of `pstate_to_vocab` that is keyed on states that *transition to* # the original keys of `pstate_to_vocab`. - _pstate_to_vocab: DefaultDict[PartialParseState, Set[str]] = defaultdict(set) + _pstate_to_vocab: DefaultDict[PartialParseState, Set[int]] = defaultdict(set) for pstate, vocab in pstate_to_vocab.items(): for next_pstate in rev_ts_pstate_to_substates[pstate]: _pstate_to_vocab[next_pstate] |= vocab diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py index 17acd4b8d..097d8c177 100644 --- a/tests/text/test_parsing.py +++ b/tests/text/test_parsing.py @@ -1,3 +1,6 @@ +import random +import re + import interegular import pytest from lark import Lark @@ -155,7 +158,7 @@ def test_partial_match(): } -def test_partial_match_preprocessing(): +def test_map_partial_states_to_vocab_python(): pyparser = Lark.open_from_package( "lark", "python.lark", @@ -171,28 +174,20 @@ def test_partial_match_preprocessing(): k: v for k, v in symbol_names_and_fsms.items() if k in test_symbols } - vocabulary = {"d", "e", "ef foo", "f ", " "} + vocabulary = ["d", "e", "ef foo", "f ", " "] pstate_to_vocab = map_partial_states_to_vocab( vocabulary, symbol_names_and_fsms, False ) assert dict(pstate_to_vocab) == { - ("NAME", 1): {"d", "e", "ef foo", "f "}, - ("NAME", 2): {"d", "e", "ef foo", "f "}, - ("DEF", 1): { - "d", - }, - ("DEF", 2): {"e", "ef foo"}, - ("DEF", 3): { - "f ", - }, - ("__IGNORE_0", 1): { - " ", - }, - ("__IGNORE_0", 2): { - " ", - }, + ("__IGNORE_0", 2): {4}, + ("__IGNORE_0", 1): {4}, + ("NAME", 2): {0, 1, 2, 3}, + ("NAME", 1): {0, 1, 2, 3}, + ("DEF", 1): {0}, + ("DEF", 2): {1, 2}, + ("DEF", 3): {3}, } pstate_to_vocab = map_partial_states_to_vocab( @@ -200,25 +195,15 @@ def test_partial_match_preprocessing(): ) assert dict(pstate_to_vocab) == { - ("DEF", 1): {"e", "ef foo"}, - ("DEF", 2): { - "f ", - }, - ("DEF", 0): { - "d", - }, - ("NAME", 1): {"d", "e", "ef foo", "f "}, - ("NAME", 2): {"d", "e", "ef foo", "f "}, - ("NAME", 0): {"d", "e", "ef foo", "f "}, - ("__IGNORE_0", 1): { - " ", - }, - ("__IGNORE_0", 2): { - " ", - }, - ("__IGNORE_0", 0): { - " ", - }, + ("__IGNORE_0", 1): {4}, + ("__IGNORE_0", 2): {4}, + ("__IGNORE_0", 0): {4}, + ("NAME", 1): {0, 1, 2, 3}, + ("NAME", 2): {0, 1, 2, 3}, + ("NAME", 0): {0, 1, 2, 3}, + ("DEF", 0): {0}, + ("DEF", 1): {1, 2}, + ("DEF", 2): {3}, } @@ -278,3 +263,67 @@ def test_parse_from_partial_match(): ) with pytest.raises(UnexpectedToken): parse_to_end(parser_state) + + +def test_map_partial_states_to_vocab_regex(): + regex_string = r"(([0-9]+)?([.]([0-9]*)?)?|[.][0-9]+)" + regex_pattern = interegular.parse_pattern(regex_string) + regex_fsm = regex_pattern.simplify().to_fsm() + + vocabulary = ["1.", "2", "3.", ".", ".80", "42", "1a", " ", "0", "a", "b", "$"] + + # We want the vocabulary strings to entirely match the regex--not just the + # prefixes of the vocabulary strings + def partial_match_filter(string, end_idx, state_seq): + if end_idx is not None and end_idx < len(string) - 1: + return False + return True + + pstate_to_vocab = map_partial_states_to_vocab( + vocabulary, {"FLOAT": regex_fsm}, True, partial_match_filter + ) + + assert dict(pstate_to_vocab) == { + ("FLOAT", 0): {0, 1, 2, 3, 4, 5, 8}, + ("FLOAT", 3): {0, 1, 2, 3, 4, 5, 8}, + ("FLOAT", 1): {0, 1, 2, 3, 4, 5, 8}, + ("FLOAT", 5): {1, 5, 8}, + ("FLOAT", 7): {1, 5, 8}, + ("FLOAT", 4): {1, 5, 8}, + ("FLOAT", 6): {1, 5, 8}, + ("FLOAT", 2): {1, 5, 8}, + } + + pstate_to_vocab = {k: tuple(v) for k, v in pstate_to_vocab.items()} + + random.seed(24080) + + # Start at the initial state + pstate = ("FLOAT", regex_fsm.initial) + + sample_seq = "" + + for i in range(10): + next_support = pstate_to_vocab[pstate] + + (next_sample_idx,) = random.sample(next_support, 1) + + next_sample = vocabulary[next_sample_idx] + sample_seq += next_sample + + # Parse the entire sampled sequence/string + # TODO: We could continue from the previous parse state, but this is + # easier for now and only for demonstration purposes. + partial_matches = find_partial_matches(regex_fsm, sample_seq) + + # Use the/a longest match + pmatch = max(partial_matches, key=lambda x: x[0] if x[0] is not None else -1) + + # Create the next state + pstate = (pstate[0], pmatch[1][-1]) + + # TODO: We could check if the FSM is done (i.e. in an final/accept + # state) and end the sampling loop + + # Make sure the whole thing matches the regex + assert re.fullmatch(regex_string, sample_seq) is not None