From d01f53f16107b5018da776a690c56af12f20599f Mon Sep 17 00:00:00 2001 From: Huanghe Date: Sat, 10 Aug 2024 17:37:27 -0500 Subject: [PATCH] Fix readme's function call examples --- README.md | 25 ++++++++++++++++++++++++- tests/test_transformers_integration.py | 24 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b36f2aec..33209fdc 100644 --- a/README.md +++ b/README.md @@ -172,14 +172,31 @@ print(tokenizer.batch_decode(model.generate(**inputs, ```python from formatron import schemas from formatron.formatter import FormatterBuilder +from transformers import AutoModelForCausalLM +import transformers from formatron.grammar_generators.json_generator import JsonGenerator +from formatron.integrations.transformers import create_formatter_logits_processor_list +import torch @schemas.pydantic.callable_schema def add(a: int, b: int, /, *, c: int): return a + b + c +model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k", + device_map="cuda", + torch_dtype=torch.float16) +tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k") +inputs = tokenizer(["""<|system|> +You are a helpful assistant.<|end|> +<|user|>a is 1, b is 6 and c is 7. Generate a json containing them.<|end|> +<|assistant|>"""], return_tensors="pt").to("cuda") f = FormatterBuilder() f.append_line(f"{f.schema(add, JsonGenerator(), capture_name='json')}") -# TODO: find some models that work +logits_processor = create_formatter_logits_processor_list(tokenizer, f) +print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1, + max_new_tokens=100, logits_processor=logits_processor))) +print(logits_processor[0].formatters_captures) +# possible output: +# [{'json': 14}] ``` ### Json Schema You can use [pydantic's code generator](https://docs.pydantic.dev/latest/integrations/datamodel_code_generator/) @@ -210,6 +227,12 @@ it makes maintaining and updating the pipeline painless in the long run. ### Support OpenAI or in general API-based LLM solutions They don't support efficient logits masking per token, nullifying most benefits of constrained decoding. +### Semantic Validation +Although constrained decoding can enforce certain formats +in generated text, they cannot guarantee that the output aligns +with the users' intention. In other words, if the model is inadequate +or the prompt is poorly written, it's possible to generate well-formatted +but meaningless output. ### Context-Sensitive Validation Unfortunately, many formats require context-sensitive validation. For example, two keys in a JSON object must not be equal to each other. diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py index 029875bd..d81ca056 100644 --- a/tests/test_transformers_integration.py +++ b/tests/test_transformers_integration.py @@ -91,6 +91,30 @@ class Goods(ClassSchema): # possible output: # [{'json': Goods(name='apples', price=14.4, remaining=14)}] +def test_readme_example4(snapshot): + from formatron import schemas + from formatron.formatter import FormatterBuilder + from formatron.grammar_generators.json_generator import JsonGenerator + @schemas.pydantic.callable_schema + def add(a: int, b: int, /, *, c: int): + return a + b + c + + model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k", + device_map="cuda", + torch_dtype=torch.float16) + tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k") + inputs = tokenizer(["""<|system|> + You are a helpful assistant.<|end|> + <|user|>a is 1, b is 6 and c is 7. Generate a json containing them.<|end|> + <|assistant|>"""], return_tensors="pt").to("cuda") + f = FormatterBuilder() + f.append_line(f"{f.schema(add, JsonGenerator(), capture_name='json')}") + logits_processor = create_formatter_logits_processor_list(tokenizer, f) + print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1, + max_new_tokens=100, logits_processor=logits_processor))) + print(logits_processor[0].formatters_captures) + # possible output: + # [{'json': 14}] def test_transformers_batched_inference(snapshot): f = FormatterBuilder()