Skip to content

Commit

Permalink
change schema dict
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed May 30, 2024
1 parent 944df28 commit f7580f6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 18 deletions.
33 changes: 30 additions & 3 deletions docs/reference/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ print(result)

Outlines provides [custom Pydantic types](types.md) so you do not have to write regular expressions for common types, such as phone numbers or zip codes.

## Using a JSON Schema
## Using a String of a JSON Schema

Instead of a Pydantic model you can pass a string that represents a [JSON Schema](https://json-schema.org/) specification to `generate.json`:

```python
from pydantic import BaseModel

from outlines import models
from outlines import generate

Expand All @@ -82,6 +80,35 @@ print(result)
# User(name="John", last_name="Doe", id=11)
```

## Using a Dictionary of a JSON Schema

You can also pass in dictionary that represents a [JSON Schema](https://json-schema.org/) specification to `generate.json`:

```python
from outlines import models
from outlines import generate

model = models.transformers("mistralai/Mistral-7B-v0.1")

schema_dict = {
"title": "User",
"type": "object",
"properties": {
"name": {"type": "string"},
"last_name": {"type": "string"},
"id": {"type": "integer"}
}
}

generator = generate.json(model, schema_dict)
result = generator(
"Create a user profile with the fields name, last_name and id"
)
print(result)
# User(name="John", last_name="Doe", id=11)
```


## From a function's signature

Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`:
Expand Down
21 changes: 16 additions & 5 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,24 @@ def test_llamacpp_json_dict(model):
prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n"

schema_dict = {
"title": "spam",
"type": "object",
"properties": {
"foo": {"type": "boolean"},
"bar": {"type": "string", "maxLength": 4},
"user_id": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"title": "User Id",
},
"name": {
"additionalProperties": {"type": "integer"},
"title": "Name",
"type": "object",
},
"password": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
"title": "Password",
},
},
"required": ["foo", "bar"],
"required": ["user_id", "name", "password"],
"title": "UserPydantic",
"type": "object",
}

result = generate.json(model, schema_dict, whitespace_pattern="")(
Expand Down
21 changes: 16 additions & 5 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,24 @@ def test_transformers_json_dict():
prompt = "Output some JSON "

schema_dict = {
"title": "spam",
"type": "object",
"properties": {
"foo": {"type": "integer"},
"bar": {"type": "string", "maxLength": 4},
"user_id": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"title": "User Id",
},
"name": {
"additionalProperties": {"type": "integer"},
"title": "Name",
"type": "object",
},
"password": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
"title": "Password",
},
},
"required": ["foo", "bar"],
"required": ["user_id", "name", "password"],
"title": "UserPydantic",
"type": "object",
}

rng = torch.Generator()
Expand Down
21 changes: 16 additions & 5 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,24 @@ def test_vllm_json_dict(model):
prompt = "Output some JSON. "

schema_dict = {
"title": "spam",
"type": "object",
"properties": {
"foo": {"type": "boolean"},
"bar": {"type": "string", "maxLength": 4},
"user_id": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"title": "User Id",
},
"name": {
"additionalProperties": {"type": "integer"},
"title": "Name",
"type": "object",
},
"password": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
"title": "Password",
},
},
"required": ["foo", "bar"],
"required": ["user_id", "name", "password"],
"title": "UserPydantic",
"type": "object",
}

sampling_params = SamplingParams(temperature=0)
Expand Down

0 comments on commit f7580f6

Please sign in to comment.