Skip to content

Commit 684b3f7

Browse files
feat: add structured output support using Pydantic models (#60)
* feat: add structured output support using Pydantic models - Add method to Agent class for handling structured outputs - Create structured_output.py utility for converting Pydantic models to tool specs - Improve error handling when extracting model_id from configuration - Add integration tests to validate structured output functionality * fix: import cleanups and unused vars * feat: wip adding `structured_output` methods * feat: wip added structured output to bedrock and anthropic * feat: litellm structured output and some integ tests * feat: all structured outputs working, tbd llama api * feat: updated docstring * fix: otel ci dep issue * fix: remove unnecessary changes and comments * feat: basic test WIP * feat: better test coverage * fix: remove unused fixture * fix: resolve some comments * fix: inline basemodel classes * feat: update litellm, add checks * fix: autoformatting issue * feat: resolves comments * fix: ollama skip tests, pyproject whitespace diffs
1 parent 735d0c0 commit 684b3f7

21 files changed

+1147
-18
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ docs = [
6565
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
6666
]
6767
litellm = [
68-
"litellm>=1.69.0,<2.0.0",
68+
"litellm>=1.72.6,<2.0.0",
6969
]
7070
llamaapi = [
7171
"llama-api-client>=0.1.0,<1.0.0",
@@ -264,4 +264,4 @@ style = [
264264
["instruction", ""],
265265
["text", ""],
266266
["disabled", "fg:#858585 italic"]
267-
]
267+
]

src/strands/agent/agent.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import random
1717
from concurrent.futures import ThreadPoolExecutor
1818
from threading import Thread
19-
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union
19+
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
2020
from uuid import uuid4
2121

2222
from opentelemetry import trace
23+
from pydantic import BaseModel
2324

2425
from ..event_loop.event_loop import event_loop_cycle
2526
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
@@ -43,6 +44,9 @@
4344

4445
logger = logging.getLogger(__name__)
4546

47+
# TypeVar for generic structured output
48+
T = TypeVar("T", bound=BaseModel)
49+
4650

4751
# Sentinel class and object to distinguish between explicit None and default parameter value
4852
class _DefaultCallbackHandlerSentinel:
@@ -386,6 +390,32 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
386390
# Re-raise the exception to preserve original behavior
387391
raise
388392

393+
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
394+
"""This method allows you to get structured output from the agent.
395+
396+
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
397+
If you don't pass in a prompt, it will use only the conversation history to respond.
398+
If no conversation history exists and no prompt is provided, an error will be raised.
399+
400+
For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
401+
instruct the model to output the structured data.
402+
403+
Args:
404+
output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel)
405+
that the agent will use when responding.
406+
prompt(Optional[str]): The prompt to use for the agent.
407+
"""
408+
messages = self.messages
409+
if not messages and not prompt:
410+
raise ValueError("No conversation history or prompt provided")
411+
412+
# add the prompt as the last message
413+
if prompt:
414+
messages.append({"role": "user", "content": [{"text": prompt}]})
415+
416+
# get the structured output from the model
417+
return self.model.structured_output(output_model, messages, self.callback_handler)
418+
389419
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
390420
"""Process a natural language prompt and yield events as an async iterator.
391421

src/strands/models/anthropic.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, Iterable, Optional, TypedDict, cast
10+
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast
1111

1212
import anthropic
13+
from pydantic import BaseModel
1314
from typing_extensions import Required, Unpack, override
1415

16+
from ..event_loop.streaming import process_stream
17+
from ..handlers.callback_handler import PrintingCallbackHandler
18+
from ..tools import convert_pydantic_to_tool_spec
1519
from ..types.content import ContentBlock, Messages
1620
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
1721
from ..types.models import Model
@@ -20,6 +24,8 @@
2024

2125
logger = logging.getLogger(__name__)
2226

27+
T = TypeVar("T", bound=BaseModel)
28+
2329

2430
class AnthropicModel(Model):
2531
"""Anthropic model provider implementation."""
@@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
356362
with self.client.messages.stream(**request) as stream:
357363
for event in stream:
358364
if event.type in AnthropicModel.EVENT_TYPES:
359-
yield event.dict()
365+
yield event.model_dump()
360366

361367
usage = event.message.usage # type: ignore
362-
yield {"type": "metadata", "usage": usage.dict()}
368+
yield {"type": "metadata", "usage": usage.model_dump()}
363369

364370
except anthropic.RateLimitError as error:
365371
raise ModelThrottledException(str(error)) from error
@@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
369375
raise ContextWindowOverflowException(str(error)) from error
370376

371377
raise error
378+
379+
@override
380+
def structured_output(
381+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
382+
) -> T:
383+
"""Get structured output from the model.
384+
385+
Args:
386+
output_model(Type[BaseModel]): The output model to use for the agent.
387+
prompt(Messages): The prompt messages to use for the agent.
388+
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
389+
"""
390+
tool_spec = convert_pydantic_to_tool_spec(output_model)
391+
392+
response = self.converse(messages=prompt, tool_specs=[tool_spec])
393+
# process the stream and get the tool use input
394+
results = process_stream(
395+
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
396+
)
397+
398+
stop_reason, messages, _, _, _ = results
399+
400+
if stop_reason != "tool_use":
401+
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
402+
403+
content = messages["content"]
404+
output_response: dict[str, Any] | None = None
405+
for block in content:
406+
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
407+
# if the tool use name never matches, raise an error.
408+
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
409+
output_response = block["toolUse"]["input"]
410+
else:
411+
continue
412+
413+
if output_response is None:
414+
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
415+
416+
return output_model(**output_response)

src/strands/models/bedrock.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
import json
77
import logging
88
import os
9-
from typing import Any, Iterable, List, Literal, Optional, cast
9+
from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast
1010

1111
import boto3
1212
from botocore.config import Config as BotocoreConfig
1313
from botocore.exceptions import ClientError
14+
from pydantic import BaseModel
1415
from typing_extensions import TypedDict, Unpack, override
1516

17+
from ..event_loop.streaming import process_stream
18+
from ..handlers.callback_handler import PrintingCallbackHandler
19+
from ..tools import convert_pydantic_to_tool_spec
1620
from ..types.content import Messages
1721
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
1822
from ..types.models import Model
@@ -29,6 +33,8 @@
2933
"too many total text bytes",
3034
]
3135

36+
T = TypeVar("T", bound=BaseModel)
37+
3238

3339
class BedrockModel(Model):
3440
"""AWS Bedrock model provider implementation.
@@ -477,3 +483,42 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
477483
return self._find_detected_and_blocked_policy(item)
478484
# Otherwise return False
479485
return False
486+
487+
@override
488+
def structured_output(
489+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
490+
) -> T:
491+
"""Get structured output from the model.
492+
493+
Args:
494+
output_model(Type[BaseModel]): The output model to use for the agent.
495+
prompt(Messages): The prompt messages to use for the agent.
496+
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
497+
"""
498+
tool_spec = convert_pydantic_to_tool_spec(output_model)
499+
500+
response = self.converse(messages=prompt, tool_specs=[tool_spec])
501+
# process the stream and get the tool use input
502+
results = process_stream(
503+
response, callback_handler=callback_handler or PrintingCallbackHandler(), messages=prompt
504+
)
505+
506+
stop_reason, messages, _, _, _ = results
507+
508+
if stop_reason != "tool_use":
509+
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
510+
511+
content = messages["content"]
512+
output_response: dict[str, Any] | None = None
513+
for block in content:
514+
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
515+
# if the tool use name never matches, raise an error.
516+
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
517+
output_response = block["toolUse"]["input"]
518+
else:
519+
continue
520+
521+
if output_response is None:
522+
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
523+
524+
return output_model(**output_response)

src/strands/models/litellm.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
- Docs: https://docs.litellm.ai/
44
"""
55

6+
import json
67
import logging
7-
from typing import Any, Optional, TypedDict, cast
8+
from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast
89

910
import litellm
11+
from litellm.utils import supports_response_schema
12+
from pydantic import BaseModel
1013
from typing_extensions import Unpack, override
1114

12-
from ..types.content import ContentBlock
15+
from ..types.content import ContentBlock, Messages
1316
from .openai import OpenAIModel
1417

1518
logger = logging.getLogger(__name__)
1619

20+
T = TypeVar("T", bound=BaseModel)
21+
1722

1823
class LiteLLMModel(OpenAIModel):
1924
"""LiteLLM model provider implementation."""
@@ -97,3 +102,43 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
97102
}
98103

99104
return super().format_request_message_content(content)
105+
106+
@override
107+
def structured_output(
108+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
109+
) -> T:
110+
"""Get structured output from the model.
111+
112+
Args:
113+
output_model(Type[BaseModel]): The output model to use for the agent.
114+
prompt(Messages): The prompt messages to use for the agent.
115+
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
116+
117+
"""
118+
# The LiteLLM `Client` inits with Chat().
119+
# Chat() inits with self.completions
120+
# completions() has a method `create()` which wraps the real completion API of Litellm
121+
response = self.client.chat.completions.create(
122+
model=self.get_config()["model_id"],
123+
messages=super().format_request(prompt)["messages"],
124+
response_format=output_model,
125+
)
126+
127+
if not supports_response_schema(self.get_config()["model_id"]):
128+
raise ValueError("Model does not support response_format")
129+
if len(response.choices) > 1:
130+
raise ValueError("Multiple choices found in the response.")
131+
132+
# Find the first choice with tool_calls
133+
for choice in response.choices:
134+
if choice.finish_reason == "tool_calls":
135+
try:
136+
# Parse the tool call content as JSON
137+
tool_call_data = json.loads(choice.message.content)
138+
# Instantiate the output model with the parsed data
139+
return output_model(**tool_call_data)
140+
except (json.JSONDecodeError, TypeError, ValueError) as e:
141+
raise ValueError(f"Failed to parse or load content into model: {e}") from e
142+
143+
# If no tool_calls found, raise an error
144+
raise ValueError("No tool_calls found in response")

src/strands/models/llamaapi.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import json
99
import logging
1010
import mimetypes
11-
from typing import Any, Iterable, Optional, cast
11+
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
1212

1313
import llama_api_client
1414
from llama_api_client import LlamaAPIClient
15+
from pydantic import BaseModel
1516
from typing_extensions import TypedDict, Unpack, override
1617

1718
from ..types.content import ContentBlock, Messages
@@ -22,6 +23,8 @@
2223

2324
logger = logging.getLogger(__name__)
2425

26+
T = TypeVar("T", bound=BaseModel)
27+
2528

2629
class LlamaAPIModel(Model):
2730
"""Llama API model provider implementation."""
@@ -384,3 +387,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
384387
# we may have a metrics event here
385388
if metrics_event:
386389
yield {"chunk_type": "metadata", "data": metrics_event}
390+
391+
@override
392+
def structured_output(
393+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
394+
) -> T:
395+
"""Get structured output from the model.
396+
397+
Args:
398+
output_model(Type[BaseModel]): The output model to use for the agent.
399+
prompt(Messages): The prompt messages to use for the agent.
400+
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
401+
402+
Raises:
403+
NotImplementedError: Structured output is not currently supported for LlamaAPI models.
404+
"""
405+
# response_format: ResponseFormat = {
406+
# "type": "json_schema",
407+
# "json_schema": {
408+
# "name": output_model.__name__,
409+
# "schema": output_model.model_json_schema(),
410+
# },
411+
# }
412+
# response = self.client.chat.completions.create(
413+
# model=self.config["model_id"],
414+
# messages=self.format_request(prompt)["messages"],
415+
# response_format=response_format,
416+
# )
417+
raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.")

src/strands/models/ollama.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
import json
77
import logging
8-
from typing import Any, Iterable, Optional, cast
8+
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
99

1010
from ollama import Client as OllamaClient
11+
from pydantic import BaseModel
1112
from typing_extensions import TypedDict, Unpack, override
1213

1314
from ..types.content import ContentBlock, Messages
@@ -17,6 +18,8 @@
1718

1819
logger = logging.getLogger(__name__)
1920

21+
T = TypeVar("T", bound=BaseModel)
22+
2023

2124
class OllamaModel(Model):
2225
"""Ollama model provider implementation.
@@ -310,3 +313,25 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
310313
yield {"chunk_type": "content_stop", "data_type": "text"}
311314
yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else event.done_reason}
312315
yield {"chunk_type": "metadata", "data": event}
316+
317+
@override
318+
def structured_output(
319+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
320+
) -> T:
321+
"""Get structured output from the model.
322+
323+
Args:
324+
output_model(Type[BaseModel]): The output model to use for the agent.
325+
prompt(Messages): The prompt messages to use for the agent.
326+
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
327+
"""
328+
formatted_request = self.format_request(messages=prompt)
329+
formatted_request["format"] = output_model.model_json_schema()
330+
formatted_request["stream"] = False
331+
response = self.client.chat(**formatted_request)
332+
333+
try:
334+
content = response.message.content.strip()
335+
return output_model.model_validate_json(content)
336+
except Exception as e:
337+
raise ValueError(f"Failed to parse or load content into model: {e}") from e

0 commit comments

Comments
 (0)