Skip to content

Commit

Permalink
docs: Document and test PydanticOutputFunctionsParser (langchain-ai#1…
Browse files Browse the repository at this point in the history
…5759)

This PR adds documentation and testing to
`PydanticOutputFunctionsParser(OutputFunctionsParser)`.
  • Loading branch information
eyurtsev authored Jan 18, 2024
1 parent 3502a40 commit 5d8c147
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 2 deletions.
46 changes: 44 additions & 2 deletions libs/langchain/langchain/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,52 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An


class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object."""
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
OpenAI function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""

pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with."""
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""

@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Dict

import pytest
Expand All @@ -7,7 +8,9 @@

from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.pydantic_v1 import BaseModel


def test_json_output_function_parser() -> None:
Expand Down Expand Up @@ -134,3 +137,61 @@ def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:

with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])


def test_pydantic_output_functions_parser() -> None:
"""Test pydantic output functions parser."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "function_name",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)

class Model(BaseModel):
"""Test model."""

name: str
age: int

# Full output
parser = PydanticOutputFunctionsParser(pydantic_schema=Model)
result = parser.parse_result([chat_generation])
assert result == Model(name="value", age=10)


def test_pydantic_output_functions_parser_multiple_schemas() -> None:
"""Test that the parser works if providing multiple pydantic schemas."""

message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)

class Cookie(BaseModel):
"""Test model."""

name: str
age: int

class Dog(BaseModel):
"""Test model."""

species: str

# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
assert result == Cookie(name="value", age=10)

0 comments on commit 5d8c147

Please sign in to comment.