Skip to content

Commit

Permalink
Updates for Pydantic 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 3, 2023
1 parent 759ce89 commit e88f13b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
5 changes: 4 additions & 1 deletion outlines/text/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def validate(validator, result):

@validate.register(BaseModelType)
def validate_pydantic(validator, result):
return validator.parse_raw(result)
if hasattr(validator, "model_validate_json"):
return validator.model_validate_json(result)
else: # pragma: no cover
return validator.parse_raw(result)


@validate.register(FunctionType)
Expand Down
12 changes: 9 additions & 3 deletions outlines/text/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,21 @@ def get_schema_pydantic(model: type[BaseModel]):
if not type(model) == type(BaseModel):
raise TypeError("The `schema` filter only applies to Pydantic models.")

raw_schema = model.schema()
definitions = raw_schema.get("definitions", None)
if hasattr(model, "model_json_schema"):
def_key = "$defs"
raw_schema = model.model_json_schema()
else: # pragma: no cover
def_key = "definitions"
raw_schema = model.schema()

definitions = raw_schema.get(def_key, None)
schema = parse_pydantic_schema(raw_schema, definitions)

return json.dumps(schema, indent=2)


def parse_pydantic_schema(raw_schema, definitions):
"""Parse the output of `Basemodel.schema()`.
"""Parse the output of `Basemodel.[schema|model_json_schema]()`.
This recursively follows the references to other schemas in case
of nested models. Other schemas are stored under the "definitions"
Expand Down

0 comments on commit e88f13b

Please sign in to comment.