Skip to content

Commit

Permalink
Fix readme's function call examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 10, 2024
1 parent 752d8aa commit d01f53f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,31 @@ print(tokenizer.batch_decode(model.generate(**inputs,
```python
from formatron import schemas
from formatron.formatter import FormatterBuilder
from transformers import AutoModelForCausalLM
import transformers
from formatron.grammar_generators.json_generator import JsonGenerator
from formatron.integrations.transformers import create_formatter_logits_processor_list
import torch
@schemas.pydantic.callable_schema
def add(a: int, b: int, /, *, c: int):
return a + b + c

model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k",
device_map="cuda",
torch_dtype=torch.float16)
tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k")
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
<|user|>a is 1, b is 6 and c is 7. Generate a json containing them.<|end|>
<|assistant|>"""], return_tensors="pt").to("cuda")
f = FormatterBuilder()
f.append_line(f"{f.schema(add, JsonGenerator(), capture_name='json')}")
# TODO: find some models that work
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1,
max_new_tokens=100, logits_processor=logits_processor)))
print(logits_processor[0].formatters_captures)
# possible output:
# [{'json': 14}]
```
### Json Schema
You can use [pydantic's code generator](https://docs.pydantic.dev/latest/integrations/datamodel_code_generator/)
Expand Down Expand Up @@ -210,6 +227,12 @@ it makes maintaining and updating the pipeline painless in the long run.
### Support OpenAI or in general API-based LLM solutions
They don't support efficient logits masking per token, nullifying most benefits
of constrained decoding.
### Semantic Validation
Although constrained decoding can enforce certain formats
in generated text, they cannot guarantee that the output aligns
with the users' intention. In other words, if the model is inadequate
or the prompt is poorly written, it's possible to generate well-formatted
but meaningless output.
### Context-Sensitive Validation
Unfortunately, many formats require context-sensitive validation.
For example, two keys in a JSON object must not be equal to each other.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_transformers_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ class Goods(ClassSchema):
# possible output:
# [{'json': Goods(name='apples', price=14.4, remaining=14)}]

def test_readme_example4(snapshot):
from formatron import schemas
from formatron.formatter import FormatterBuilder
from formatron.grammar_generators.json_generator import JsonGenerator
@schemas.pydantic.callable_schema
def add(a: int, b: int, /, *, c: int):
return a + b + c

model = AutoModelForCausalLM.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k",
device_map="cuda",
torch_dtype=torch.float16)
tokenizer = transformers.AutoTokenizer.from_pretrained("NurtureAI/Meta-Llama-3-8B-Instruct-32k")
inputs = tokenizer(["""<|system|>
You are a helpful assistant.<|end|>
<|user|>a is 1, b is 6 and c is 7. Generate a json containing them.<|end|>
<|assistant|>"""], return_tensors="pt").to("cuda")
f = FormatterBuilder()
f.append_line(f"{f.schema(add, JsonGenerator(), capture_name='json')}")
logits_processor = create_formatter_logits_processor_list(tokenizer, f)
print(tokenizer.batch_decode(model.generate(**inputs, top_p=0.5, temperature=1,
max_new_tokens=100, logits_processor=logits_processor)))
print(logits_processor[0].formatters_captures)
# possible output:
# [{'json': 14}]

def test_transformers_batched_inference(snapshot):
f = FormatterBuilder()
Expand Down

0 comments on commit d01f53f

Please sign in to comment.