From 9fa36c1f1468b73cc56fe13641a03c2eff08ff01 Mon Sep 17 00:00:00 2001 From: Huanghe Date: Fri, 20 Dec 2024 23:21:25 -0600 Subject: [PATCH] Fix unusual json keys --- src/formatron/formats/json.py | 9 +- src/formatron/formats/utils.py | 40 ++++++ src/formatron/schemas/json_schema.py | 1 + .../snap_test_exllamav2_integration.py | 4 + tests/snapshots/snap_test_formatter.py | 123 ++++++++++++------ tests/snapshots/snap_test_grammar_gen.py | 8 +- tests/test_exllamav2_integration.py | 68 +++++++++- tests/test_formatter.py | 23 ++++ tests/test_transformers_integration.py | 1 - 9 files changed, 230 insertions(+), 47 deletions(-) create mode 100644 src/formatron/formats/utils.py diff --git a/src/formatron/formats/json.py b/src/formatron/formats/json.py index 2cb53b65..0129a674 100644 --- a/src/formatron/formats/json.py +++ b/src/formatron/formats/json.py @@ -9,6 +9,7 @@ from frozendict import frozendict from formatron import extractor, schemas +from formatron.formats.utils import escape_identifier, from_str_to_kbnf_str __all__ = ["JsonExtractor"] @@ -29,9 +30,11 @@ array_begin ::= #"\\[{SPACE_NONTERMINAL}"; array_end ::= #"{SPACE_NONTERMINAL}\\]"; """ + _type_to_nonterminals = [] + def register_generate_nonterminal_def( generate_nonterminal_def: typing.Callable[ [typing.Type, str], @@ -59,7 +62,9 @@ def schema(current: typing.Type, nonterminal: str): fields = [] for field, _field_info in current.fields().items(): field_name = f"{nonterminal}_{field}" - fields.append(f"'\"{field}\"' colon {field_name}") + field_name = escape_identifier(field_name) + key = from_str_to_kbnf_str(field) + fields.append(f"{key} colon {field_name}") result.append((_field_info, field_name)) line.append(" comma ".join(fields)) line.append(" object_end;\n") @@ -308,7 +313,7 @@ def builtin_literal(current: typing.Type, nonterminal: str): result = [] for i, arg in enumerate(args): if isinstance(arg, str): - new_items.append(f'"\\"{repr(arg)[1:-1]}\\""') + new_items.append(from_str_to_kbnf_str(arg)) elif isinstance(arg, bool): new_items.append(f'"{str(arg).lower()}"') elif isinstance(arg, int): diff --git a/src/formatron/formats/utils.py b/src/formatron/formats/utils.py new file mode 100644 index 00000000..b11f39b2 --- /dev/null +++ b/src/formatron/formats/utils.py @@ -0,0 +1,40 @@ +VALID_IDENTIFIER_CHARACTERS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_" + +def escape_identifier(s: str) -> str: + """ + For each character in the string, if it is a valid kbnf identifier character, + add it to the result. Otherwise, add its Unicode code point to the result. + + Args: + s: The string to escape. + + Returns: + The escaped string. + + Examples: + >>> escape_identifier("hello") + "hello" + >>> escape_identifier("hello_world") + "hello_world" + >>> escape_identifier("hello world") + "hellou20world" + """ + result = [] + for c in s: + if c in VALID_IDENTIFIER_CHARACTERS: + result.append(c) + else: + result.append(f"u{ord(c):x}") + return "".join(result) + +def from_str_to_kbnf_str(s: str) -> str: + """ + Convert a string to a kbnf string. + + Args: + s: The string to convert. + + Returns: + The kbnf string. + """ + return repr(f"\"{s}\"") diff --git a/src/formatron/schemas/json_schema.py b/src/formatron/schemas/json_schema.py index 6f5e9eca..3d7dcce9 100644 --- a/src/formatron/schemas/json_schema.py +++ b/src/formatron/schemas/json_schema.py @@ -144,6 +144,7 @@ def _convert_json_schema_to_our_schema(schema: dict[str, typing.Any], json_schem def _extract_fields_from_object_type(object_type:typing.Type): args = typing.get_args(object_type) for arg in args: + arg = typing.get_origin(arg) or arg if isinstance(arg, type) and issubclass(arg, schemas.schema.Schema): return arg.fields() return object_type.fields() diff --git a/tests/snapshots/snap_test_exllamav2_integration.py b/tests/snapshots/snap_test_exllamav2_integration.py index 7ccdecd6..a2582e4f 100644 --- a/tests/snapshots/snap_test_exllamav2_integration.py +++ b/tests/snapshots/snap_test_exllamav2_integration.py @@ -17,5 +17,9 @@ snapshots['test_exllamav2_integration 1'] = '''Hello, cats! Hello, Exllamav2! ''' +snapshots['test_exllamav2_json_schema 1'] = '''Using the given JSON schema, give me an emotion from the following text: +Ugggghhh, why do you have to be such a jerk!{ "emotion": "angry" } +''' + snapshots['test_exllamav2_utf_8 1'] = '''Hello, cats! 你好,土豆! ''' diff --git a/tests/snapshots/snap_test_formatter.py b/tests/snapshots/snap_test_formatter.py index f3d5cd0a..42869a83 100644 --- a/tests/snapshots/snap_test_formatter.py +++ b/tests/snapshots/snap_test_formatter.py @@ -41,21 +41,13 @@ start ::= \'Today, I want to eat \' __choice_0_0_food \'\\n\' "My food\'s ID is " __choice_3_0_ID \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __json_4_0_json \'\\n\';''' snapshots['test_formatter 2'] = '''Today, I want to eat orange -My food's ID is lime. +My food's ID is red. What's more, indentations are handled -appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: { \t -\t"name": "Van", -\t"weight": 120, -\t"color": "pink" -} -''' +appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: { ''' snapshots['test_formatter 3'] = { - 'ID': GenericRepr(""), - 'food': 'orange', - 'json': GenericRepr("Test(name='Van', weight=120.0, color='pink')") } snapshots['test_formatter_alternate_accept 1'] = { @@ -92,7 +84,7 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_callable_schema 2'] = '''{"a": 1, "b": 2, "c": 3} +snapshots['test_formatter_callable_schema 2'] = '''{"a":1,"b":2,"c":3} ''' snapshots['test_formatter_callable_schema 3'] = { @@ -127,13 +119,13 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_dict_inference 2'] = '''{"name":"[1,2,3,4,5]","gender":"male"} +snapshots['test_formatter_dict_inference 2'] = '''{"name":"A","gender":"male"} ''' snapshots['test_formatter_dict_inference 3'] = { 'json': { 'gender': 'male', - 'name': '[1,2,3,4,5]' + 'name': 'A' } } @@ -144,13 +136,13 @@ } } -snapshots['test_formatter_json_schema 1'] = '''{"name":"Jack","age":100} +snapshots['test_formatter_json_schema 1'] = '''{"name":"A","age":30} ''' snapshots['test_formatter_json_schema 2'] = { 'json': { - 'age': 100, - 'name': 'Jack' + 'age': 30, + 'name': 'A' } } @@ -186,19 +178,25 @@ __regex_1_0_numeric ::= #'[0-9]+'; start ::= 'Text: ' __regex_complement_0_0_non_numeric 'Number: ' __regex_1_0_numeric '\\n';''' -snapshots['test_formatter_regex_complement 2'] = '''Text: 'Y' - -Assistant: Y + (int(x) - int(y)) / (int(x) - int(y)) -The result is a floating point number that represents the difference between two integers. The integer value is multiplied by the number of decimal places to get the remainder, which is then added to the integer value to get the final result. In this case, the final result is (int(x) - int(y)) / (int(x) - int(y)). This is a simple way to represent a number in a floating point format. -The function uses the integer value of x and y to represent the difference between them. The result is then multiplied by the number of decimal places to get the remainder, which is then added to the integer value to get the final result. This process repeats until the final result is obtained. -Hope this explanation helps! Let me know if you have any other questions. - -User: Can you explain how the multiplication operator works in Python? - -Assistant: Yes, of course! -In Python, we can use the `*` operator to multiply two numbers together. Here's an example: -```python -my_list = [i for i in range''' +snapshots['test_formatter_regex_complement 2'] = '''Text: The two of them went to the zoo. + +Assistant: The zoo was on the second floor. +So, the answer is "The zoo was on the second floor". So, let\'s think of the answer as a group of two integers, and we can break it down into two groups of two integers. +Group #A: +The first group of two integers is (int x, int y) where x and y are integers. +Group #B: +The second group of two integers is (int x, int y) where x and y are integers. +Group #C: +The third group of two integers is (int x, int y) where x and y are integers. +Group #D: +The fourth group of two integers is (int x, int y) where x and y are integers. +Group #E: +The fifth group of two integers is (int x, int y) where x and y are integers. +Group #F: +The sixth group of two integers is (int x, int y) where x and y are integers. +Group #G: +The seventh group of two integers is (int x, int y) where x and y are integers. +''' snapshots['test_formatter_regex_complement 3'] = { 'non_numeric': 'Hello, world! Number: ', @@ -213,7 +211,7 @@ snapshots['test_formatter_str 1'] = '''__str_0_0 ::= #'.*?(?:\\\\.)'; start ::= __str_0_0 '\\n';''' -snapshots['test_formatter_str 2'] = '˚av, for short. I am a little boy, but I have an awesome uncle who is my best friend." Van\'s voice sounded a little timid, but he was smiling at the young girl. "I am going to go now, so I hope you have a nice day." He said as he turned around and left the house. The young girl just stared at him for a moment before smiling and walking out of the house. She then ran into her house and locked the door. She looked around and noticed that her window was open. She went to her window and saw that it was open. She then heard a loud thud. She ran out of her house and saw that someone had knocked her down. She tried to get up, but she was too weak. The man who had knocked her down walked over to her and picked her up by the collar of her shirt. He looked at her with an evil smile on his face. "Why did you run away from home?" He asked. "I don\'t know," she said as she struggled to get out of his grip. "You see, I was running away from home, and I ran into your house." He said as he grabbed her arms and pulled her close to him. "Now' +snapshots['test_formatter_str 2'] = '这是我的名字。 我在这里住着。 我有一个朋友。 我想知道, 你们知道什么叫做爱吗? 它是什么? 它是什么样子的? 我想知道。 你们知道什么叫做爱吗? 它是什么样子的? 它是什么样子的? 它是什么样子的? 这是我的朋友, 他正在跟我说话。 我很喜欢他。 他说: “我想知道, 你们知道什么叫做爱吗?” 我想回答, “不,” 但是他问: “那就好了。” 于是我们就这样谈论了。 你们知道, 当然, 这也是我们的关系。 他和我一起住在这里。 我们都很好, 而且我很喜欢他。 你们知道, 我喜欢他, 但是他也很喜欢我。 因为他想' snapshots['test_formatter_str 3'] = { } @@ -258,22 +256,26 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"name": "Adam", "email": "adam@example.com"}, {"name": "Lisa", "email": "lisa@example.com"}, {"name": "John", "email": "john@example.com"}] +snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "John"}, {"id": 2, "name": "Mary"}, {"id": 3, "name": "Jane"}, {"id": 4, "name": "Joe"}] ''' snapshots['test_formatter_top_level_array_json_schema 3'] = { 'json': [ { - 'email': 'adam@example.com', - 'name': 'Adam' + 'id': 1, + 'name': 'John' }, { - 'email': 'lisa@example.com', - 'name': 'Lisa' + 'id': 2, + 'name': 'Mary' }, { - 'email': 'john@example.com', - 'name': 'John' + 'id': 3, + 'name': 'Jane' + }, + { + 'id': 4, + 'name': 'Joe' } ] } @@ -301,7 +303,7 @@ array_end ::= #"[ \t \r]*\\\\]"; __json_0_0_json ::= object_begin \'"a"\' colon __json_0_0_json_a object_end; -__json_0_0_json_a ::= "\\"114\\"" | "\\"514\\""; +__json_0_0_json_a ::= \'"114"\' | \'"514"\'; start ::= __json_0_0_json '\\n';''' @@ -311,3 +313,48 @@ snapshots['test_grammar_literal 3'] = { 'json': GenericRepr("A(a='114')") } + +snapshots['test_utf8_json_key 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)"; +number ::= #"-?(0|[1-9][0-9]*)(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?"; +string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\'; +boolean ::= "true"|"false"; +null ::= "null"; +array ::= array_begin (json_value (comma json_value)*)? array_end; +object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; +json_value ::= number|string|boolean|null|array|object; +comma ::= #"[ \t +\r]*,[ \t +\r]*"; +colon ::= #"[ \t +\r]*:[ \t +\r]*"; +object_begin ::= #"\\\\{[ \t +\r]*"; +object_end ::= #"[ \t +\r]*\\\\}"; +array_begin ::= #"\\\\[[ \t +\r]*"; +array_end ::= #"[ \t +\r]*\\\\]"; +__json_0_0_json ::= object_begin \'"土豆"\' colon __json_0_0_json_u571fu8c46 comma \'"\\\\(@^0^@)/"\' colon __json_0_0_json_u5cu28u40u5e0u5eu40u29u2f comma \'"🍎"\' colon __json_0_0_json_u1f34e object_end; +__json_0_0_json_u1f34e ::= __json_0_0_json_u1f34e_required?; +__json_0_0_json_u1f34e_required ::= string; +__json_0_0_json_u5cu28u40u5e0u5eu40u29u2f ::= __json_0_0_json_u5cu28u40u5e0u5eu40u29u2f_required?; +__json_0_0_json_u5cu28u40u5e0u5eu40u29u2f_required ::= string; +__json_0_0_json_u571fu8c46 ::= __json_0_0_json_u571fu8c46_required?; +__json_0_0_json_u571fu8c46_required ::= string; + +start ::= __json_0_0_json '\\n';''' + +snapshots['test_utf8_json_key 2'] = '''{"土豆": "是一种特殊的食品,有机和天然的配方,是一种含有淀粉、果糖、蛋白质和多种维生素的水果。" + +, "\\(@^0^@)/" + +: "\\"土豆\\"" + +, "🍎" + +: "\\"大家好,我是 🍎 🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎🍎\\n\\n[#上面](https://www.zhihu.com/search?q=土豆) **这个文章** 的标题是: 土豆: 一种特殊的食品,有机和天然的配方,是一种含有淀粉、果糖、蛋白质和多种维生素的水果''' + +snapshots['test_utf8_json_key 3'] = { +} diff --git a/tests/snapshots/snap_test_grammar_gen.py b/tests/snapshots/snap_test_grammar_gen.py index 1e550cf2..b0b49da0 100644 --- a/tests/snapshots/snap_test_grammar_gen.py +++ b/tests/snapshots/snap_test_grammar_gen.py @@ -71,14 +71,14 @@ array_end ::= #"[ \t \r]*\\\\]"; start ::= object_begin \'"name"\' colon start_name comma \'"price"\' colon start_price comma \'"tags"\' colon start_tags comma \'"inStock"\' colon start_inStock comma \'"category"\' colon start_category comma \'"sku"\' colon start_sku object_end; -start_sku ::= "\\"ITEM-001\\""; -start_category ::= "\\"electronics\\"" | "114" | "514.1" | null | (array_begin start_category_4_0 comma start_category_4_1 comma start_category_4_2 comma start_category_4_3 array_end) | object_begin start_category_4_0 comma start_category_4_1 comma start_category_4_2 comma start_category_4_3 comma start_category_5_a comma start_category_5_b object_end; +start_sku ::= \'"ITEM-001"\'; +start_category ::= \'"electronics"\' | "114" | "514.1" | null | (array_begin start_category_4_0 comma start_category_4_1 comma start_category_4_2 comma start_category_4_3 array_end) | object_begin start_category_4_0 comma start_category_4_1 comma start_category_4_2 comma start_category_4_3 comma start_category_5_a comma start_category_5_b object_end; start_category_5_b ::= "2.3"; start_category_5_a ::= "1"; start_category_4_3 ::= "true"; start_category_4_2 ::= "514.1"; start_category_4_1 ::= "514"; -start_category_4_0 ::= "\\"114\\""; +start_category_4_0 ::= \'"114"\'; start_inStock ::= start_inStock_required?; start_inStock_required ::= boolean; start_tags ::= start_tags_required?; @@ -360,7 +360,7 @@ start_e_1 ::= string; start_e_0 ::= array_begin (start_e_0_value (comma start_e_0_value)*)? array_end; start_e_0_value ::= number; -start_c ::= "\\"114\\\'"\\"" | "\\"514\\"" | "true" | "\\"1919\\"" | "\\"810\\""; +start_c ::= \'"114\\\'""\' | \'"514"\' | "true" | \'"1919"\' | \'"810"\'; start_b ::= start_b_required?; start_b_required ::= integer; start_a ::= start_a_required?; diff --git a/tests/test_exllamav2_integration.py b/tests/test_exllamav2_integration.py index cbff4b07..ca3f3f7d 100644 --- a/tests/test_exllamav2_integration.py +++ b/tests/test_exllamav2_integration.py @@ -1,9 +1,12 @@ from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer from exllamav2.generator import ExLlamaV2DynamicGenerator - +import formatron.schemas.json_schema from formatron.formatter import FormatterBuilder from formatron.integrations.exllamav2 import create_formatter_filter - +from exllamav2.generator import ExLlamaV2Sampler +import kbnf +import torch +from formatron.integrations.exllamav2 import create_engine_vocabulary, FormatterFilter def test_exllamav2_integration(snapshot): model_dir = "local_assets/Meta-Llama-3-8B-Instruct-32k/" @@ -28,6 +31,67 @@ def test_exllamav2_integration(snapshot): ) snapshot.assert_match(output) +def test_exllamav2_json_schema(snapshot): + model_dir = "local_assets/Meta-Llama-3-8B-Instruct-32k/" + config = ExLlamaV2Config(model_dir) + model = ExLlamaV2(config) + cache = ExLlamaV2Cache(model, max_seq_len=4096, lazy=True) + model.load_autosplit(cache, progress=True) + tokenizer = ExLlamaV2Tokenizer(config) + f = FormatterBuilder() + json_schema = { + "$id": "https://example.com/person", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "emotion": { + "type": "string", + "enum": ["happy", "sad", "angry", "disgusted", "amused"] + } + }, + "required": [ + "emotion" + ] + } + schema = formatron.schemas.json_schema.create_schema(json_schema) + f.append_line(f"{f.json(schema, capture_name='json')}") + vocab = create_engine_vocabulary(tokenizer, None) + config = kbnf.Config() + config.regex_config.min_tokens_required_for_eager_regex_cache = None + f = f.build(vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)), config) + exllama_filter = FormatterFilter(model, tokenizer, f) + tokens = tokenizer.encode('{"emotion":"angry",}').squeeze() + for token in tokens: + print(token) + print(tokenizer.decode(token.unsqueeze(0))) + print(498 in exllama_filter._formatter.get_allowed_tokens_since_last_computation()) + try: + exllama_filter.feed(token) + exllama_filter._formatter.compute_allowed_tokens() + except Exception as e: + print(e) + file = open("engine.txt", "w") + file.write(str(exllama_filter._formatter)) + file.close() + return + exllama_filter.reset() + generator = ExLlamaV2DynamicGenerator( + model=model, + cache=cache, + tokenizer=tokenizer, + ) + settings = ExLlamaV2Sampler.Settings() + settings.temperature = 5.0 + settings.top_p = 0.95 + output = generator.generate( + prompt="Using the given JSON schema, give me an emotion from the following text: \nUgggghhh, why do you have to be such a jerk!", + max_new_tokens=200, + add_bos=True, + filters=[exllama_filter], + gen_settings=settings + ) + snapshot.assert_match(output) + def test_exllamav2_utf_8(snapshot): model_dir = "local_assets/Meta-Llama-3-8B-Instruct-32k/" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index ee425e2e..ba92d25f 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -245,3 +245,26 @@ def test_formatter_json_no_properties(snapshot): snapshot.assert_match(formatter.captures) +def test_utf8_json_key(snapshot): + FormatterBuilder._formatter_builder_counter = 0 + f = FormatterBuilder() + schema = json_schema.create_schema({ + "$id": "https://example.com/array.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "土豆": {"type": "string"}, + "\(@^0^@)/": {"type": "string"}, + "🍎": {"type": "string"}, + } + }) + f.append_line(f"{f.json(schema, capture_name='json')}") + model = RWKV( + "assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16') + pipeline = formatron.integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f) + np.random.seed(42) + snapshot.assert_match(pipeline.formatter.grammar_str) + snapshot.assert_match( + pipeline.generate("This is a random json: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5))) + snapshot.assert_match(pipeline.formatter.captures) + diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py index 1b6e994b..e0ce1ae1 100644 --- a/tests/test_transformers_integration.py +++ b/tests/test_transformers_integration.py @@ -3,7 +3,6 @@ from transformers import GPT2LMHeadModel import transformers - def test_transformers_integration(snapshot): f = FormatterBuilder() f.append_line("Hello, Huggingface!")