Skip to content

Commit

Permalink
Fix schema inference structured generation (#994)
Browse files Browse the repository at this point in the history
* fix: converting ModelMetaClass to model_json_schema

* fix: allow for adding optional literal format json to instructor to make methods more inter-changable

* docs: emphasize usability with any framework

* fix: first check if structured_output has been defined

* Update docs/sections/how_to_guides/advanced/structured_generation.md

Co-authored-by: Agus <[email protected]>

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
davidberenstein1957 and plaguss authored Sep 23, 2024
1 parent c7deafa commit d7e61b5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
11 changes: 6 additions & 5 deletions docs/sections/how_to_guides/advanced/structured_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ These were some simple examples, but one can see the options this opens.

## Instructor

When working with model providers behind an API, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM], so you can work with the following LLMs: [`OpenAILLM`][distilabel.llms.OpenAILLM], [`AzureOpenAILLM`][distilabel.llms.AzureOpenAILLM], [`CohereLLM`][distilabel.llms.CohereLLM], [`GroqLLM`][distilabel.llms.GroqLLM], [`LiteLLM`][distilabel.llms.LiteLLM] and [`MistralLLM`][distilabel.llms.MistralLLM].
For other LLM providers behind APIs, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM].

!!! Note
For `instructor` integration to work you may need to install the corresponding dependencies:
Expand All @@ -155,14 +155,15 @@ class User(BaseModel):

And then we provide that schema to the `structured_output` argument of the LLM:

!!! Note
In this example we are using *open-mixtral-8x22b*, keep in mind not all the models work with the function calling functionality required for this example to work.
!!! NOTE
In this example we are using *Meta Llama 3.1 8B Instruct*, keep in mind not all the models support structured outputs.

```python
from distilabel.llms import MistralLLM

llm = MistralLLM(
model="open-mixtral-8x22b",
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
structured_output={"schema": User}
)
llm.load()
Expand Down
7 changes: 7 additions & 0 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
model_validator,
validate_call,
)
from pydantic._internal._model_construction import ModelMetaclass
from typing_extensions import Annotated, override

from distilabel.llms.base import AsyncLLM
Expand Down Expand Up @@ -363,6 +364,12 @@ def _get_structured_output(
"the `structured_output` attribute."
) from e

if structured_output:
if isinstance(structured_output["value"], ModelMetaclass):
structured_output["value"] = structured_output[
"value"
].model_json_schema()

return structured_output

async def _generate_with_text_generation(
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/steps/tasks/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class OutlinesStructuredOutputType(TypedDict, total=False):
class InstructorStructuredOutputType(TypedDict, total=False):
"""TypedDict to represent the structured output configuration from `instructor`."""

format: Optional[Literal["json"]]
"""One of "json"."""
schema: Union[Type[BaseModel], Dict[str, Any]]
"""The schema to use for the structured output, a `pydantic.BaseModel` class. """
mode: Optional[str]
Expand Down

0 comments on commit d7e61b5

Please sign in to comment.