Skip to content

Commit

Permalink
Regex complement
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Oct 26, 2024
1 parent 30e5fdc commit a2bf491
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
]
description = "Formatron empowers everyone to control the output format of language models with minimal overhead."
readme = "README.md"
dependencies = ["pydantic>=2,<3","kbnf>=0.3.10,<0.4.0", "general-sam>=1,<2", "jsonschema>=4,<5", "frozendict>=2,<3"]
dependencies = ["pydantic>=2,<3","kbnf>=0.3.13,<0.4.0", "general-sam>=1,<2", "jsonschema>=4,<5", "frozendict>=2,<3"]
license = {file = "LICENSE"}
keywords = ["deep learning", "language model", "guided generation", "structured generation","constrained decoding"]
requires-python = ">=3.10"
Expand Down
4 changes: 3 additions & 1 deletion src/formatron/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,6 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, str]]:

@property
def kbnf_definition(self) -> str:
return f"{self.nonterminal} ::= #substrs{repr(self._string)};"
return f"{self.nonterminal} ::= #substrs{repr(self._string)};"


31 changes: 30 additions & 1 deletion src/formatron/formats/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,33 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, re.Match | None]

@property
def kbnf_definition(self) -> str:
return f"{self.nonterminal} ::= #{repr(self._regex.pattern)};"
return f"{self.nonterminal} ::= #{repr(self._regex.pattern)};"


class RegexComplementExtractor(NonterminalExtractor):
"""
An extractor that extracts data by matching a regex complement.
"""

def __init__(self, regex: str, capture_name: str, nonterminal: str):
"""
Initialize the regex complement extractor.
"""
super().__init__(nonterminal, capture_name)
self._regex = re.compile(regex)

def extract(self, input_str: str) -> typing.Optional[tuple[str, str]]:
"""
Extract the data by matching a regex complement.
Specifically, the string until the first character in the first match of the regex is extracted if there is a match,
or the entire string is extracted if there is no match.
"""
matched = self._regex.search(input_str)
if not matched:
return "", input_str
return input_str[matched.span()[0]:], input_str[:matched.span()[0]]

@property
def kbnf_definition(self) -> str:
return f"{self.nonterminal} ::= #ex{repr(self._regex.pattern)};"
19 changes: 18 additions & 1 deletion src/formatron/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from formatron.formats.json import JsonExtractor
from formatron.schemas.schema import Schema
from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
from formatron.formats.regex import RegexExtractor
from formatron.formats.regex import RegexComplementExtractor, RegexExtractor



Expand Down Expand Up @@ -412,6 +412,21 @@ def regex(self, regex: str, *, capture_name: str = None) -> RegexExtractor:
"""
return self._add_extractor("regex",
lambda nonterminal: RegexExtractor(regex, capture_name, nonterminal))

def regex_complement(self, regex: str, *, capture_name: str = None) -> RegexComplementExtractor:
"""
Create a regex complement extractor. This is roughly equivalent to 'extract a string that does not match the given regex anywhere'.
Check out the RegexComplementExtractor docs for more details.
Args:
regex: The regular expression for extraction.
capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
Returns:
The regex complement extractor.
"""
return self._add_extractor("regex_complement",
lambda nonterminal: RegexComplementExtractor(regex, capture_name, nonterminal))

def str(self, *, stop: typing.Union[str, list[str]] = None,
capture_name: typing.Optional[str] = None) -> Extractor:
Expand Down Expand Up @@ -452,6 +467,8 @@ def substr(self, string: str, *, capture_name: str = None, extract_empty_substri
lambda nonterminal: SubstringExtractor(string, capture_name, nonterminal,
extract_empty_substring=extract_empty_substring))



def build(self, vocabulary: kbnf.Vocabulary,
decode: typing.Callable[[list[int]], str],
engine_config: kbnf.Config = None) -> Formatter:
Expand Down
72 changes: 49 additions & 23 deletions tests/snapshots/snap_test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@
start ::= \'Today, I want to eat \' __choice_0_0_food \'\\n\' "My food\'s ID is " __choice_3_0_ID \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __json_4_0_json \'\\n\';'''

snapshots['test_formatter 2'] = '''Today, I want to eat banana
My food's ID is not.
snapshots['test_formatter 2'] = '''Today, I want to eat orange
My food's ID is soo.
What's more, indentations
are handled
appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: {"name":"Van","weight":1.4,"color":"pink"}
appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: { "name" : "Van" ,"weight" : 1.4, "color" : "pink"}
'''

snapshots['test_formatter 3'] = {
'ID': GenericRepr("<re.Match object; span=(0, 3), match='not'>"),
'food': 'banana',
'ID': GenericRepr("<re.Match object; span=(0, 3), match='soo'>"),
'food': 'orange',
'json': GenericRepr("Test(name='Van', weight=1.4, color='pink')")
}

Expand Down Expand Up @@ -123,23 +123,23 @@
start ::= __json_0_0_json '\\n';'''

snapshots['test_formatter_dict_inference 2'] = '''{"name":"MyName","gender":"male"}
snapshots['test_formatter_dict_inference 2'] = '''{"name":"Tom","gender":"male"}
'''

snapshots['test_formatter_dict_inference 3'] = {
'json': {
'gender': 'male',
'name': 'MyName'
'name': 'Tom'
}
}

snapshots['test_formatter_json_schema 1'] = '''{"name":"peter","age":30}
snapshots['test_formatter_json_schema 1'] = '''{"name":"value","age":0}
'''

snapshots['test_formatter_json_schema 2'] = {
'json': {
'age': 30,
'name': 'peter'
'age': 0,
'name': 'value'
}
}

Expand Down Expand Up @@ -171,10 +171,39 @@
start ::= __json_0_0_json '\\n';'''

snapshots['test_formatter_regex_complement 1'] = '''__regex_complement_0_0_non_numeric ::= #ex'[0-9]';
__regex_1_0_numeric ::= #'[0-9]+';
start ::= 'Text: ' __regex_complement_0_0_non_numeric 'Number: ' __regex_1_0_numeric '\\n';'''

snapshots['test_formatter_regex_complement 2'] = '''Text: I got $l worth of money from $b.
A: $l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth of money from $b.
$l worth of money from $b is $l worth'''

snapshots['test_formatter_regex_complement 3'] = {
'non_numeric': 'Hello, world! Number: ',
'numeric': GenericRepr("<re.Match object; span=(0, 2), match='42'>")
}

snapshots['test_formatter_regex_complement 4'] = {
'non_numeric': 'Hello, world! Number: ',
'numeric': GenericRepr("<re.Match object; span=(0, 2), match='42'>")
}

snapshots['test_formatter_str 1'] = '''__str_0_0 ::= #'.*?(?:\\\\.)';
start ::= __str_0_0 '\\n';'''

snapshots['test_formatter_str 2'] = '˝ she said. "I\'m from the International Union of Mine, Metal and Air-craft Workers. I\'m here to ask you for a place to stay for a while. I can\'t pay you, but I\'ll pay you back in installments. You\'ll be living with me and my family." I was about to protest, but she had already pulled out a business card. "I\'m sorry, but you have to come with me. My name is Angela, and I\'ll be your maid." She handed me the card and I thanked her. I walked over to her car and got in. She drove off. I looked at the address on the card. It was a town called Pleasant Valley. I knew that the building was the old prison camp. I wondered what happened to them all. The rest of the day was spent driving around looking for any sign of the men who were imprisoned there. After what seemed like an eternity, we arrived at the old prison camp. It was a long walk to the prison itself, but it was worth it. I looked around and saw that it was a very beautiful place. There were many trees and flowers everywhere. I looked around for a sign of the guards, but they weren\'t there. I started'
snapshots['test_formatter_str 2'] = '请将上述英文翻译成中文,并且返回正确的翻译方式为英文,因为它不在这个问题中。Answer: 我的名字是Van。请问你想问什么?(我不知道你想问什么,所以我只能回答你)。这是一个简单的问题,你可以告诉我你想问什么。(我不知道你想问什么,所以我只能回答你)。请告诉我你想要了解的是什么?(我不知道你想要了解的是什么,所以我只能回答你)。这是一个有趣的问题,可以让我们聊天。(我不知道你想要聊什么,所以我只能回答你)。请告诉我你想要聊什么?(我不知道你想要聊什么,所以我只能回答你)。这是一个有趣的问题,可以让我们聊天。(我不知道你'

snapshots['test_formatter_str 3'] = {
}
Expand Down Expand Up @@ -210,34 +239,31 @@
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
__json_0_0_json ::= array_begin (__json_0_0_json_value (comma __json_0_0_json_value)*)? array_end;
__json_0_0_json_value ::= object_begin \'"id"\' colon __json_0_0_json_value_id comma \'"name"\' colon __json_0_0_json_value_name comma \'"active"\' colon __json_0_0_json_value_active object_end;
__json_0_0_json_value_active ::= __json_0_0_json_value_active_required?;
__json_0_0_json_value_active_required ::= boolean;
__json_0_0_json_value_name ::= string;
__json_0_0_json_value_id ::= integer;
__json_0_0_json_min ::= __json_0_0_json_item;
__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item array_end;
__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item array_end;
__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item array_end;
__json_0_0_json ::= array_begin __json_0_0_json_min comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item comma __json_0_0_json_item array_end;
__json_0_0_json_item ::= json_value;
start ::= __json_0_0_json '\\n';'''

snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "John", "active": true}, {"id": 2, "name": "Bob", "active": true}, {"id": 3, "name": "Charlie", "active": true}]
snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"}, {"id": 3, "name": "Mary"}]
'''

snapshots['test_formatter_top_level_array_json_schema 3'] = {
'json': [
{
'active': True,
'id': 1,
'name': 'John'
},
{
'active': True,
'id': 2,
'name': 'Bob'
'name': 'Jane'
},
{
'active': True,
'id': 3,
'name': 'Charlie'
'name': 'Mary'
}
]
}
Expand Down
27 changes: 27 additions & 0 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,30 @@ def test_formatter_alternate_accept(snapshot):

snapshot.assert_match(formatter.captures)


def test_formatter_regex_complement(snapshot):
FormatterBuilder._formatter_builder_counter = 0
f = FormatterBuilder()
f.append_str(f"Text: {f.regex_complement('[0-9]', capture_name='non_numeric')}")
f.append_line(f"Number: {f.regex('[0-9]+', capture_name='numeric')}")

model = RWKV(
"assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16')
pipeline = formatron.integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f)

np.random.seed(42)
snapshot.assert_match(pipeline.formatter.grammar_str)
snapshot.assert_match(
pipeline.generate("Here's some text followed by an integer: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5)))
snapshot.assert_match(pipeline.formatter.captures)

# Test with manual input
formatter = pipeline.formatter
formatter.reset()

input_text = "Text: Hello, world! Number: 42\n"
for char in input_text:
formatter.accept_bytes(char.encode('utf-8'))

snapshot.assert_match(formatter.captures)

0 comments on commit a2bf491

Please sign in to comment.