From a2bf491946e39d8b2b82a66d860dd33c4dfb9a56 Mon Sep 17 00:00:00 2001 From: Huanghe Date: Sat, 26 Oct 2024 18:46:30 -0500 Subject: [PATCH] Regex complement --- pyproject.toml | 2 +- src/formatron/extractor.py | 4 +- src/formatron/formats/regex.py | 31 ++++++++++- src/formatron/formatter.py | 19 ++++++- tests/snapshots/snap_test_formatter.py | 72 ++++++++++++++++++-------- tests/test_formatter.py | 27 ++++++++++ 6 files changed, 128 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a40e8d8..313e6c91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ ] description = "Formatron empowers everyone to control the output format of language models with minimal overhead." readme = "README.md" -dependencies = ["pydantic>=2,<3","kbnf>=0.3.10,<0.4.0", "general-sam>=1,<2", "jsonschema>=4,<5", "frozendict>=2,<3"] +dependencies = ["pydantic>=2,<3","kbnf>=0.3.13,<0.4.0", "general-sam>=1,<2", "jsonschema>=4,<5", "frozendict>=2,<3"] license = {file = "LICENSE"} keywords = ["deep learning", "language model", "guided generation", "structured generation","constrained decoding"] requires-python = ">=3.10" diff --git a/src/formatron/extractor.py b/src/formatron/extractor.py index 2e350af3..2f1f4e5e 100644 --- a/src/formatron/extractor.py +++ b/src/formatron/extractor.py @@ -199,4 +199,6 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, str]]: @property def kbnf_definition(self) -> str: - return f"{self.nonterminal} ::= #substrs{repr(self._string)};" \ No newline at end of file + return f"{self.nonterminal} ::= #substrs{repr(self._string)};" + + diff --git a/src/formatron/formats/regex.py b/src/formatron/formats/regex.py index 8d862e0f..0d513290 100644 --- a/src/formatron/formats/regex.py +++ b/src/formatron/formats/regex.py @@ -40,4 +40,33 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, re.Match | None] @property def kbnf_definition(self) -> str: - return f"{self.nonterminal} ::= #{repr(self._regex.pattern)};" \ No newline at end of file + return f"{self.nonterminal} ::= #{repr(self._regex.pattern)};" + + +class RegexComplementExtractor(NonterminalExtractor): + """ + An extractor that extracts data by matching a regex complement. + """ + + def __init__(self, regex: str, capture_name: str, nonterminal: str): + """ + Initialize the regex complement extractor. + """ + super().__init__(nonterminal, capture_name) + self._regex = re.compile(regex) + + def extract(self, input_str: str) -> typing.Optional[tuple[str, str]]: + """ + Extract the data by matching a regex complement. + + Specifically, the string until the first character in the first match of the regex is extracted if there is a match, + or the entire string is extracted if there is no match. + """ + matched = self._regex.search(input_str) + if not matched: + return "", input_str + return input_str[matched.span()[0]:], input_str[:matched.span()[0]] + + @property + def kbnf_definition(self) -> str: + return f"{self.nonterminal} ::= #ex{repr(self._regex.pattern)};" \ No newline at end of file diff --git a/src/formatron/formatter.py b/src/formatron/formatter.py index ea3808b2..b7ddd4d9 100644 --- a/src/formatron/formatter.py +++ b/src/formatron/formatter.py @@ -13,7 +13,7 @@ from formatron.formats.json import JsonExtractor from formatron.schemas.schema import Schema from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor -from formatron.formats.regex import RegexExtractor +from formatron.formats.regex import RegexComplementExtractor, RegexExtractor @@ -412,6 +412,21 @@ def regex(self, regex: str, *, capture_name: str = None) -> RegexExtractor: """ return self._add_extractor("regex", lambda nonterminal: RegexExtractor(regex, capture_name, nonterminal)) + + def regex_complement(self, regex: str, *, capture_name: str = None) -> RegexComplementExtractor: + """ + Create a regex complement extractor. This is roughly equivalent to 'extract a string that does not match the given regex anywhere'. + + Check out the RegexComplementExtractor docs for more details. + + Args: + regex: The regular expression for extraction. + capture_name: The capture name of the extractor, or `None` if the extractor does not capture. + Returns: + The regex complement extractor. + """ + return self._add_extractor("regex_complement", + lambda nonterminal: RegexComplementExtractor(regex, capture_name, nonterminal)) def str(self, *, stop: typing.Union[str, list[str]] = None, capture_name: typing.Optional[str] = None) -> Extractor: @@ -452,6 +467,8 @@ def substr(self, string: str, *, capture_name: str = None, extract_empty_substri lambda nonterminal: SubstringExtractor(string, capture_name, nonterminal, extract_empty_substring=extract_empty_substring)) + + def build(self, vocabulary: kbnf.Vocabulary, decode: typing.Callable[[list[int]], str], engine_config: kbnf.Config = None) -> Formatter: diff --git a/tests/snapshots/snap_test_formatter.py b/tests/snapshots/snap_test_formatter.py index e1d05ba4..56030a3c 100644 --- a/tests/snapshots/snap_test_formatter.py +++ b/tests/snapshots/snap_test_formatter.py @@ -40,17 +40,17 @@ start ::= \'Today, I want to eat \' __choice_0_0_food \'\\n\' "My food\'s ID is " __choice_3_0_ID \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __json_4_0_json \'\\n\';''' -snapshots['test_formatter 2'] = '''Today, I want to eat banana -My food's ID is not. +snapshots['test_formatter 2'] = '''Today, I want to eat orange +My food's ID is soo. What's more, indentations are handled -appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: {"name":"Van","weight":1.4,"color":"pink"} +appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: { "name" : "Van" ,"weight" : 1.4, "color" : "pink"} ''' snapshots['test_formatter 3'] = { - 'ID': GenericRepr(""), - 'food': 'banana', + 'ID': GenericRepr(""), + 'food': 'orange', 'json': GenericRepr("Test(name='Van', weight=1.4, color='pink')") } @@ -123,23 +123,23 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_dict_inference 2'] = '''{"name":"MyName","gender":"male"} +snapshots['test_formatter_dict_inference 2'] = '''{"name":"Tom","gender":"male"} ''' snapshots['test_formatter_dict_inference 3'] = { 'json': { 'gender': 'male', - 'name': 'MyName' + 'name': 'Tom' } } -snapshots['test_formatter_json_schema 1'] = '''{"name":"peter","age":30} +snapshots['test_formatter_json_schema 1'] = '''{"name":"value","age":0} ''' snapshots['test_formatter_json_schema 2'] = { 'json': { - 'age': 30, - 'name': 'peter' + 'age': 0, + 'name': 'value' } } @@ -171,10 +171,39 @@ start ::= __json_0_0_json '\\n';''' +snapshots['test_formatter_regex_complement 1'] = '''__regex_complement_0_0_non_numeric ::= #ex'[0-9]'; +__regex_1_0_numeric ::= #'[0-9]+'; +start ::= 'Text: ' __regex_complement_0_0_non_numeric 'Number: ' __regex_1_0_numeric '\\n';''' + +snapshots['test_formatter_regex_complement 2'] = '''Text: I got $l worth of money from $b. +A: $l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth of money from $b. +$l worth of money from $b is $l worth''' + +snapshots['test_formatter_regex_complement 3'] = { + 'non_numeric': 'Hello, world! Number: ', + 'numeric': GenericRepr("") +} + +snapshots['test_formatter_regex_complement 4'] = { + 'non_numeric': 'Hello, world! Number: ', + 'numeric': GenericRepr("") +} + snapshots['test_formatter_str 1'] = '''__str_0_0 ::= #'.*?(?:\\\\.)'; start ::= __str_0_0 '\\n';''' -snapshots['test_formatter_str 2'] = '˝ she said. "I\'m from the International Union of Mine, Metal and Air-craft Workers. I\'m here to ask you for a place to stay for a while. I can\'t pay you, but I\'ll pay you back in installments. You\'ll be living with me and my family." I was about to protest, but she had already pulled out a business card. "I\'m sorry, but you have to come with me. My name is Angela, and I\'ll be your maid." She handed me the card and I thanked her. I walked over to her car and got in. She drove off. I looked at the address on the card. It was a town called Pleasant Valley. I knew that the building was the old prison camp. I wondered what happened to them all. The rest of the day was spent driving around looking for any sign of the men who were imprisoned there. After what seemed like an eternity, we arrived at the old prison camp. It was a long walk to the prison itself, but it was worth it. I looked around and saw that it was a very beautiful place. There were many trees and flowers everywhere. I looked around for a sign of the guards, but they weren\'t there. I started' +snapshots['test_formatter_str 2'] = '请将上述英文翻译成中文,并且返回正确的翻译方式为英文,因为它不在这个问题中。Answer: 我的名字是Van。请问你想问什么?(我不知道你想问什么,所以我只能回答你)。这是一个简单的问题,你可以告诉我你想问什么。(我不知道你想问什么,所以我只能回答你)。请告诉我你想要了解的是什么?(我不知道你想要了解的是什么,所以我只能回答你)。这是一个有趣的问题,可以让我们聊天。(我不知道你想要聊什么,所以我只能回答你)。请告诉我你想要聊什么?(我不知道你想要聊什么,所以我只能回答你)。这是一个有趣的问题,可以让我们聊天。(我不知道你' snapshots['test_formatter_str 3'] = { } @@ -210,34 +239,31 @@ \r]*"; array_end ::= #"[ \t \r]*\\\\]"; -__json_0_0_json ::= array_begin (__json_0_0_json_value (comma __json_0_0_json_value)*)? array_end; -__json_0_0_json_value ::= object_begin \'"id"\' colon __json_0_0_json_value_id comma \'"name"\' colon __json_0_0_json_value_name comma \'"active"\' colon __json_0_0_json_value_active object_end; -__json_0_0_json_value_active ::= __json_0_0_json_value_active_required?; -__json_0_0_json_value_active_required ::= boolean; -__json_0_0_json_value_name ::= string; -__json_0_0_json_value_id ::= integer; +__json_0_0_json_min ::= __json_0_0_json_item; +__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item array_end; +__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item array_end; +__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item array_end; +__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item array_end; +__json_0_0_json_item ::= json_value; start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "John", "active": true}, {"id": 2, "name": "Bob", "active": true}, {"id": 3, "name": "Charlie", "active": true}] +snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"}, {"id": 3, "name": "Mary"}] ''' snapshots['test_formatter_top_level_array_json_schema 3'] = { 'json': [ { - 'active': True, 'id': 1, 'name': 'John' }, { - 'active': True, 'id': 2, - 'name': 'Bob' + 'name': 'Jane' }, { - 'active': True, 'id': 3, - 'name': 'Charlie' + 'name': 'Mary' } ] } diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 898b5eb2..65d937cc 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -199,3 +199,30 @@ def test_formatter_alternate_accept(snapshot): snapshot.assert_match(formatter.captures) + +def test_formatter_regex_complement(snapshot): + FormatterBuilder._formatter_builder_counter = 0 + f = FormatterBuilder() + f.append_str(f"Text: {f.regex_complement('[0-9]', capture_name='non_numeric')}") + f.append_line(f"Number: {f.regex('[0-9]+', capture_name='numeric')}") + + model = RWKV( + "assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16') + pipeline = formatron.integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f) + + np.random.seed(42) + snapshot.assert_match(pipeline.formatter.grammar_str) + snapshot.assert_match( + pipeline.generate("Here's some text followed by an integer: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5))) + snapshot.assert_match(pipeline.formatter.captures) + + # Test with manual input + formatter = pipeline.formatter + formatter.reset() + + input_text = "Text: Hello, world! Number: 42\n" + for char in input_text: + formatter.accept_bytes(char.encode('utf-8')) + + snapshot.assert_match(formatter.captures) +