diff --git a/docs/reference/json.md b/docs/reference/json.md index 85e1a846a..24b606bcb 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -50,13 +50,11 @@ print(result) Outlines provides [custom Pydantic types](types.md) so you do not have to write regular expressions for common types, such as phone numbers or zip codes. -## Using a JSON Schema +## Using a String of a JSON Schema Instead of a Pydantic model you can pass a string that represents a [JSON Schema](https://json-schema.org/) specification to `generate.json`: ```python -from pydantic import BaseModel - from outlines import models from outlines import generate @@ -82,6 +80,35 @@ print(result) # User(name="John", last_name="Doe", id=11) ``` +## Using a Dictionary of a JSON Schema + +You can also pass in dictionary that represents a [JSON Schema](https://json-schema.org/) specification to `generate.json`: + +```python +from outlines import models +from outlines import generate + +model = models.transformers("mistralai/Mistral-7B-v0.1") + +schema_dict = { + "title": "User", + "type": "object", + "properties": { + "name": {"type": "string"}, + "last_name": {"type": "string"}, + "id": {"type": "integer"} + } +} + +generator = generate.json(model, schema_dict) +result = generator( + "Create a user profile with the fields name, last_name and id" +) +print(result) +# User(name="John", last_name="Doe", id=11) +``` + + ## From a function's signature Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`: diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 218853d5f..0203ddefc 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -248,13 +248,24 @@ def test_llamacpp_json_dict(model): prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" schema_dict = { - "title": "spam", - "type": "object", "properties": { - "foo": {"type": "boolean"}, - "bar": {"type": "string", "maxLength": 4}, + "user_id": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "User Id", + }, + "name": { + "additionalProperties": {"type": "integer"}, + "title": "Name", + "type": "object", + }, + "password": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "title": "Password", + }, }, - "required": ["foo", "bar"], + "required": ["user_id", "name", "password"], + "title": "UserPydantic", + "type": "object", } result = generate.json(model, schema_dict, whitespace_pattern="")( diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 9ef6344f0..d9754ecdf 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -448,13 +448,24 @@ def test_transformers_json_dict(): prompt = "Output some JSON " schema_dict = { - "title": "spam", - "type": "object", "properties": { - "foo": {"type": "integer"}, - "bar": {"type": "string", "maxLength": 4}, + "user_id": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "User Id", + }, + "name": { + "additionalProperties": {"type": "integer"}, + "title": "Name", + "type": "object", + }, + "password": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "title": "Password", + }, }, - "required": ["foo", "bar"], + "required": ["user_id", "name", "password"], + "title": "UserPydantic", + "type": "object", } rng = torch.Generator() diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index d28543894..e4d703a18 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -235,13 +235,24 @@ def test_vllm_json_dict(model): prompt = "Output some JSON. " schema_dict = { - "title": "spam", - "type": "object", "properties": { - "foo": {"type": "boolean"}, - "bar": {"type": "string", "maxLength": 4}, + "user_id": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "User Id", + }, + "name": { + "additionalProperties": {"type": "integer"}, + "title": "Name", + "type": "object", + }, + "password": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "title": "Password", + }, }, - "required": ["foo", "bar"], + "required": ["user_id", "name", "password"], + "title": "UserPydantic", + "type": "object", } sampling_params = SamplingParams(temperature=0)