From d15b69a28867a69cee670a35222b724bb80e943b Mon Sep 17 00:00:00 2001 From: sayanc82 Date: Wed, 18 Jun 2025 00:40:19 -0400 Subject: [PATCH 01/10] feat(Add-generate-support): Added first cut support for ultra, core and 3.5 --- CONTRIBUTING.md | 34 +++ src/strands/event_loop/streaming.py | 13 +- src/strands/models/_stabilityaiclient.py | 222 +++++++++++++++++ src/strands/models/stability.py | 264 +++++++++++++++++++++ src/strands/types/content.py | 3 +- src/strands/types/streaming.py | 3 + tests-integ/test_model_stability.py | 64 +++++ tests/TODO.lis | 5 + tests/strands/event_loop/test_streaming.py | 58 +++++ tests/strands/models/test_stability.py | 196 +++++++++++++++ 10 files changed, 858 insertions(+), 4 deletions(-) create mode 100644 src/strands/models/_stabilityaiclient.py create mode 100644 src/strands/models/stability.py create mode 100644 tests-integ/test_model_stability.py create mode 100644 tests/TODO.lis create mode 100644 tests/strands/models/test_stability.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa724cdd..84e048bf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -143,3 +143,37 @@ If you discover a potential security issue in this project we ask that you notif ## Licensing See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. + +## Stability Specific Instructions + +* Modified files + + 1. src/strands/event_loop/streaming.py + 2. tests/strands/event_loop/test_streaming.py + + These files have fixes to make the image return to agent consumers work. + +* New Files + 1. src/strands/models/_stabilityaiclient.py - A class with a rest client implementation for Stability + 2. src/strands/models/stability.py - The main model provider implementation + 3. tests-integ/test_model_stability.py - Integration test to test the image generation workflow. + 4. tests/TODO.lis - A text file containing TODOs. + 5. tests/strands/models/test_stability.py - Unit tests for the model provider + + +* Running the stability tests and testing the code + +You can run the integration test for the stability model as follows +``` + cd path/to/strands-agents-sdk-python + export STABILITY_API_KEY= + hatch test tests-integ/test_model_stability.py +``` + +You can also install this as a editable module in a python project and write a agent + +``` +pip install -e /path/to/strands-agents-sdk-python +Sample code similar to +https://gist.github.com/sayanc82/17f0d2442acd78f74f8393f58b552c54 +``` \ No newline at end of file diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6e8a806f..26bd33f6 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -152,7 +152,11 @@ def handle_content_block_delta( reasoning=True, **kwargs, ) - + elif "image" in delta_content: + # Handle the new ImageContent structure + image_content = delta_content["image"] + state["image"] = image_content + callback_handler(data=image_content, delta=delta_content, **kwargs) return state @@ -170,7 +174,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: current_tool_use = state["current_tool_use"] text = state["text"] reasoning_text = state["reasoningText"] - + image = state["image"] if current_tool_use: if "input" not in current_tool_use: current_tool_use["input"] = "" @@ -194,7 +198,6 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: elif text: content.append({"text": text}) state["text"] = "" - elif reasoning_text: content.append( { @@ -207,6 +210,9 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: } ) state["reasoningText"] = "" + elif image: + content.append({"image": image}) + state["image"] = "" return state @@ -279,6 +285,7 @@ def process_stream( "current_tool_use": {}, "reasoningText": "", "signature": "", + "image": None, } state["content"] = state["message"]["content"] diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py new file mode 100644 index 00000000..232262f8 --- /dev/null +++ b/src/strands/models/_stabilityaiclient.py @@ -0,0 +1,222 @@ +from typing import Any, BinaryIO, Dict, Optional, Union, cast + +import requests + + +class StabilityAiError(Exception): + """Base exception for Stability AI API errors.""" + + pass + + +class StabilityAiClient: + """Client for interacting with the Stability AI API.""" + + MODEL_ID_TO_BASE_URL = { + "stability.stable-image-core-v1:1": "https://api.stability.ai/v2beta/stable-image/generate/core", + "stability.stable-image-ultra-v1:1": "https://api.stability.ai/v2beta/stable-image/generate/ultra", + "stability.sd3-5-large-v1:0": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + } + + def __init__( + self, api_key: str, model_id: str, client_id: Optional[str] = None, client_version: Optional[str] = None + ): + """Initialize the Stability AI client. + + Args: + api_key: Your Stability API key + model_id: The model ID to use for the API request.See MODEL_ID_TO_BASE_URL for available models + client_id: Optional client ID for debugging + client_version: Optional client version for debugging + """ + self.model_id = model_id + self.base_url = self.MODEL_ID_TO_BASE_URL[model_id] + self.api_key = api_key + self.client_id = client_id + self.client_version = client_version + + def _get_headers(self, accept: str = "image/*") -> Dict[str, str]: + """Get the headers for the API request. + + Args: + accept: The accept header value (image/* or application/json) + + Returns: + Dict of headers + """ + headers = {"Authorization": f"Bearer {self.api_key}", "Accept": accept} + + if self.client_id: + headers["stability-client-id"] = self.client_id + if self.client_version: + headers["stability-client-version"] = self.client_version + + return headers + + def generate_image_bytes( + self, + prompt: str, + negative_prompt: Optional[str] = None, + aspect_ratio: str = "1:1", + seed: Optional[int] = None, + output_format: str = "png", + image: Optional[BinaryIO] = None, + style_preset: Optional[str] = None, + strength: Optional[float] = None, + ) -> bytes: + """Generate an image using the Stability AI API. + + Args: + prompt: Text prompt for image generation + negative_prompt: Optional text describing what not to include + aspect_ratio: Aspect ratio of the output image + seed: Random seed for generation + output_format: Output format (jpeg, png, webp) + image: Optional input image for img2img + style_preset: Optional style preset + strength: Required when image is provided, controls influence of input image + + Returns: bytes of the image + """ + return cast( + bytes, + self._generate_image( + prompt, + negative_prompt, + aspect_ratio, + seed, + output_format, + image, + style_preset, + strength, + return_json=False, + ), + ) + + def generate_image_json( + self, + prompt: str, + negative_prompt: Optional[str] = None, + aspect_ratio: str = "1:1", + seed: Optional[int] = None, + output_format: str = "png", + image: Optional[BinaryIO] = None, + style_preset: Optional[str] = None, + strength: Optional[float] = None, + ) -> Dict[str, Any]: + """Generate an image using the Stability AI API. + + Args: + prompt: Text prompt for image generation + negative_prompt: Optional text describing what not to include + aspect_ratio: Aspect ratio of the output image + seed: Random seed for generation + output_format: Output format (jpeg, png, webp) + image: Optional input image for img2img + style_preset: Optional style preset + strength: Required when image is provided, controls influence of input image + return_json: If True, returns JSON response with base64 image + + Returns: + Either image bytes or JSON response with base64 image + """ + return cast( + Dict[str, Any], + self._generate_image( + prompt, + negative_prompt, + aspect_ratio, + seed, + output_format, + image, + style_preset, + strength, + return_json=True, + ), + ) + + def _generate_image( + self, + prompt: str, + negative_prompt: Optional[str] = None, + aspect_ratio: str = "1:1", + seed: Optional[int] = None, + output_format: str = "png", + image: Optional[BinaryIO] = None, + style_preset: Optional[str] = None, + strength: Optional[float] = None, + return_json: bool = False, + ) -> Union[bytes, Dict[str, Any]]: + """Generate an image using the Stability AI API. + + Args: + prompt: Text prompt for image generation + negative_prompt: Optional text describing what not to include + aspect_ratio: Aspect ratio of the output image + seed: Random seed for generation + output_format: Output format (jpeg, png, webp) + image: Optional input image for img2img + style_preset: Optional style preset + strength: Required when image is provided, controls influence of input image + return_json: If True, returns JSON response with base64 image + + Returns: + Either image bytes or JSON response with base64 image + + Raises: + StabilityAiError: If the API request fails + """ + # Prepare the multipart form data + files: Dict[str, Union[BinaryIO, str]] = {} + data: Dict[str, Any] = {} + + # Add all parameters to data as strings + data["prompt"] = prompt + if negative_prompt: + data["negative_prompt"] = negative_prompt + if aspect_ratio: + data["aspect_ratio"] = aspect_ratio + if seed is not None: + data["seed"] = seed + if output_format: + data["output_format"] = output_format + if style_preset: + data["style_preset"] = style_preset + + # Handle input image if provided + if image: + files["image"] = image + + if len(files) == 0: + files["none"] = "" + try: + # Make the API request + response = requests.post( + self.base_url, + headers=self._get_headers("application/json" if return_json else "image/*"), + data=data, + files=files, + ) + + # Handle different response status codes + if response.status_code == 200: + if return_json: + return cast(Dict[str, Any], response.json()) + return cast(bytes, response.content) + elif response.status_code == 400: + raise StabilityAiError(f"Invalid parameters: {response.json().get('errors', 'Unknown error')}") + elif response.status_code == 403: + raise StabilityAiError("Request flagged by content moderation") + elif response.status_code == 413: + raise StabilityAiError("Request too large (max 10MiB)") + elif response.status_code == 422: + raise StabilityAiError(f"Request rejected: {response.json().get('errors', 'Unknown error')}") + elif response.status_code == 429: + raise StabilityAiError("Rate limit exceeded (max 150 requests per 10 seconds)") + elif response.status_code == 500: + raise StabilityAiError("Internal server error") + else: + raise StabilityAiError(f"Unexpected error: {response.status_code}") + + except requests.exceptions.RequestException as e: + raise StabilityAiError(f"Request failed: {str(e)}") from e diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py new file mode 100644 index 00000000..3768a8d0 --- /dev/null +++ b/src/strands/models/stability.py @@ -0,0 +1,264 @@ +"""Stability AI model provider. + +- Docs: https://platform.stability.ai/ +""" + +import base64 +import logging +from enum import Enum +from typing import Any, Iterable, Optional, TypedDict, cast + +from typing_extensions import NotRequired, Unpack, override + +from strands.types.content import Messages +from strands.types.models import Model +from strands.types.streaming import ContentBlockDelta, ContentBlockDeltaEvent, StreamEvent +from strands.types.tools import ToolSpec + +from ._stabilityaiclient import StabilityAiClient, StabilityAiError + +logger = logging.getLogger(__name__) + + +class OutputFormat(Enum): + """Supported output formats for image generation.""" + + JPEG = "jpeg" + PNG = "png" + WEBP = "webp" + + +class StylePreset(Enum): + """Supported style presets for image generation.""" + + THREE_D_MODEL = "3d-model" + ANALOG_FILM = "analog-film" + ANIME = "anime" + CINEMATIC = "cinematic" + COMIC_BOOK = "comic-book" + DIGITAL_ART = "digital-art" + ENHANCE = "enhance" + FANTASY_ART = "fantasy-art" + ISOMETRIC = "isometric" + LINE_ART = "line-art" + LOW_POLY = "low-poly" + MODELING_COMPOUND = "modeling-compound" + NEON_PUNK = "neon-punk" + ORIGAMI = "origami" + PHOTOGRAPHIC = "photographic" + PIXEL_ART = "pixel-art" + TILE_TEXTURE = "tile-texture" + + +class StabilityAiImageModel(Model): + """Your custom model provider implementation.""" + + class StabilityAiImageModelConfig(TypedDict): + """Configuration your model. + + Attributes: + model_id: ID of Custom model (required). + params: Model parameters (e.g., max_tokens). + """ + + """ + image - the image to use as the starting point for the generation + strength - controls how much influence the image parameter has on the output image + aspect_ratio - the aspect ratio of the output image + seed - the randomness seed to use for the generation + output_format - the format of the output image + """ + # Required parameters + model_id: str + + # Optional parameters with defaults + aspect_ratio: NotRequired[str] # defaults to "1:1" + seed: NotRequired[int] # defaults to random + output_format: NotRequired[OutputFormat] # defaults to PNG + style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC + image: NotRequired[str] # defaults to None + strength: NotRequired[float] # defaults to 0.35 + + def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: + """Initialize provider instance. + + Args: + api_key: The API key for connecting to your Custom model. + **model_config: Configuration options for Custom model. + """ + # Set default values for optional parameters + + defaults = { + "output_format": OutputFormat.PNG, + } + + # Update defaults with provided config + config_dict = {**defaults, **dict(model_config)} + + # Convert string output_format to enum if provided as string + if "output_format" in config_dict and isinstance(config_dict["output_format"], str): + try: + config_dict["output_format"] = OutputFormat(config_dict["output_format"]) + except ValueError as e: + raise ValueError(f"output_format must be one of: {[f.value for f in OutputFormat]}") from e + + # Convert string style_preset to enum if provided as string + if "style_preset" in config_dict and isinstance(config_dict["style_preset"], str): + try: + config_dict["style_preset"] = StylePreset(config_dict["style_preset"]) + except ValueError as e: + raise ValueError(f"style_preset must be one of: {[f.value for f in StylePreset]}") from e + + self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) + logger.debug("config=<%s> | initializing", self.config) + + model_id = self.config.get("model_id") + if model_id is None: + raise ValueError("model_id is required") + self.client = StabilityAiClient(api_key=api_key, model_id=model_id) + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Stability AI model request. + + Args: + messages: List of messages containing the conversation history + tool_specs: Optional list of tool specifications + system_prompt: Optional system prompt + + Returns: + Formatted request parameters for the Stability AI API + """ + # Extract the last user message as the prompt + # We do not need all the previous messages as context unlike an llm + prompt = "" + + for message in reversed(messages): + if message["role"] == "user": + # Find the text content in the message + for content in message["content"]: + if isinstance(content, dict) and "text" in content: + prompt = content["text"] + break + break + + if not prompt: + raise ValueError("No user message found in the conversation") + + # Format the request + request = { + "prompt": prompt, + "aspect_ratio": self.config.get("aspect_ratio", "1:1"), + "output_format": self.config.get("output_format", OutputFormat.PNG).value, + "style_preset": self.config.get("style_preset", StylePreset.PHOTOGRAPHIC).value, + } + + # Add optional parameters if they exist in config + if "seed" in self.config: + request["seed"] = self.config["seed"] # type: ignore[assignment] + if self.config.get("image") is not None: + request["image"] = self.config["image"] + request["strength"] = self.config.get("strength", 0.35) # type: ignore[assignment] + + return request + + @override + def update_config(self, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: # type: ignore[override] + """Update the model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> StabilityAiImageModelConfig: + """Get the model configuration. + + Returns: + The model configuration. + """ + return self.config + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + return {"contentBlockStart": {"start": {}}} + + case "content_block_delta": + # Have to do this cast as there are two different ContentBlockDelta types + # with different structures. The cast is for mypy to explicitly understand + # the right one, otherwise it is getting confused. + content_block_delta = cast( + ContentBlockDelta, + { + "image": { + "format": self.config["output_format"].value, + "source": {"bytes": base64.b64decode(event.get("data", b""))}, + } + }, + ) + content_block_delta_event = ContentBlockDeltaEvent(delta=content_block_delta) + return {"contentBlockDelta": content_block_delta_event} + + case "content_stop": + return {"contentBlockStop": {}} + case "message_stop": + return {"messageStop": {"stopReason": event["data"]}} + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + def stream(self, request: dict[str, Any]) -> Iterable[Any]: + """Send the request to the Stability AI model and get a streaming response. + + Args: + request: The formatted request to send to the Stability AI model. + + Returns: + An iterable of response events from the Stability AI model. + + Raises: + StabilityAiError: If the API request fails + """ + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + try: + # Generate the image + response_json = self.client.generate_image_json( + prompt=request["prompt"], + negative_prompt=request.get("negative_prompt"), + aspect_ratio=request.get("aspect_ratio", "1:1"), + seed=request.get("seed"), + output_format=request.get("output_format", "png"), + image=request.get("image"), + style_preset=request.get("style_preset"), + strength=request.get("strength", 0.35), + ) + # Yield the image data as a single event + + yield {"chunk_type": "content_block_delta", "data_type": "image", "data": response_json.get("image")} + yield {"chunk_type": "content_stop", "data_type": "text"} + + yield {"chunk_type": "message_stop", "data": response_json.get("finish_reason")} + except StabilityAiError as e: + logger.error("Failed to generate image: %s", str(e)) + raise diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 790e9094..f5b531b0 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -118,6 +118,7 @@ class DeltaContent(TypedDict, total=False): text: str toolUse: Dict[Literal["input"], str] + image: ImageContent class ContentBlockStartToolUse(TypedDict): @@ -142,7 +143,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Optional[ContentBlockStartToolUse] -class ContentBlockDelta(TypedDict): +class ContentBlockDelta(TypedDict, total=False): """The content block delta event. Attributes: diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index 9c99b210..dcd1f868 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -12,6 +12,7 @@ from .content import ContentBlockStart, Role from .event_loop import Metrics, StopReason, Usage from .guardrails import Trace +from .media import ImageContent class MessageStartEvent(TypedDict): @@ -78,11 +79,13 @@ class ContentBlockDelta(TypedDict, total=False): reasoningContent: Contains content regarding the reasoning that is carried out by the model. text: Text fragment being streamed. toolUse: Tool use input fragment being streamed. + image: Image content being streamed. """ reasoningContent: ReasoningContentBlockDelta text: str toolUse: ContentBlockDeltaToolUse + image: ImageContent class ContentBlockDeltaEvent(TypedDict, total=False): diff --git a/tests-integ/test_model_stability.py b/tests-integ/test_model_stability.py new file mode 100644 index 00000000..864d24c0 --- /dev/null +++ b/tests-integ/test_model_stability.py @@ -0,0 +1,64 @@ +import os + +import pytest + +from strands import Agent +from strands.models.stability import OutputFormat, StabilityAiImageModel + + +@pytest.fixture +def model_id(request): + return request.param + + +@pytest.fixture +def model(model_id): + return StabilityAiImageModel( + api_key=os.getenv("STABILITY_API_KEY"), # Use the API key loaded from .env + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.PNG, + ) + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif( + "STABILITY_API_KEY" not in os.environ, + reason="STABILITY_API_KEY environment variable missing", +) +@pytest.mark.parametrize( + "model_id", + [ + "stability.stable-image-core-v1:1", + "stability.stable-image-ultra-v1:1", + "stability.sd3-5-large-v1:0", + ], + indirect=True, +) +def test_agent(agent): + result = agent("dark high contrast render of a psychedelic tree of life illuminating dust in a mystical cave.") + + # Initialize variables + image_data = None + image_format = None + + # Find image content + for content in result.message.get("content", []): + if isinstance(content, dict) and "image" in content: + image_data = content["image"]["source"]["bytes"] + image_format = content["image"]["format"] + break + + # Verify we found an image + assert image_data is not None, "No image data found in the response" + assert image_format is not None, "No image format found in the response" + + # Verify image data is not empty + assert len(image_data) > 0, "Image data should not be empty" + + # Verify image format is PNG + assert image_format == "png", f"Expected image format to be 'png', got '{image_format}'" diff --git a/tests/TODO.lis b/tests/TODO.lis new file mode 100644 index 00000000..3339a4e4 --- /dev/null +++ b/tests/TODO.lis @@ -0,0 +1,5 @@ +1. Add model id validation +2. Do specific config validation for the 3 models, support parameters like cfg for sd3.5 +3. Cleanup code :- Modularize. Move validation to its own methods for example. +4. Add test for the fixes I have made for plumbing images through the framework +5. Add documentation similar to other model providers \ No newline at end of file diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c24e7e48..54236ca0 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -139,6 +139,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Image + ( + {"delta": {"image": {"format": "png", "source": {"bytes": b"image_data"}}}}, + {"image": {"format": "png", "source": {"bytes": b"image_data"}}}, + {"image": {"format": "png", "source": {"bytes": b"image_data"}}}, + {"data": {"format": "png", "source": {"bytes": b"image_data"}}}, + ), # Empty ( {"delta": {}}, @@ -175,12 +182,14 @@ def callback_handler(**kwargs): "current_tool_use": {"toolUseId": "123", "name": "test", "input": '{"key": "value"}'}, "text": "", "reasoningText": "", + "image": None, }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Tool Use - Missing input @@ -190,12 +199,14 @@ def callback_handler(**kwargs): "current_tool_use": {"toolUseId": "123", "name": "test"}, "text": "", "reasoningText": "", + "image": None, }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Text @@ -205,12 +216,14 @@ def callback_handler(**kwargs): "current_tool_use": {}, "text": "test", "reasoningText": "", + "image": None, }, { "content": [{"text": "test"}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Reasoning @@ -221,6 +234,7 @@ def callback_handler(**kwargs): "text": "", "reasoningText": "test", "signature": "123", + "image": None, }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test", "signature": "123"}}}], @@ -228,6 +242,24 @@ def callback_handler(**kwargs): "text": "", "reasoningText": "", "signature": "123", + "image": None, + }, + ), + # Image + ( + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "image": {"format": "png", "source": {"bytes": b"image_data"}}, + }, + { + "content": [{"image": {"format": "png", "source": {"bytes": b"image_data"}}}], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "image": "", }, ), # Empty @@ -237,12 +269,14 @@ def callback_handler(**kwargs): "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), ], @@ -355,6 +389,30 @@ def test_extract_usage_metrics(): {"calls": 1}, [{"role": "user", "content": [{"text": "REDACTED"}]}], ), + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"image": {"format": "png", "source": {"bytes": b"image_data"}}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + "metrics": {"latencyMs": 1}, + } + }, + ], + "end_turn", + { + "role": "assistant", + "content": [{"image": {"format": "png", "source": {"bytes": b"image_data"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + {"calls": 1}, + [{"role": "user", "content": [{"text": "Some input!"}]}], + ), ], ) def test_process_stream( diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py new file mode 100644 index 00000000..1f86c7d8 --- /dev/null +++ b/tests/strands/models/test_stability.py @@ -0,0 +1,196 @@ +import base64 +import unittest.mock + +import pytest + +import strands +from strands.models.stability import OutputFormat, StabilityAiImageModel, StylePreset + + +@pytest.fixture +def stability_client_cls(): + with unittest.mock.patch.object(strands.models.stability, "StabilityAiClient") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def stability_client(stability_client_cls): + return stability_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "stability.stable-image-ultra-v1:1" + + +@pytest.fixture +def model(stability_client, model_id): + _ = stability_client + return StabilityAiImageModel(api_key="test_key", model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "a beautiful sunset over mountains"}]}] + + +def test__init__(stability_client_cls, model_id): + model = StabilityAiImageModel( + api_key="test_key", + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.JPEG, + style_preset=StylePreset.PHOTOGRAPHIC, + ) + + tru_config = model.get_config() + exp_config = { + "model_id": model_id, + "aspect_ratio": "16:9", + "output_format": OutputFormat.JPEG, + "style_preset": StylePreset.PHOTOGRAPHIC, + } + + assert tru_config == exp_config + stability_client_cls.assert_called_once_with(api_key="test_key", model_id=model_id) + + +def test__init__with_string_enums(stability_client_cls, model_id): + model = StabilityAiImageModel( + api_key="test_key", + model_id=model_id, + output_format="jpeg", + style_preset="photographic", + ) + + tru_config = model.get_config() + exp_config = {"model_id": model_id, "output_format": OutputFormat.JPEG, "style_preset": StylePreset.PHOTOGRAPHIC} + + assert tru_config == exp_config + + +def test__init__with_invalid_output_format(): + with pytest.raises(ValueError) as exc_info: + StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + output_format="invalid", + ) + assert "output_format must be one of:" in str(exc_info.value) + + +def test__init__with_invalid_style_preset(): + with pytest.raises(ValueError) as exc_info: + StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + style_preset="invalid", + ) + assert "style_preset must be one of:" in str(exc_info.value) + + +def test_update_config(model, model_id): + model.update_config( + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.JPEG, + ) + + tru_config = model.get_config() + exp_config = {"model_id": model_id, "aspect_ratio": "16:9", "output_format": OutputFormat.JPEG} + + assert tru_config == exp_config + + +def test_format_request(model, messages): + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + } + + assert request == exp_request + + +def test_format_request_with_optional_params(model, messages): + model.update_config( + seed=12345, + image="base64_encoded_image", + strength=0.5, + ) + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "seed": 12345, + "image": "base64_encoded_image", + "strength": 0.5, + } + + assert request == exp_request + + +def test_format_request_no_user_message(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + messages = [{"role": "assistant", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError) as exc_info: + model.format_request(messages) + assert "No user message found in the conversation" in str(exc_info.value) + + +def test_format_chunk_message_start(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "message_start"} + + chunk = model.format_chunk(event) + assert chunk == {"messageStart": {"role": "assistant"}} + + +def test_format_chunk_content_start(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "content_start"} + + chunk = model.format_chunk(event) + assert chunk == {"contentBlockStart": {"start": {}}} + + +def test_format_chunk_content_block_delta(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + raw_image_data = b"raw_image_data" + base64_encoded_data = base64.b64encode(raw_image_data) + event = {"chunk_type": "content_block_delta", "data": base64_encoded_data} + + chunk = model.format_chunk(event) + assert chunk == {"contentBlockDelta": {"delta": {"image": {"format": "png", "source": {"bytes": raw_image_data}}}}} + + +def test_format_chunk_content_stop(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "content_stop"} + + chunk = model.format_chunk(event) + assert chunk == {"contentBlockStop": {}} + + +def test_format_chunk_message_stop(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "message_stop", "data": "stop"} + + chunk = model.format_chunk(event) + assert chunk == {"messageStop": {"stopReason": "stop"}} + + +def test_format_chunk_unknown_type(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError) as exc_info: + model.format_chunk(event) + assert "unknown type" in str(exc_info.value) From f371180d19053a7a735b4e2b6cc61819137327d6 Mon Sep 17 00:00:00 2001 From: sayanc82 Date: Wed, 18 Jun 2025 09:38:12 -0400 Subject: [PATCH 02/10] refactor(Modularize-code-for-better-readability): Added methods for config validation, added well named classes and variables --- TODO.lis | 4 + src/strands/models/stability.py | 271 ++++++++++++++++++++------------ tests/TODO.lis | 5 - 3 files changed, 176 insertions(+), 104 deletions(-) create mode 100644 TODO.lis delete mode 100644 tests/TODO.lis diff --git a/TODO.lis b/TODO.lis new file mode 100644 index 00000000..7316285a --- /dev/null +++ b/TODO.lis @@ -0,0 +1,4 @@ +Add support for all the parameters for core and 3.5 +Do specific config validation for the 3 models, support parameters like cfg for sd3.5 +Cleanup code :- Modularize. Move validation to its own methods for example. +Add documentation similar to other model providers diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index 3768a8d0..1668c5ab 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -50,24 +50,41 @@ class StylePreset(Enum): TILE_TEXTURE = "tile-texture" +class Defaults: + """Default values for Stability AI configuration.""" + + ASPECT_RATIO = "1:1" + OUTPUT_FORMAT = OutputFormat.PNG + STYLE_PRESET = StylePreset.PHOTOGRAPHIC + STRENGTH = 0.35 + + +class ChunkTypes: + """Chunk type constants.""" + + MESSAGE_START = "message_start" + CONTENT_START = "content_start" + CONTENT_BLOCK_DELTA = "content_block_delta" + CONTENT_STOP = "content_stop" + MESSAGE_STOP = "message_stop" + + class StabilityAiImageModel(Model): - """Your custom model provider implementation.""" + """Stability AI image generation model provider.""" class StabilityAiImageModelConfig(TypedDict): - """Configuration your model. + """Configuration for Stability AI image model. Attributes: - model_id: ID of Custom model (required). - params: Model parameters (e.g., max_tokens). + model_id: ID of the Stability AI model (required). + aspect_ratio: Aspect ratio of the output image. + seed: Random seed for generation. + output_format: Output format (jpeg, png, webp). + style_preset: Style preset for image generation. + image: Input image for img2img generation. + strength: Influence of input image on output (0.0-1.0). """ - """ - image - the image to use as the starting point for the generation - strength - controls how much influence the image parameter has on the output image - aspect_ratio - the aspect ratio of the output image - seed - the randomness seed to use for the generation - output_format - the format of the output image - """ # Required parameters model_id: str @@ -80,88 +97,110 @@ class StabilityAiImageModelConfig(TypedDict): strength: NotRequired[float] # defaults to 0.35 def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: - """Initialize provider instance. + """Initialize the Stability AI model provider. Args: - api_key: The API key for connecting to your Custom model. - **model_config: Configuration options for Custom model. + api_key: The API key for connecting to Stability AI. + **model_config: Configuration options for the model. """ - # Set default values for optional parameters + config_dict = {**{"output_format": Defaults.OUTPUT_FORMAT}, **dict(model_config)} + self._validate_and_convert_config(config_dict) - defaults = { - "output_format": OutputFormat.PNG, - } + self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) + logger.debug("config=<%s> | initializing", self.config) + + model_id = self.config.get("model_id") + if model_id is None: + raise ValueError("model_id is required") + self.client = StabilityAiClient(api_key=api_key, model_id=model_id) - # Update defaults with provided config - config_dict = {**defaults, **dict(model_config)} + def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: + """Validate and convert configuration values to proper types.""" + self._convert_output_format(config_dict) + self._convert_style_preset(config_dict) - # Convert string output_format to enum if provided as string + def _convert_output_format(self, config_dict: dict[str, Any]) -> None: + """Convert string output_format to enum if needed.""" if "output_format" in config_dict and isinstance(config_dict["output_format"], str): try: config_dict["output_format"] = OutputFormat(config_dict["output_format"]) except ValueError as e: - raise ValueError(f"output_format must be one of: {[f.value for f in OutputFormat]}") from e + valid_formats = [f.value for f in OutputFormat] + raise ValueError(f"output_format must be one of: {valid_formats}") from e - # Convert string style_preset to enum if provided as string + def _convert_style_preset(self, config_dict: dict[str, Any]) -> None: + """Convert string style_preset to enum if needed.""" if "style_preset" in config_dict and isinstance(config_dict["style_preset"], str): try: config_dict["style_preset"] = StylePreset(config_dict["style_preset"]) except ValueError as e: - raise ValueError(f"style_preset must be one of: {[f.value for f in StylePreset]}") from e + valid_presets = [p.value for p in StylePreset] + raise ValueError(f"style_preset must be one of: {valid_presets}") from e - self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) - logger.debug("config=<%s> | initializing", self.config) - - model_id = self.config.get("model_id") - if model_id is None: - raise ValueError("model_id is required") - self.client = StabilityAiClient(api_key=api_key, model_id=model_id) - - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format a Stability AI model request. + def _extract_prompt_from_messages(self, messages: Messages) -> str: + """Extract the last user message as prompt. Args: - messages: List of messages containing the conversation history - tool_specs: Optional list of tool specifications - system_prompt: Optional system prompt + messages: List of conversation messages. Returns: - Formatted request parameters for the Stability AI API - """ - # Extract the last user message as the prompt - # We do not need all the previous messages as context unlike an llm - prompt = "" + The extracted prompt text. + Raises: + ValueError: If no user message with text content is found. + """ for message in reversed(messages): if message["role"] == "user": - # Find the text content in the message for content in message["content"]: if isinstance(content, dict) and "text" in content: - prompt = content["text"] - break - break + return content["text"] + raise ValueError("No user message found in the conversation") - if not prompt: - raise ValueError("No user message found in the conversation") + def _build_base_request(self, prompt: str) -> dict[str, Any]: + """Build the base request with required parameters. - # Format the request - request = { + Args: + prompt: The text prompt for image generation. + + Returns: + Dictionary with base request parameters. + """ + return { "prompt": prompt, - "aspect_ratio": self.config.get("aspect_ratio", "1:1"), - "output_format": self.config.get("output_format", OutputFormat.PNG).value, - "style_preset": self.config.get("style_preset", StylePreset.PHOTOGRAPHIC).value, + "aspect_ratio": self.config.get("aspect_ratio", Defaults.ASPECT_RATIO), + "output_format": self.config.get("output_format", Defaults.OUTPUT_FORMAT).value, + "style_preset": self.config.get("style_preset", Defaults.STYLE_PRESET).value, } - # Add optional parameters if they exist in config + def _add_optional_parameters(self, request: dict[str, Any]) -> None: + """Add optional parameters to the request if they exist in config. + + Args: + request: The request dictionary to modify. + """ if "seed" in self.config: - request["seed"] = self.config["seed"] # type: ignore[assignment] + request["seed"] = self.config["seed"] if self.config.get("image") is not None: request["image"] = self.config["image"] - request["strength"] = self.config.get("strength", 0.35) # type: ignore[assignment] + request["strength"] = self.config.get("strength", Defaults.STRENGTH) + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Stability AI model request. + Args: + messages: List of messages containing the conversation history. + tool_specs: Optional list of tool specifications (unused for image generation). + system_prompt: Optional system prompt (unused for image generation). + + Returns: + Formatted request parameters for the Stability AI API. + """ + prompt = self._extract_prompt_from_messages(messages) + request = self._build_base_request(prompt) + self._add_optional_parameters(request) return request @override @@ -182,49 +221,78 @@ def get_config(self) -> StabilityAiImageModelConfig: """ return self.config + def _format_message_start(self) -> StreamEvent: + """Format message start event.""" + return {"messageStart": {"role": "assistant"}} + + def _format_content_start(self) -> StreamEvent: + """Format content start event.""" + return {"contentBlockStart": {"start": {}}} + + def _format_content_block_delta(self, event: dict[str, Any]) -> StreamEvent: + """Format content block delta event. + + Args: + event: The event containing image data. + + Returns: + Formatted content block delta event. + """ + content_block_delta = cast( + ContentBlockDelta, + { + "image": { + "format": self.config["output_format"].value, + "source": {"bytes": base64.b64decode(event.get("data", b""))}, + } + }, + ) + content_block_delta_event = ContentBlockDeltaEvent(delta=content_block_delta) + return {"contentBlockDelta": content_block_delta_event} + + def _format_content_stop(self) -> StreamEvent: + """Format content stop event.""" + return {"contentBlockStop": {}} + + def _format_message_stop(self, event: dict[str, Any]) -> StreamEvent: + """Format message stop event. + + Args: + event: The event containing stop reason. + + Returns: + Formatted message stop event. + """ + return {"messageStop": {"stopReason": event["data"]}} + @override def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format an OpenAI response event into a standardized message chunk. + """Format an event into a standardized message chunk. Args: - event: A response event from the OpenAI compatible model. + event: A response event from the Stability AI model. Returns: The formatted chunk. Raises: RuntimeError: If chunk_type is not recognized. - This error should never be encountered as chunk_type is controlled in the stream method. """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - return {"contentBlockStart": {"start": {}}} - - case "content_block_delta": - # Have to do this cast as there are two different ContentBlockDelta types - # with different structures. The cast is for mypy to explicitly understand - # the right one, otherwise it is getting confused. - content_block_delta = cast( - ContentBlockDelta, - { - "image": { - "format": self.config["output_format"].value, - "source": {"bytes": base64.b64decode(event.get("data", b""))}, - } - }, - ) - content_block_delta_event = ContentBlockDeltaEvent(delta=content_block_delta) - return {"contentBlockDelta": content_block_delta_event} - - case "content_stop": - return {"contentBlockStop": {}} - case "message_stop": - return {"messageStop": {"stopReason": event["data"]}} + chunk_type = event["chunk_type"] + + match chunk_type: + case ChunkTypes.MESSAGE_START: + return self._format_message_start() + case ChunkTypes.CONTENT_START: + return self._format_content_start() + case ChunkTypes.CONTENT_BLOCK_DELTA: + return self._format_content_block_delta(event) + case ChunkTypes.CONTENT_STOP: + return self._format_content_stop() + case ChunkTypes.MESSAGE_STOP: + return self._format_message_stop(event) case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + raise RuntimeError(f"chunk_type=<{chunk_type}> | unknown type") @override def stream(self, request: dict[str, Any]) -> Iterable[Any]: @@ -237,28 +305,33 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: An iterable of response events from the Stability AI model. Raises: - StabilityAiError: If the API request fails + StabilityAiError: If the API request fails. """ - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} + yield {"chunk_type": ChunkTypes.MESSAGE_START} + yield {"chunk_type": ChunkTypes.CONTENT_START, "data_type": "text"} + try: # Generate the image response_json = self.client.generate_image_json( prompt=request["prompt"], negative_prompt=request.get("negative_prompt"), - aspect_ratio=request.get("aspect_ratio", "1:1"), + aspect_ratio=request.get("aspect_ratio", Defaults.ASPECT_RATIO), seed=request.get("seed"), output_format=request.get("output_format", "png"), image=request.get("image"), style_preset=request.get("style_preset"), - strength=request.get("strength", 0.35), + strength=request.get("strength", Defaults.STRENGTH), ) - # Yield the image data as a single event - yield {"chunk_type": "content_block_delta", "data_type": "image", "data": response_json.get("image")} - yield {"chunk_type": "content_stop", "data_type": "text"} + # Yield the image data as a single event + yield { + "chunk_type": ChunkTypes.CONTENT_BLOCK_DELTA, + "data_type": "image", + "data": response_json.get("image"), + } + yield {"chunk_type": ChunkTypes.CONTENT_STOP, "data_type": "text"} + yield {"chunk_type": ChunkTypes.MESSAGE_STOP, "data": response_json.get("finish_reason")} - yield {"chunk_type": "message_stop", "data": response_json.get("finish_reason")} except StabilityAiError as e: logger.error("Failed to generate image: %s", str(e)) raise diff --git a/tests/TODO.lis b/tests/TODO.lis deleted file mode 100644 index 3339a4e4..00000000 --- a/tests/TODO.lis +++ /dev/null @@ -1,5 +0,0 @@ -1. Add model id validation -2. Do specific config validation for the 3 models, support parameters like cfg for sd3.5 -3. Cleanup code :- Modularize. Move validation to its own methods for example. -4. Add test for the fixes I have made for plumbing images through the framework -5. Add documentation similar to other model providers \ No newline at end of file From 520a47942c5707ed62c698b464ed139ed5419cac Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 18 Jun 2025 15:35:54 +0000 Subject: [PATCH 03/10] fix(stability): add type annotations and parameter validation --- pyproject.toml | 7 + src/strands/models/_stabilityaiclient.py | 201 +++++++++++++++-------- src/strands/models/stability.py | 13 +- 3 files changed, 138 insertions(+), 83 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17bc110e..0ed20aa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,10 +54,12 @@ dev = [ "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", + "Pillow>=11.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", + "types-Pillow>=10.2.0.20240822", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -118,6 +120,7 @@ lint-fix = [ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", + "Pillow>=11.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "pytest-cov>=4.1.0,<5.0.0", @@ -198,6 +201,10 @@ ignore_missing_imports = false module = "litellm" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "PIL" +ignore_missing_imports = true + [tool.ruff] line-length = 120 include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index 232262f8..30645131 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -1,6 +1,95 @@ +import base64 +from enum import Enum +from io import BytesIO from typing import Any, BinaryIO, Dict, Optional, Union, cast import requests +from PIL import Image + + +# Validation classes and functions +# Other validation is performed in the JSON workflow configs +class ModeEnum(str, Enum): + TEXT_TO_IMAGE = "text-to-image" + IMAGE_TO_IMAGE = "image-to-image" + + +class OutputFormat(Enum): + PNG = "png" + JPEG = "jpeg" + WEBP = "webp" + + +class StylePresetEnum(str, Enum): + THREE_D_MODEL = "3d-model" + ANALOG_FILM = "analog-film" + ANIME = "anime" + CINEMATIC = "cinematic" + COMIC_BOOK = "comic-book" + DIGITAL_ART = "digital-art" + ENHANCE = "enhance" + FANTASY_ART = "fantasy-art" + ISOMETRIC = "isometric" + LINE_ART = "line-art" + LOW_POLY = "low-poly" + MODELING_COMPOUND = "modeling-compound" + NEON_PUNK = "neon-punk" + ORIGAMI = "origami" + PHOTOGRAPHIC = "photographic" + PIXEL_ART = "pixel-art" + TILE_TEXTURE = "tile-texture" + + +def _validate_image_pixels_and_aspect_ratio(image: Union[str, BinaryIO]) -> None: + """Validates the number of pixels in the 'image' field of the request. + + The image must have a total pixel count between 4,096 and 9,437,184 (inclusive). + Not implemented yet (but required for stable image services): + If the model is outpaint, the aspect ratio must be between 1:2.5 and 2.5:1. + + Args: + image: Either a base64-encoded string or a BinaryIO object + """ + # Get the raw image data + if isinstance(image, str): + # Decode base64 string + try: + image_data = base64.b64decode(image) + except Exception as e: + raise ValueError("Invalid base64 encoding for 'image'") from e + else: + # Read from BinaryIO + image_data = image.read() + image.seek(0) # Reset the file pointer so it can be read again later + + # Attempt to open the image using Pillow + try: + with Image.open(BytesIO(image_data)) as img: + width, height = img.size + except Exception as e: + raise ValueError("Unable to open or process the image data") from e + + # Check the image type based on magic bytes (JPEG, PNG, WebP) + image_format = None + if image_data.startswith(b"\xff\xd8\xff"): # JPEG magic number + image_format = "jpeg" + elif image_data.startswith(b"\x89\x50\x4e\x47"): # PNG magic number + image_format = "png" + elif image_data.startswith(b"\x52\x49\x46\x46") and image_data[8:12] == b"WEBP": # WebP magic number + image_format = "webp" + + if not image_format: + raise ValueError("Unsupported image format. Only JPEG, PNG, or WebP are allowed.") + + total_pixels = width * height + MIN_PIXELS = 4096 + MAX_PIXELS = 9437184 + + if total_pixels < MIN_PIXELS or total_pixels > MAX_PIXELS: + raise ValueError( + f"Image total pixel count {total_pixels} is invalid. Image size (height x width) must be between " + f"{MIN_PIXELS} and {MAX_PIXELS} pixels." + ) class StabilityAiError(Exception): @@ -53,87 +142,29 @@ def _get_headers(self, accept: str = "image/*") -> Dict[str, str]: return headers - def generate_image_bytes( - self, - prompt: str, - negative_prompt: Optional[str] = None, - aspect_ratio: str = "1:1", - seed: Optional[int] = None, - output_format: str = "png", - image: Optional[BinaryIO] = None, - style_preset: Optional[str] = None, - strength: Optional[float] = None, - ) -> bytes: + def generate_image_bytes(self, **kwargs: Any) -> bytes: """Generate an image using the Stability AI API. Args: - prompt: Text prompt for image generation - negative_prompt: Optional text describing what not to include - aspect_ratio: Aspect ratio of the output image - seed: Random seed for generation - output_format: Output format (jpeg, png, webp) - image: Optional input image for img2img - style_preset: Optional style preset - strength: Required when image is provided, controls influence of input image + **kwargs: See _generate_image for available parameters - Returns: bytes of the image + Returns: + bytes of the image """ - return cast( - bytes, - self._generate_image( - prompt, - negative_prompt, - aspect_ratio, - seed, - output_format, - image, - style_preset, - strength, - return_json=False, - ), - ) + kwargs["return_json"] = False + return cast(bytes, self._generate_image(**kwargs)) - def generate_image_json( - self, - prompt: str, - negative_prompt: Optional[str] = None, - aspect_ratio: str = "1:1", - seed: Optional[int] = None, - output_format: str = "png", - image: Optional[BinaryIO] = None, - style_preset: Optional[str] = None, - strength: Optional[float] = None, - ) -> Dict[str, Any]: + def generate_image_json(self, **kwargs: Any) -> Dict[str, Any]: """Generate an image using the Stability AI API. Args: - prompt: Text prompt for image generation - negative_prompt: Optional text describing what not to include - aspect_ratio: Aspect ratio of the output image - seed: Random seed for generation - output_format: Output format (jpeg, png, webp) - image: Optional input image for img2img - style_preset: Optional style preset - strength: Required when image is provided, controls influence of input image - return_json: If True, returns JSON response with base64 image + **kwargs: See _generate_image for available parameters Returns: - Either image bytes or JSON response with base64 image + JSON response with base64 image """ - return cast( - Dict[str, Any], - self._generate_image( - prompt, - negative_prompt, - aspect_ratio, - seed, - output_format, - image, - style_preset, - strength, - return_json=True, - ), - ) + kwargs["return_json"] = True + return cast(Dict[str, Any], self._generate_image(**kwargs)) def _generate_image( self, @@ -141,11 +172,13 @@ def _generate_image( negative_prompt: Optional[str] = None, aspect_ratio: str = "1:1", seed: Optional[int] = None, - output_format: str = "png", + output_format: Union[OutputFormat, str] = "png", image: Optional[BinaryIO] = None, + mode: Union[ModeEnum, str] = ModeEnum.TEXT_TO_IMAGE, style_preset: Optional[str] = None, - strength: Optional[float] = None, + strength: Optional[float] = 0.35, return_json: bool = False, + **extra_kwargs: Any, ) -> Union[bytes, Dict[str, Any]]: """Generate an image using the Stability AI API. @@ -156,9 +189,11 @@ def _generate_image( seed: Random seed for generation output_format: Output format (jpeg, png, webp) image: Optional input image for img2img + mode: "text-to-image" or "image-to-image" style_preset: Optional style preset strength: Required when image is provided, controls influence of input image return_json: If True, returns JSON response with base64 image + **extra_kwargs: Additional keyword arguments (will be ignored with a warning) Returns: Either image bytes or JSON response with base64 image @@ -166,6 +201,20 @@ def _generate_image( Raises: StabilityAiError: If the API request fails """ + if isinstance(output_format, str): + try: + output_format = OutputFormat(output_format) + except ValueError as e: + raise ValueError( + f"Invalid output_format: {output_format}. Must be one of: {[e.value for e in OutputFormat]}" + ) from e + + if isinstance(mode, str): + try: + mode = ModeEnum(mode) + except ValueError as e: + raise ValueError(f"Invalid mode: {mode}. Must be one of: {[e.value for e in ModeEnum]}") from e + # Prepare the multipart form data files: Dict[str, Union[BinaryIO, str]] = {} data: Dict[str, Any] = {} @@ -179,12 +228,16 @@ def _generate_image( if seed is not None: data["seed"] = seed if output_format: - data["output_format"] = output_format + data["output_format"] = output_format.value if style_preset: + allowed_presets = [preset.value for preset in StylePresetEnum] + if style_preset not in allowed_presets: + raise ValueError(f"'style_preset' must be one of {allowed_presets}. Got '{style_preset}'.") data["style_preset"] = style_preset # Handle input image if provided if image: + _validate_image_pixels_and_aspect_ratio(image) files["image"] = image if len(files) == 0: @@ -203,6 +256,10 @@ def _generate_image( if return_json: return cast(Dict[str, Any], response.json()) return cast(bytes, response.content) + elif response.status_code == 401: + raise StabilityAiError( + f"Unauthorized: check authentication credentials: {response.json().get('errors', 'Unknown error')}" + ) elif response.status_code == 400: raise StabilityAiError(f"Invalid parameters: {response.json().get('errors', 'Unknown error')}") elif response.status_code == 403: diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index 3768a8d0..ce9120dd 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -242,17 +242,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: yield {"chunk_type": "message_start"} yield {"chunk_type": "content_start", "data_type": "text"} try: - # Generate the image - response_json = self.client.generate_image_json( - prompt=request["prompt"], - negative_prompt=request.get("negative_prompt"), - aspect_ratio=request.get("aspect_ratio", "1:1"), - seed=request.get("seed"), - output_format=request.get("output_format", "png"), - image=request.get("image"), - style_preset=request.get("style_preset"), - strength=request.get("strength", 0.35), - ) + # Generate the image #TODO add generate_image_bytes + response_json = self.client.generate_image_json(**request) # Yield the image data as a single event yield {"chunk_type": "content_block_delta", "data_type": "image", "data": response_json.get("image")} From 59e87d0237d7bdb2d7749eca343c907aca26c9fc Mon Sep 17 00:00:00 2001 From: satsumas Date: Thu, 19 Jun 2025 14:24:41 +0000 Subject: [PATCH 04/10] refactor(stability): separate error classes and add mode parameter --- src/strands/models/_stabilityaiclient.py | 121 ++++++++++++++++++++--- src/strands/models/stability.py | 109 ++++++++++++++++++++ tests/strands/models/test_stability.py | 2 + 3 files changed, 218 insertions(+), 14 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index 30645131..52011a52 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -93,9 +93,90 @@ def _validate_image_pixels_and_aspect_ratio(image: Union[str, BinaryIO]) -> None class StabilityAiError(Exception): - """Base exception for Stability AI API errors.""" + """Base exception for Stability AI API errors. - pass + Attributes: + message: Error message + status_code: HTTP status code (if applicable) + response_data: Full response data from API (if available) + """ + + def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict[str, Any]] = None): + """Initialize the exception. + + Args: + message: Error message + status_code: HTTP status code that caused the error + response_data: Full response data from the API + """ + super().__init__(message) + self.message = message + self.status_code = status_code + self.response_data = response_data or {} + + def __str__(self) -> str: + """String representation of the error.""" + if self.status_code: + return f"[HTTP {self.status_code}] {self.message}" + return self.message + + +### Specific error classes for common API errors +class AuthenticationError(StabilityAiError): + """Raised when authentication fails (401).""" + + def __init__(self, message: str = "Authentication failed", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=401, response_data=response_data) + + +class BadRequestError(StabilityAiError): + """Raised when request parameters are invalid (400).""" + + def __init__(self, message: str = "Invalid request parameters", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=400, response_data=response_data) + + +class ContentModerationError(StabilityAiError): + """Raised when content is flagged by moderation (403).""" + + def __init__(self, message: str = "Content flagged by moderation", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=403, response_data=response_data) + + +class PayloadTooLargeError(StabilityAiError): + """Raised when request exceeds size limit (413).""" + + def __init__(self, message: str = "Request too large (max 10MiB)", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=413, response_data=response_data) + + +class ValidationError(StabilityAiError): + """Raised when request validation fails (422).""" + + def __init__(self, message: str = "Request validation failed", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=422, response_data=response_data) + + +class RateLimitError(StabilityAiError): + """Raised when rate limit is exceeded (429).""" + + def __init__(self, message: str = "Rate limit exceeded", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=429, response_data=response_data) + + +class InternalServerError(StabilityAiError): + """Raised when server encounters an error (500).""" + + def __init__(self, message: str = "Internal server error", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=500, response_data=response_data) + + +class NetworkError(StabilityAiError): + """Raised when network request fails.""" + + def __init__(self, message: str = "Network request failed", original_error: Optional[Exception] = None): + super().__init__(message) + self.original_error = original_error class StabilityAiClient: @@ -251,29 +332,41 @@ def _generate_image( files=files, ) - # Handle different response status codes + # Handle successful response if response.status_code == 200: if return_json: return cast(Dict[str, Any], response.json()) return cast(bytes, response.content) - elif response.status_code == 401: - raise StabilityAiError( - f"Unauthorized: check authentication credentials: {response.json().get('errors', 'Unknown error')}" + + # Parse error response + try: + error_data = response.json() + error_message = error_data.get("errors", error_data.get("message", "Unknown error")) + except ValueError: + error_data = {} + error_message = response.text or "Unknown error" + + # Handle specific error cases + if response.status_code == 401: + raise AuthenticationError( + f"Unauthorized: check authentication credentials: {error_message}", response_data=error_data ) elif response.status_code == 400: - raise StabilityAiError(f"Invalid parameters: {response.json().get('errors', 'Unknown error')}") + raise BadRequestError(f"Invalid parameters: {error_message}", response_data=error_data) elif response.status_code == 403: - raise StabilityAiError("Request flagged by content moderation") + raise ContentModerationError("Request flagged by content moderation", response_data=error_data) elif response.status_code == 413: - raise StabilityAiError("Request too large (max 10MiB)") + raise PayloadTooLargeError("Request too large (max 10MiB)", response_data=error_data) elif response.status_code == 422: - raise StabilityAiError(f"Request rejected: {response.json().get('errors', 'Unknown error')}") + raise ValidationError(f"Request rejected: {error_message}", response_data=error_data) elif response.status_code == 429: - raise StabilityAiError("Rate limit exceeded (max 150 requests per 10 seconds)") + raise RateLimitError("Rate limit exceeded (max 150 requests per 10 seconds)", response_data=error_data) elif response.status_code == 500: - raise StabilityAiError("Internal server error") + raise InternalServerError("Internal server error", response_data=error_data) else: - raise StabilityAiError(f"Unexpected error: {response.status_code}") + raise StabilityAiError( + f"Unexpected error: {error_message}", status_code=response.status_code, response_data=error_data + ) except requests.exceptions.RequestException as e: - raise StabilityAiError(f"Request failed: {str(e)}") from e + raise NetworkError(f"Request failed: {str(e)}", original_error=e) from e diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index ce9120dd..38d36f0e 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -50,6 +50,29 @@ class StylePreset(Enum): TILE_TEXTURE = "tile-texture" +<<<<<<< HEAD +======= +class Defaults: + """Default values for Stability AI configuration.""" + + ASPECT_RATIO = "1:1" + OUTPUT_FORMAT = OutputFormat.PNG + STYLE_PRESET = StylePreset.PHOTOGRAPHIC + STRENGTH = 0.35 + MODE = "text-to-image" + + +class ChunkTypes: + """Chunk type constants.""" + + MESSAGE_START = "message_start" + CONTENT_START = "content_start" + CONTENT_BLOCK_DELTA = "content_block_delta" + CONTENT_STOP = "content_stop" + MESSAGE_STOP = "message_stop" + + +>>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) class StabilityAiImageModel(Model): """Your custom model provider implementation.""" @@ -57,8 +80,19 @@ class StabilityAiImageModelConfig(TypedDict): """Configuration your model. Attributes: +<<<<<<< HEAD model_id: ID of Custom model (required). params: Model parameters (e.g., max_tokens). +======= + model_id: ID of the Stability AI model (required). + aspect_ratio: Aspect ratio of the output image. + seed: Random seed for generation. + output_format: Output format (jpeg, png, webp). + style_preset: Style preset for image generation. + image: Input image for img2img generation. + mode: Mode of operation (text-to-image, image-to-image). + strength: Influence of input image on output (0.0-1.0). +>>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) """ """ @@ -77,6 +111,7 @@ class StabilityAiImageModelConfig(TypedDict): output_format: NotRequired[OutputFormat] # defaults to PNG style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC image: NotRequired[str] # defaults to None + mode: NotRequired[str] strength: NotRequired[float] # defaults to 0.35 def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: @@ -117,6 +152,80 @@ def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelCon raise ValueError("model_id is required") self.client = StabilityAiClient(api_key=api_key, model_id=model_id) +<<<<<<< HEAD +======= + def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: + """Validate and convert configuration values to proper types.""" + self._convert_output_format(config_dict) + self._convert_style_preset(config_dict) + + def _convert_output_format(self, config_dict: dict[str, Any]) -> None: + """Convert string output_format to enum if needed.""" + if "output_format" in config_dict and isinstance(config_dict["output_format"], str): + try: + config_dict["output_format"] = OutputFormat(config_dict["output_format"]) + except ValueError as e: + valid_formats = [f.value for f in OutputFormat] + raise ValueError(f"output_format must be one of: {valid_formats}") from e + + def _convert_style_preset(self, config_dict: dict[str, Any]) -> None: + """Convert string style_preset to enum if needed.""" + if "style_preset" in config_dict and isinstance(config_dict["style_preset"], str): + try: + config_dict["style_preset"] = StylePreset(config_dict["style_preset"]) + except ValueError as e: + valid_presets = [p.value for p in StylePreset] + raise ValueError(f"style_preset must be one of: {valid_presets}") from e + + def _extract_prompt_from_messages(self, messages: Messages) -> str: + """Extract the last user message as prompt. + + Args: + messages: List of conversation messages. + + Returns: + The extracted prompt text. + + Raises: + ValueError: If no user message with text content is found. + """ + for message in reversed(messages): + if message["role"] == "user": + for content in message["content"]: + if isinstance(content, dict) and "text" in content: + return content["text"] + raise ValueError("No user message found in the conversation") + + def _build_base_request(self, prompt: str) -> dict[str, Any]: + """Build the base request with required parameters. + + Args: + prompt: The text prompt for image generation. + + Returns: + Dictionary with base request parameters. + """ + return { + "prompt": prompt, + "aspect_ratio": self.config.get("aspect_ratio", Defaults.ASPECT_RATIO), + "output_format": self.config.get("output_format", Defaults.OUTPUT_FORMAT).value, + "style_preset": self.config.get("style_preset", Defaults.STYLE_PRESET).value, + "mode": self.config.get("mode", Defaults.MODE), + } + + def _add_optional_parameters(self, request: dict[str, Any]) -> None: + """Add optional parameters to the request if they exist in config. + + Args: + request: The request dictionary to modify. + """ + if "seed" in self.config: + request["seed"] = self.config["seed"] + if self.config.get("image") is not None: + request["image"] = self.config["image"] + request["strength"] = self.config.get("strength", Defaults.STRENGTH) + +>>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) @override def format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index 1f86c7d8..188fe96e 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -109,6 +109,7 @@ def test_format_request(model, messages): "prompt": "a beautiful sunset over mountains", "aspect_ratio": "1:1", "output_format": "png", + "mode": "text-to-image", "style_preset": "photographic", } @@ -130,6 +131,7 @@ def test_format_request_with_optional_params(model, messages): "style_preset": "photographic", "seed": 12345, "image": "base64_encoded_image", + "mode": "text-to-image", "strength": 0.5, } From e93e660a73ccb92cceb50c62406127290b4fe5f3 Mon Sep 17 00:00:00 2001 From: satsumas Date: Thu, 19 Jun 2025 16:12:29 +0000 Subject: [PATCH 05/10] refactor(stability): model_id is used by stream method to determine model to address --- src/strands/models/_stabilityaiclient.py | 26 ++++---- src/strands/models/stability.py | 79 +++++++++--------------- tests/strands/models/test_stability.py | 2 +- 3 files changed, 45 insertions(+), 62 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index 52011a52..cad2c336 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -188,19 +188,14 @@ class StabilityAiClient: "stability.sd3-5-large-v1:0": "https://api.stability.ai/v2beta/stable-image/generate/sd3", } - def __init__( - self, api_key: str, model_id: str, client_id: Optional[str] = None, client_version: Optional[str] = None - ): + def __init__(self, api_key: str, client_id: Optional[str] = None, client_version: Optional[str] = None): """Initialize the Stability AI client. Args: api_key: Your Stability API key - model_id: The model ID to use for the API request.See MODEL_ID_TO_BASE_URL for available models client_id: Optional client ID for debugging client_version: Optional client version for debugging """ - self.model_id = model_id - self.base_url = self.MODEL_ID_TO_BASE_URL[model_id] self.api_key = api_key self.client_id = client_id self.client_version = client_version @@ -223,32 +218,35 @@ def _get_headers(self, accept: str = "image/*") -> Dict[str, str]: return headers - def generate_image_bytes(self, **kwargs: Any) -> bytes: + def generate_image_bytes(self, model_id: str, **kwargs: Any) -> bytes: """Generate an image using the Stability AI API. Args: + model_id: The model ID to use for the API request. **kwargs: See _generate_image for available parameters Returns: bytes of the image """ kwargs["return_json"] = False - return cast(bytes, self._generate_image(**kwargs)) + return cast(bytes, self._generate_image(model_id, **kwargs)) - def generate_image_json(self, **kwargs: Any) -> Dict[str, Any]: + def generate_image_json(self, model_id: str, **kwargs: Any) -> Dict[str, Any]: """Generate an image using the Stability AI API. Args: + model_id: The model ID to use for the API request. **kwargs: See _generate_image for available parameters Returns: JSON response with base64 image """ kwargs["return_json"] = True - return cast(Dict[str, Any], self._generate_image(**kwargs)) + return cast(Dict[str, Any], self._generate_image(model_id, **kwargs)) def _generate_image( self, + model_id: str, prompt: str, negative_prompt: Optional[str] = None, aspect_ratio: str = "1:1", @@ -264,6 +262,7 @@ def _generate_image( """Generate an image using the Stability AI API. Args: + model_id: The model ID to use for the API request prompt: Text prompt for image generation negative_prompt: Optional text describing what not to include aspect_ratio: Aspect ratio of the output image @@ -282,6 +281,11 @@ def _generate_image( Raises: StabilityAiError: If the API request fails """ + # Validate and prepare the base URL + if model_id not in self.MODEL_ID_TO_BASE_URL: + raise ValueError(f"Invalid model_id: {model_id}. Must be one of: {list(self.MODEL_ID_TO_BASE_URL.keys())}") + base_url = self.MODEL_ID_TO_BASE_URL[model_id] + if isinstance(output_format, str): try: output_format = OutputFormat(output_format) @@ -326,7 +330,7 @@ def _generate_image( try: # Make the API request response = requests.post( - self.base_url, + base_url, headers=self._get_headers("application/json" if return_json else "image/*"), data=data, files=files, diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index 38d36f0e..78ccd436 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -50,8 +50,6 @@ class StylePreset(Enum): TILE_TEXTURE = "tile-texture" -<<<<<<< HEAD -======= class Defaults: """Default values for Stability AI configuration.""" @@ -72,7 +70,6 @@ class ChunkTypes: MESSAGE_STOP = "message_stop" ->>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) class StabilityAiImageModel(Model): """Your custom model provider implementation.""" @@ -80,10 +77,6 @@ class StabilityAiImageModelConfig(TypedDict): """Configuration your model. Attributes: -<<<<<<< HEAD - model_id: ID of Custom model (required). - params: Model parameters (e.g., max_tokens). -======= model_id: ID of the Stability AI model (required). aspect_ratio: Aspect ratio of the output image. seed: Random seed for generation. @@ -92,7 +85,6 @@ class StabilityAiImageModelConfig(TypedDict): image: Input image for img2img generation. mode: Mode of operation (text-to-image, image-to-image). strength: Influence of input image on output (0.0-1.0). ->>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) """ """ @@ -111,7 +103,7 @@ class StabilityAiImageModelConfig(TypedDict): output_format: NotRequired[OutputFormat] # defaults to PNG style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC image: NotRequired[str] # defaults to None - mode: NotRequired[str] + mode: NotRequired[str] # defaults to "text-to-image" strength: NotRequired[float] # defaults to 0.35 def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: @@ -147,15 +139,27 @@ def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelCon self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) logger.debug("config=<%s> | initializing", self.config) - model_id = self.config.get("model_id") - if model_id is None: - raise ValueError("model_id is required") - self.client = StabilityAiClient(api_key=api_key, model_id=model_id) + # model_id = self.config.get("model_id") + # if model_id is None: + # raise ValueError("model_id is required") + self.client = StabilityAiClient(api_key=api_key) -<<<<<<< HEAD -======= def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: """Validate and convert configuration values to proper types.""" + # Validate required fields first + if "model_id" not in config_dict: + raise ValueError("model_id is required in configuration") + + # Validate model_id is one of the supported models + valid_model_ids = [ + "stability.stable-image-core-v1:1", + "stability.stable-image-ultra-v1:1", + "stability.sd3-5-large-v1:0", + ] + if config_dict["model_id"] not in valid_model_ids: + raise ValueError(f"Invalid model_id: {config_dict['model_id']}. Must be one of: {valid_model_ids}") + + # Convert other fields self._convert_output_format(config_dict) self._convert_style_preset(config_dict) @@ -225,7 +229,6 @@ def _add_optional_parameters(self, request: dict[str, Any]) -> None: request["image"] = self.config["image"] request["strength"] = self.config.get("strength", Defaults.STRENGTH) ->>>>>>> 2f20399 (refactor(stability): separate error classes and add mode parameter) @override def format_request( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None @@ -238,38 +241,11 @@ def format_request( system_prompt: Optional system prompt Returns: - Formatted request parameters for the Stability AI API + Formatted request parameters for the Stability AI API. """ - # Extract the last user message as the prompt - # We do not need all the previous messages as context unlike an llm - prompt = "" - - for message in reversed(messages): - if message["role"] == "user": - # Find the text content in the message - for content in message["content"]: - if isinstance(content, dict) and "text" in content: - prompt = content["text"] - break - break - - if not prompt: - raise ValueError("No user message found in the conversation") - - # Format the request - request = { - "prompt": prompt, - "aspect_ratio": self.config.get("aspect_ratio", "1:1"), - "output_format": self.config.get("output_format", OutputFormat.PNG).value, - "style_preset": self.config.get("style_preset", StylePreset.PHOTOGRAPHIC).value, - } - - # Add optional parameters if they exist in config - if "seed" in self.config: - request["seed"] = self.config["seed"] # type: ignore[assignment] - if self.config.get("image") is not None: - request["image"] = self.config["image"] - request["strength"] = self.config.get("strength", 0.35) # type: ignore[assignment] + prompt = self._extract_prompt_from_messages(messages) + request = self._build_base_request(prompt) + self._add_optional_parameters(request) return request @@ -348,11 +324,14 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: Raises: StabilityAiError: If the API request fails """ - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} + yield {"chunk_type": ChunkTypes.MESSAGE_START} + yield {"chunk_type": ChunkTypes.CONTENT_START, "data_type": "text"} + + model_id = self.config["model_id"] + try: # Generate the image #TODO add generate_image_bytes - response_json = self.client.generate_image_json(**request) + response_json = self.client.generate_image_json(model_id, **request) # Yield the image data as a single event yield {"chunk_type": "content_block_delta", "data_type": "image", "data": response_json.get("image")} diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index 188fe96e..fa279d00 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -52,7 +52,7 @@ def test__init__(stability_client_cls, model_id): } assert tru_config == exp_config - stability_client_cls.assert_called_once_with(api_key="test_key", model_id=model_id) + stability_client_cls.assert_called_once_with(api_key="test_key") def test__init__with_string_enums(stability_client_cls, model_id): From 461e756514255e6f3cc114a6ceb934504d1a4dfc Mon Sep 17 00:00:00 2001 From: satsumas Date: Thu, 19 Jun 2025 20:56:32 +0000 Subject: [PATCH 06/10] refactor(stability): only SD3.5 accepts cfg_scale parameter --- src/strands/models/_stabilityaiclient.py | 4 +- src/strands/models/stability.py | 14 ++++- tests/strands/models/test_stability.py | 71 ++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index cad2c336..12ba188a 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -250,6 +250,7 @@ def _generate_image( prompt: str, negative_prompt: Optional[str] = None, aspect_ratio: str = "1:1", + cfg_scale: Optional[int] = 4, seed: Optional[int] = None, output_format: Union[OutputFormat, str] = "png", image: Optional[BinaryIO] = None, @@ -266,6 +267,7 @@ def _generate_image( prompt: Text prompt for image generation negative_prompt: Optional text describing what not to include aspect_ratio: Aspect ratio of the output image + cfg_scale: Optional classifier-free guidance scale, only used for stability.sd3-5-large-v1:0 seed: Random seed for generation output_format: Output format (jpeg, png, webp) image: Optional input image for img2img @@ -273,7 +275,7 @@ def _generate_image( style_preset: Optional style preset strength: Required when image is provided, controls influence of input image return_json: If True, returns JSON response with base64 image - **extra_kwargs: Additional keyword arguments (will be ignored with a warning) + **extra_kwargs: Additional keyword arguments Returns: Either image bytes or JSON response with base64 image diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index 78ccd436..cd500e62 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -79,6 +79,7 @@ class StabilityAiImageModelConfig(TypedDict): Attributes: model_id: ID of the Stability AI model (required). aspect_ratio: Aspect ratio of the output image. + cfg_scale: CFG scale for image generation (only used for stability.sd3-5-large-v1:0). seed: Random seed for generation. output_format: Output format (jpeg, png, webp). style_preset: Style preset for image generation. @@ -99,6 +100,7 @@ class StabilityAiImageModelConfig(TypedDict): # Optional parameters with defaults aspect_ratio: NotRequired[str] # defaults to "1:1" + cfg_scale: NotRequired[int] # defaults to 4. Only used for stability.sd3-5-large-v1:0 seed: NotRequired[int] # defaults to random output_format: NotRequired[OutputFormat] # defaults to PNG style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC @@ -139,9 +141,6 @@ def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelCon self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) logger.debug("config=<%s> | initializing", self.config) - # model_id = self.config.get("model_id") - # if model_id is None: - # raise ValueError("model_id is required") self.client = StabilityAiClient(api_key=api_key) def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: @@ -159,6 +158,12 @@ def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: if config_dict["model_id"] not in valid_model_ids: raise ValueError(f"Invalid model_id: {config_dict['model_id']}. Must be one of: {valid_model_ids}") + # Warn if cfg_scale is used with non-SD3.5 models + if "cfg_scale" in config_dict and config_dict["model_id"] != "stability.sd3-5-large-v1:0": + logger.warning( + "cfg_scale is only supported for stability.sd3-5-large-v1:0. It will be ignored for model %s", + config_dict["model_id"], + ) # Convert other fields self._convert_output_format(config_dict) self._convert_style_preset(config_dict) @@ -223,6 +228,9 @@ def _add_optional_parameters(self, request: dict[str, Any]) -> None: Args: request: The request dictionary to modify. """ + # Only add cfg_scale for SD3.5 model + if "cfg_scale" in self.config and self.config["model_id"] == "stability.sd3-5-large-v1:0": + request["cfg_scale"] = self.config["cfg_scale"] if "seed" in self.config: request["seed"] = self.config["seed"] if self.config.get("image") is not None: diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index fa279d00..cae51aaa 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -138,6 +138,77 @@ def test_format_request_with_optional_params(model, messages): assert request == exp_request +def test_format_request_with_cfg_scale_sd35(stability_client, messages): + """Test that cfg_scale is included in request for SD3.5 model.""" + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.sd3-5-large-v1:0", + cfg_scale=8, + ) + + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "mode": "text-to-image", + "style_preset": "photographic", + "cfg_scale": 8, + } + + assert request == exp_request + + +def test_format_request_with_cfg_scale_non_sd35(stability_client, messages): + """Test that cfg_scale is NOT included in request for non-SD3.5 models.""" + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + cfg_scale=8, # This should be ignored + ) + + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "mode": "text-to-image", + "style_preset": "photographic", + # Note: cfg_scale is not passed in + } + + assert request == exp_request + assert "cfg_scale" not in request + + +def test_update_config_change_model_id(model, messages): + """Test updating config to change model_id.""" + # Initial model uses stability.stable-image-ultra-v1:1 from fixture + initial_config = model.get_config() + assert initial_config["model_id"] == "stability.stable-image-ultra-v1:1" + + # Update to different model + model.update_config( + model_id="stability.stable-image-core-v1:1", + aspect_ratio="16:9", + ) + + updated_config = model.get_config() + exp_config = { + "model_id": "stability.stable-image-core-v1:1", + "aspect_ratio": "16:9", + "output_format": OutputFormat.PNG, + } + + assert updated_config == exp_config + + # Verify the model uses the new model_id in requests + request = model.format_request(messages) + assert request["aspect_ratio"] == "16:9" + + def test_format_request_no_user_message(): model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") messages = [{"role": "assistant", "content": [{"text": "test"}]}] From 8ec765112232fd4f407ae494cd74466c5c8d2634 Mon Sep 17 00:00:00 2001 From: satsumas Date: Thu, 19 Jun 2025 21:55:05 +0000 Subject: [PATCH 07/10] fix(stability): remove merge artefacts --- src/strands/models/_stabilityaiclient.py | 23 -------------- src/strands/models/stability.py | 39 +----------------------- 2 files changed, 1 insertion(+), 61 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index 12ba188a..898c5479 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -20,26 +20,6 @@ class OutputFormat(Enum): WEBP = "webp" -class StylePresetEnum(str, Enum): - THREE_D_MODEL = "3d-model" - ANALOG_FILM = "analog-film" - ANIME = "anime" - CINEMATIC = "cinematic" - COMIC_BOOK = "comic-book" - DIGITAL_ART = "digital-art" - ENHANCE = "enhance" - FANTASY_ART = "fantasy-art" - ISOMETRIC = "isometric" - LINE_ART = "line-art" - LOW_POLY = "low-poly" - MODELING_COMPOUND = "modeling-compound" - NEON_PUNK = "neon-punk" - ORIGAMI = "origami" - PHOTOGRAPHIC = "photographic" - PIXEL_ART = "pixel-art" - TILE_TEXTURE = "tile-texture" - - def _validate_image_pixels_and_aspect_ratio(image: Union[str, BinaryIO]) -> None: """Validates the number of pixels in the 'image' field of the request. @@ -317,9 +297,6 @@ def _generate_image( if output_format: data["output_format"] = output_format.value if style_preset: - allowed_presets = [preset.value for preset in StylePresetEnum] - if style_preset not in allowed_presets: - raise ValueError(f"'style_preset' must be one of {allowed_presets}. Got '{style_preset}'.") data["style_preset"] = style_preset # Handle input image if provided diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index 8d03304d..ff5f538c 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -117,12 +117,7 @@ def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelCon model_id = self.config.get("model_id") if model_id is None: raise ValueError("model_id is required") - self.client = StabilityAiClient(api_key=api_key, model_id=model_id) - - def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: - """Validate and convert configuration values to proper types.""" - self._convert_output_format(config_dict) - self._convert_style_preset(config_dict) + self.client = StabilityAiClient(api_key=api_key) def _convert_output_format(self, config_dict: dict[str, Any]) -> None: """Convert string output_format to enum if needed.""" @@ -145,8 +140,6 @@ def _convert_style_preset(self, config_dict: dict[str, Any]) -> None: self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) logger.debug("config=<%s> | initializing", self.config) - self.client = StabilityAiClient(api_key=api_key) - def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: """Validate and convert configuration values to proper types.""" # Validate required fields first @@ -172,25 +165,6 @@ def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: self._convert_output_format(config_dict) self._convert_style_preset(config_dict) - def _convert_output_format(self, config_dict: dict[str, Any]) -> None: - """Convert string output_format to enum if needed.""" - if "output_format" in config_dict and isinstance(config_dict["output_format"], str): - try: - config_dict["output_format"] = OutputFormat(config_dict["output_format"]) - except ValueError as e: - valid_formats = [f.value for f in OutputFormat] - raise ValueError(f"output_format must be one of: {valid_formats}") from e - - def _convert_style_preset(self, config_dict: dict[str, Any]) -> None: - """Convert string style_preset to enum if needed.""" - if "style_preset" in config_dict and isinstance(config_dict["style_preset"], str): - try: - config_dict["style_preset"] = StylePreset(config_dict["style_preset"]) - except ValueError as e: - valid_presets = [p.value for p in StylePreset] - raise ValueError(f"style_preset must be one of: {valid_presets}") from e - - def _extract_prompt_from_messages(self, messages: Messages) -> str: """Extract the last user message as prompt. @@ -261,17 +235,6 @@ def format_request( request = self._build_base_request(prompt) self._add_optional_parameters(request) - Args: - messages: List of messages containing the conversation history. - tool_specs: Optional list of tool specifications (unused for image generation). - system_prompt: Optional system prompt (unused for image generation). - - Returns: - Formatted request parameters for the Stability AI API. - """ - prompt = self._extract_prompt_from_messages(messages) - request = self._build_base_request(prompt) - self._add_optional_parameters(request) return request @override From a38c88e9e6fc0c4163cbc8258aa35893672dc275 Mon Sep 17 00:00:00 2001 From: satsumas Date: Fri, 20 Jun 2025 10:59:01 +0000 Subject: [PATCH 08/10] feat(stability): StabilityAiImageModel raises exceptions from strands --- src/strands/models/_stabilityaiclient.py | 38 +++- src/strands/models/stability.py | 97 +++++---- src/strands/types/exceptions.py | 49 +++++ tests/strands/models/test_stability.py | 259 ++++++++++++++++++++++- 4 files changed, 397 insertions(+), 46 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index 898c5479..a5650711 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -6,20 +6,44 @@ import requests from PIL import Image - # Validation classes and functions -# Other validation is performed in the JSON workflow configs -class ModeEnum(str, Enum): + + +class Mode(Enum): TEXT_TO_IMAGE = "text-to-image" IMAGE_TO_IMAGE = "image-to-image" class OutputFormat(Enum): - PNG = "png" + """Supported output formats for image generation.""" + JPEG = "jpeg" + PNG = "png" WEBP = "webp" +class StylePreset(Enum): + """Supported style presets for image generation.""" + + THREE_D_MODEL = "3d-model" + ANALOG_FILM = "analog-film" + ANIME = "anime" + CINEMATIC = "cinematic" + COMIC_BOOK = "comic-book" + DIGITAL_ART = "digital-art" + ENHANCE = "enhance" + FANTASY_ART = "fantasy-art" + ISOMETRIC = "isometric" + LINE_ART = "line-art" + LOW_POLY = "low-poly" + MODELING_COMPOUND = "modeling-compound" + NEON_PUNK = "neon-punk" + ORIGAMI = "origami" + PHOTOGRAPHIC = "photographic" + PIXEL_ART = "pixel-art" + TILE_TEXTURE = "tile-texture" + + def _validate_image_pixels_and_aspect_ratio(image: Union[str, BinaryIO]) -> None: """Validates the number of pixels in the 'image' field of the request. @@ -234,7 +258,7 @@ def _generate_image( seed: Optional[int] = None, output_format: Union[OutputFormat, str] = "png", image: Optional[BinaryIO] = None, - mode: Union[ModeEnum, str] = ModeEnum.TEXT_TO_IMAGE, + mode: Union[Mode] = Mode.TEXT_TO_IMAGE, style_preset: Optional[str] = None, strength: Optional[float] = 0.35, return_json: bool = False, @@ -278,9 +302,9 @@ def _generate_image( if isinstance(mode, str): try: - mode = ModeEnum(mode) + mode = Mode(mode) except ValueError as e: - raise ValueError(f"Invalid mode: {mode}. Must be one of: {[e.value for e in ModeEnum]}") from e + raise ValueError(f"Invalid mode: {mode}. Must be one of: {[e.value for e in Mode]}") from e # Prepare the multipart form data files: Dict[str, Union[BinaryIO, str]] = {} diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index ff5f538c..fbc047de 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -5,51 +5,42 @@ import base64 import logging -from enum import Enum from typing import Any, Iterable, Optional, TypedDict, cast from typing_extensions import NotRequired, Unpack, override from strands.types.content import Messages +from strands.types.exceptions import ( + ContentModerationException, + EventLoopException, + ModelAuthenticationException, + ModelServiceException, + ModelThrottledException, + ModelValidationException, +) from strands.types.models import Model from strands.types.streaming import ContentBlockDelta, ContentBlockDeltaEvent, StreamEvent from strands.types.tools import ToolSpec -from ._stabilityaiclient import StabilityAiClient, StabilityAiError +from ._stabilityaiclient import ( + AuthenticationError, + BadRequestError, + ContentModerationError, + InternalServerError, + Mode, + NetworkError, + OutputFormat, + PayloadTooLargeError, + RateLimitError, + StabilityAiClient, + StabilityAiError, + StylePreset, + ValidationError, +) logger = logging.getLogger(__name__) -class OutputFormat(Enum): - """Supported output formats for image generation.""" - - JPEG = "jpeg" - PNG = "png" - WEBP = "webp" - - -class StylePreset(Enum): - """Supported style presets for image generation.""" - - THREE_D_MODEL = "3d-model" - ANALOG_FILM = "analog-film" - ANIME = "anime" - CINEMATIC = "cinematic" - COMIC_BOOK = "comic-book" - DIGITAL_ART = "digital-art" - ENHANCE = "enhance" - FANTASY_ART = "fantasy-art" - ISOMETRIC = "isometric" - LINE_ART = "line-art" - LOW_POLY = "low-poly" - MODELING_COMPOUND = "modeling-compound" - NEON_PUNK = "neon-punk" - ORIGAMI = "origami" - PHOTOGRAPHIC = "photographic" - PIXEL_ART = "pixel-art" - TILE_TEXTURE = "tile-texture" - - class Defaults: """Default values for Stability AI configuration.""" @@ -57,7 +48,7 @@ class Defaults: OUTPUT_FORMAT = OutputFormat.PNG STYLE_PRESET = StylePreset.PHOTOGRAPHIC STRENGTH = 0.35 - MODE = "text-to-image" + MODE = Mode.TEXT_TO_IMAGE class ChunkTypes: @@ -98,7 +89,7 @@ class StabilityAiImageModelConfig(TypedDict): output_format: NotRequired[OutputFormat] # defaults to PNG style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC image: NotRequired[str] # defaults to None - mode: NotRequired[str] # defaults to "text-to-image" + mode: NotRequired[Mode] # defaults to "text-to-image" strength: NotRequired[float] # defaults to 0.35 def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: @@ -198,7 +189,7 @@ def _build_base_request(self, prompt: str) -> dict[str, Any]: "aspect_ratio": self.config.get("aspect_ratio", Defaults.ASPECT_RATIO), "output_format": self.config.get("output_format", Defaults.OUTPUT_FORMAT).value, "style_preset": self.config.get("style_preset", Defaults.STYLE_PRESET).value, - "mode": self.config.get("mode", Defaults.MODE), + "mode": self.config.get("mode", Defaults.MODE).value, } def _add_optional_parameters(self, request: dict[str, Any]) -> None: @@ -339,7 +330,12 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: An iterable of response events from the Stability AI model. Raises: - StabilityAiError: If the API request fails. + ModelAuthenticationException: If authentication fails + ModelValidationException: If request validation fails + ContentModerationException: If content is flagged + ModelThrottledException: If rate limit is exceeded + ModelServiceException: If server error occurs + EventLoopException: If network error occurs """ yield {"chunk_type": ChunkTypes.MESSAGE_START} yield {"chunk_type": ChunkTypes.CONTENT_START, "data_type": "text"} @@ -360,6 +356,31 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: yield {"chunk_type": ChunkTypes.CONTENT_STOP, "data_type": "text"} yield {"chunk_type": ChunkTypes.MESSAGE_STOP, "data": response_json.get("finish_reason")} + except AuthenticationError as e: + logger.error("Authentication failed: %s", str(e)) + raise ModelAuthenticationException(str(e)) from e + + except RateLimitError as e: + logger.warning("Rate limit exceeded: %s", str(e)) + raise ModelThrottledException(str(e)) from e + + except ContentModerationError as e: + logger.warning("Content flagged by moderation: %s", str(e)) + raise ContentModerationException(str(e)) from e + + except (ValidationError, BadRequestError, PayloadTooLargeError) as e: + logger.error("Request validation failed: %s", str(e)) + raise ModelValidationException(str(e)) from e + + except InternalServerError as e: + logger.error("Server error during image generation: %s", str(e)) + raise ModelServiceException(str(e), is_transient=True) from e + + except NetworkError as e: + logger.error("Network error during image generation: %s", str(e)) + raise EventLoopException(e.original_error or e, request_state=request) from e + except StabilityAiError as e: - logger.error("Failed to generate image: %s", str(e)) - raise + # Catch any other StabilityAiError subclasses + logger.error("Unexpected error during image generation: %s", str(e)) + raise ModelServiceException(str(e), is_transient=False) from e diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1ffeba4e..d500bf96 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -52,3 +52,52 @@ def __init__(self, message: str) -> None: super().__init__(message) pass + + +class ModelAuthenticationException(Exception): + """Exception raised when model authentication fails. + + This exception is raised when the API key or other authentication + credentials are invalid or expired. + """ + + pass + + +class ModelValidationException(Exception): + """Exception raised when model input validation fails. + + This exception is raised when the input parameters don't meet the + model's requirements (e.g., invalid formats, out-of-range values). + """ + + pass + + +class ContentModerationException(Exception): + """Exception raised when content is flagged by safety filters. + + This exception is raised when the model's safety systems reject + the input or output content as inappropriate. + """ + + pass + + +class ModelServiceException(Exception): + """Exception raised for model service errors. + + This is a general exception for server-side errors that aren't + covered by more specific exceptions. + """ + + def __init__(self, message: str, is_transient: bool = False) -> None: + """Initialize exception. + + Args: + message: Error message + is_transient: Whether the error is likely transient (retryable) + """ + self.message = message + self.is_transient = is_transient + super().__init__(message) diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index cae51aaa..0f354ffa 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -4,7 +4,32 @@ import pytest import strands -from strands.models.stability import OutputFormat, StabilityAiImageModel, StylePreset + +# Import the StabilityAI exceptions +from strands.models._stabilityaiclient import ( + AuthenticationError, + BadRequestError, + InternalServerError, + NetworkError, + OutputFormat, + PayloadTooLargeError, + RateLimitError, + StabilityAiError, + StylePreset, + ValidationError, +) +from strands.models._stabilityaiclient import ( + ContentModerationError as StabilityContentModerationError, +) +from strands.models.stability import StabilityAiImageModel +from strands.types.exceptions import ( + ContentModerationException, + EventLoopException, + ModelAuthenticationException, + ModelServiceException, + ModelThrottledException, + ModelValidationException, +) @pytest.fixture @@ -267,3 +292,235 @@ def test_format_chunk_unknown_type(): with pytest.raises(RuntimeError) as exc_info: model.format_chunk(event) assert "unknown type" in str(exc_info.value) + + +def test_stream_authentication_error(model, stability_client): + """Test that AuthenticationError is converted to ModelAuthenticationException.""" + stability_client.generate_image_json.side_effect = AuthenticationError( + "Invalid API key", response_data={"error": "unauthorized"} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelAuthenticationException) as exc_info: + list(model.stream(request)) + + assert "Invalid API key" in str(exc_info.value) + + +def test_stream_content_moderation_error(model, stability_client): + """Test that ContentModerationError is converted to ContentModerationException.""" + stability_client.generate_image_json.side_effect = StabilityContentModerationError( + "Content flagged by moderation", response_data={"error": "content_policy_violation"} + ) + + request = { + "prompt": "an unclothed woman on the beach", + "seed": 7, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ContentModerationException) as exc_info: + list(model.stream(request)) + + assert "Content flagged by moderation" in str(exc_info.value) + + +def test_stream_validation_error(model, stability_client): + """Test that ValidationError is converted to ModelValidationException.""" + stability_client.generate_image_json.side_effect = ValidationError( + "Prompt exceeds maximum length", response_data={"error": "validation_error", "field": "prompt"} + ) + + request = { + "prompt": "a" * 10001, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Prompt exceeds maximum length" in str(exc_info.value) + + +def test_stream_bad_request_error(model, stability_client): + """Test that BadRequestError is converted to ModelValidationException.""" + stability_client.generate_image_json.side_effect = BadRequestError( + "Invalid aspect ratio", response_data={"error": "bad_request"} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "invalid", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Invalid aspect ratio" in str(exc_info.value) + + +def test_stream_payload_too_large_error(model, stability_client): + """Test that PayloadTooLargeError is converted to ModelValidationException.""" + stability_client.generate_image_json.side_effect = PayloadTooLargeError("Request size exceeds 10MB limit") + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Request size exceeds 10MB limit" in str(exc_info.value) + + +def test_stream_rate_limit_error(model, stability_client): + """Test that RateLimitError is converted to ModelThrottledException.""" + stability_client.generate_image_json.side_effect = RateLimitError( + "Rate limit exceeded. Please retry after 60 seconds.", response_data={"retry_after": 60} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelThrottledException) as exc_info: + list(model.stream(request)) + + assert "Rate limit exceeded" in str(exc_info.value) + + +def test_stream_internal_server_error(model, stability_client): + """Test that InternalServerError is converted to ModelServiceException with is_transient=True.""" + stability_client.generate_image_json.side_effect = InternalServerError("Service temporarily unavailable") + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelServiceException) as exc_info: + list(model.stream(request)) + + assert "Service temporarily unavailable" in str(exc_info.value) + assert exc_info.value.is_transient is True + + +def test_stream_network_error(model, stability_client): + """Test that NetworkError is converted to EventLoopException.""" + original_error = ConnectionError("Connection timed out") + stability_client.generate_image_json.side_effect = NetworkError( + "Network request failed", original_error=original_error + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(EventLoopException) as exc_info: + list(model.stream(request)) + + assert exc_info.value.original_exception == original_error + assert exc_info.value.request_state == request + + +def test_stream_generic_stability_error(model, stability_client): + """Test that generic StabilityAiError is converted to ModelServiceException with is_transient=False.""" + stability_client.generate_image_json.side_effect = StabilityAiError("Unexpected error occurred", status_code=418) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelServiceException) as exc_info: + list(model.stream(request)) + + assert "Unexpected error occurred" in str(exc_info.value) + assert exc_info.value.is_transient is False + + +def test_stream_success(model, stability_client): + """Test successful image generation stream.""" + mock_response = {"image": base64.b64encode(b"fake_image_data").decode("utf-8"), "finish_reason": "SUCCESS"} + stability_client.generate_image_json.return_value = mock_response + + request = { + "prompt": "a beautiful sunset", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + events = list(model.stream(request)) + + assert len(events) == 5 + assert events[0] == {"chunk_type": "message_start"} + assert events[1] == {"chunk_type": "content_start", "data_type": "text"} + assert events[2]["chunk_type"] == "content_block_delta" + assert events[2]["data_type"] == "image" + assert events[2]["data"] == mock_response["image"] + assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} + assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} + + # Verify the client was called with the correct parameters + stability_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request) + + +def test_stream_with_invalid_api_key_string(): + """Test that invalid API key string raises ModelAuthenticationException.""" + model = StabilityAiImageModel( + api_key="12345", # Invalid API key (valid str format but not authorized) + model_id="stability.stable-image-core-v1:1", + ) + + # Mock the client to raise AuthenticationError when called + with unittest.mock.patch.object(model.client, "generate_image_json") as mock_generate: + mock_generate.side_effect = AuthenticationError("Invalid API key", response_data={"error": "unauthorized"}) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelAuthenticationException) as exc_info: + list(model.stream(request)) + + assert "Invalid API key" in str(exc_info.value) From 8b9f411d85c12b7144bf157039f6d2a072dc8882 Mon Sep 17 00:00:00 2001 From: satsumas Date: Fri, 20 Jun 2025 11:26:37 +0000 Subject: [PATCH 09/10] fix(stability): Test image-to-image mode --- tests/strands/models/test_stability.py | 57 ++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index 0f354ffa..eef20764 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -1,7 +1,9 @@ import base64 +import io import unittest.mock import pytest +from PIL import Image import strands @@ -234,6 +236,58 @@ def test_update_config_change_model_id(model, messages): assert request["aspect_ratio"] == "16:9" +def test_stream_image_to_image_mode(model, stability_client): + """Test successful image-to-image generation""" + # Create a 64x64 white PNG image + white_image = Image.new("RGB", (64, 64), color="white") + + # Convert to PNG bytes + img_buffer = io.BytesIO() + white_image.save(img_buffer, format="PNG") + img_bytes = img_buffer.getvalue() + + # Base64 encode the image + input_image_base64 = base64.b64encode(img_bytes).decode("utf-8") + + # Mock response with a different image + mock_response = { + "image": base64.b64encode(b"fake_transformed_image_data").decode("utf-8"), + "finish_reason": "SUCCESS", + } + stability_client.generate_image_json.return_value = mock_response + + request = { + "prompt": "transform this image into a sunset scene", + "image": input_image_base64, + "mode": "image-to-image", + "strength": 0.75, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + } + + events = list(model.stream(request)) + + # Verify the stream events + assert len(events) == 5 + assert events[0] == {"chunk_type": "message_start"} + assert events[1] == {"chunk_type": "content_start", "data_type": "text"} + assert events[2]["chunk_type"] == "content_block_delta" + assert events[2]["data_type"] == "image" + assert events[2]["data"] == mock_response["image"] + assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} + assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} + + # Verify the client was called with the correct parameters including image-to-image mode + stability_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request) + + # Verify the request included the base64 image and image-to-image mode + call_args = stability_client.generate_image_json.call_args[1] + assert call_args["mode"] == "image-to-image" + assert call_args["image"] == input_image_base64 + assert call_args["strength"] == 0.75 + + def test_format_request_no_user_message(): model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") messages = [{"role": "assistant", "content": [{"text": "test"}]}] @@ -294,6 +348,9 @@ def test_format_chunk_unknown_type(): assert "unknown type" in str(exc_info.value) +# Test exceptions for stream method + + def test_stream_authentication_error(model, stability_client): """Test that AuthenticationError is converted to ModelAuthenticationException.""" stability_client.generate_image_json.side_effect = AuthenticationError( From 3e13edf38e5e91070969bc939e42a4f7e1c91fd0 Mon Sep 17 00:00:00 2001 From: satsumas Date: Mon, 23 Jun 2025 19:19:20 +0000 Subject: [PATCH 10/10] feat(stability): response images can be returned as bytes --- src/strands/models/_stabilityaiclient.py | 15 ++- src/strands/models/stability.py | 39 +++++- tests/strands/models/test_stability.py | 158 +++++++++++++++++------ 3 files changed, 159 insertions(+), 53 deletions(-) diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py index a5650711..dbd6dde8 100644 --- a/src/strands/models/_stabilityaiclient.py +++ b/src/strands/models/_stabilityaiclient.py @@ -222,7 +222,7 @@ def _get_headers(self, accept: str = "image/*") -> Dict[str, str]: return headers - def generate_image_bytes(self, model_id: str, **kwargs: Any) -> bytes: + def generate_image_bytes(self, model_id: str, **kwargs: Any) -> requests.Response: """Generate an image using the Stability AI API. Args: @@ -230,10 +230,10 @@ def generate_image_bytes(self, model_id: str, **kwargs: Any) -> bytes: **kwargs: See _generate_image for available parameters Returns: - bytes of the image + requests.Response object contining image as bytes """ kwargs["return_json"] = False - return cast(bytes, self._generate_image(model_id, **kwargs)) + return cast(requests.Response, self._generate_image(model_id, **kwargs)) def generate_image_json(self, model_id: str, **kwargs: Any) -> Dict[str, Any]: """Generate an image using the Stability AI API. @@ -263,7 +263,7 @@ def _generate_image( strength: Optional[float] = 0.35, return_json: bool = False, **extra_kwargs: Any, - ) -> Union[bytes, Dict[str, Any]]: + ) -> Union[requests.Response, Dict[str, Any]]: """Generate an image using the Stability AI API. Args: @@ -343,7 +343,12 @@ def _generate_image( if response.status_code == 200: if return_json: return cast(Dict[str, Any], response.json()) - return cast(bytes, response.content) + # If return_json is False, return the full requests.Response object + # This is because data like the seed and finish_reason are not included in the image bytes response + # but need to be retreived from the headers + else: + return cast(requests.Response, response) + # return response # Parse error response try: diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py index fbc047de..1dc31d95 100644 --- a/src/strands/models/stability.py +++ b/src/strands/models/stability.py @@ -76,6 +76,7 @@ class StabilityAiImageModelConfig(TypedDict): style_preset: Style preset for image generation. image: Input image for img2img generation. mode: Mode of operation (text-to-image, image-to-image). + return_json: Return JSON response with base64-encoded image data. Defaults to False. strength: Influence of input image on output (0.0-1.0). """ @@ -90,6 +91,7 @@ class StabilityAiImageModelConfig(TypedDict): style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC image: NotRequired[str] # defaults to None mode: NotRequired[Mode] # defaults to "text-to-image" + return_json: NotRequired[bool] # defaults to False strength: NotRequired[float] # defaults to 0.35 def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: @@ -99,6 +101,9 @@ def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelCon api_key: The API key for connecting to Stability AI. **model_config: Configuration options for the model. """ + # pop the return_json to avoid it being passed into the request + self.return_json = model_config.pop("return_json", False) + config_dict = {**{"output_format": Defaults.OUTPUT_FORMAT}, **dict(model_config)} self._validate_and_convert_config(config_dict) @@ -146,6 +151,14 @@ def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: if config_dict["model_id"] not in valid_model_ids: raise ValueError(f"Invalid model_id: {config_dict['model_id']}. Must be one of: {valid_model_ids}") + if config_dict["model_id"] == "stability.sd3-5-large-v1:0": + allowed_formats = [OutputFormat.PNG, OutputFormat.JPEG] # Compare enum to enum + if config_dict["output_format"] not in allowed_formats: + raise ValueError( + f"'output_format' must be one of {[f.value for f in allowed_formats]} " + f"for {config_dict['model_id']}. Got '{config_dict['output_format']}'." + ) + # Warn if cfg_scale is used with non-SD3.5 models if "cfg_scale" in config_dict and config_dict["model_id"] != "stability.sd3-5-large-v1:0": logger.warning( @@ -263,12 +276,21 @@ def _format_content_block_delta(self, event: dict[str, Any]) -> StreamEvent: Returns: Formatted content block delta event. """ + data = event.get("data", b"") + + # Handle both bytes and base64 string + if isinstance(data, bytes): + image_bytes = data + else: + # Assume it's a base64 string + image_bytes = base64.b64decode(data) + content_block_delta = cast( ContentBlockDelta, { "image": { "format": self.config["output_format"].value, - "source": {"bytes": base64.b64decode(event.get("data", b""))}, + "source": {"bytes": image_bytes}, } }, ) @@ -343,18 +365,23 @@ def stream(self, request: dict[str, Any]) -> Iterable[Any]: model_id = self.config["model_id"] try: - # Generate the image #TODO add generate_image_bytes - response_json = self.client.generate_image_json(model_id, **request) - # Yield the image data as a single event + if self.return_json: + json_response = self.client.generate_image_json(model_id, **request) + image_data = json_response.get("image") # base64 string + finish_reason = json_response.get("finish_reason", "SUCCESS") + else: + bytes_response = self.client.generate_image_bytes(model_id, **request) # requests.Response object + image_data = bytes_response.content # raw bytes from response body + finish_reason = bytes_response.headers.get("finish-reason", "SUCCESS") # from HTTP header # Yield the image data as a single event yield { "chunk_type": ChunkTypes.CONTENT_BLOCK_DELTA, "data_type": "image", - "data": response_json.get("image"), + "data": image_data, # Either bytes or base64 string } yield {"chunk_type": ChunkTypes.CONTENT_STOP, "data_type": "text"} - yield {"chunk_type": ChunkTypes.MESSAGE_STOP, "data": response_json.get("finish_reason")} + yield {"chunk_type": ChunkTypes.MESSAGE_STOP, "data": finish_reason} except AuthenticationError as e: logger.error("Authentication failed: %s", str(e)) diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py index eef20764..10c43c56 100644 --- a/tests/strands/models/test_stability.py +++ b/tests/strands/models/test_stability.py @@ -6,8 +6,6 @@ from PIL import Image import strands - -# Import the StabilityAI exceptions from strands.models._stabilityaiclient import ( AuthenticationError, BadRequestError, @@ -91,7 +89,11 @@ def test__init__with_string_enums(stability_client_cls, model_id): ) tru_config = model.get_config() - exp_config = {"model_id": model_id, "output_format": OutputFormat.JPEG, "style_preset": StylePreset.PHOTOGRAPHIC} + exp_config = { + "model_id": model_id, + "output_format": OutputFormat.JPEG, + "style_preset": StylePreset.PHOTOGRAPHIC, + } assert tru_config == exp_config @@ -237,8 +239,11 @@ def test_update_config_change_model_id(model, messages): def test_stream_image_to_image_mode(model, stability_client): - """Test successful image-to-image generation""" + """Test successful image-to-image generation stream, receiving data as bytes.""" # Create a 64x64 white PNG image + from unittest.mock import Mock + + # Create a white 64x64 image white_image = Image.new("RGB", (64, 64), color="white") # Convert to PNG bytes @@ -249,12 +254,13 @@ def test_stream_image_to_image_mode(model, stability_client): # Base64 encode the image input_image_base64 = base64.b64encode(img_bytes).decode("utf-8") - # Mock response with a different image - mock_response = { - "image": base64.b64encode(b"fake_transformed_image_data").decode("utf-8"), - "finish_reason": "SUCCESS", - } - stability_client.generate_image_json.return_value = mock_response + # Mock response - generate_image_bytes returns a Response object + transformed_image_bytes = b"fake_transformed_image_data" + mock_response = Mock() + mock_response.content = transformed_image_bytes + mock_response.headers = {"finish-reason": "SUCCESS"} + + stability_client.generate_image_bytes.return_value = mock_response request = { "prompt": "transform this image into a sunset scene", @@ -274,18 +280,13 @@ def test_stream_image_to_image_mode(model, stability_client): assert events[1] == {"chunk_type": "content_start", "data_type": "text"} assert events[2]["chunk_type"] == "content_block_delta" assert events[2]["data_type"] == "image" - assert events[2]["data"] == mock_response["image"] + # The stream method should pass raw bytes directly + assert events[2]["data"] == transformed_image_bytes assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} - # Verify the client was called with the correct parameters including image-to-image mode - stability_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request) - - # Verify the request included the base64 image and image-to-image mode - call_args = stability_client.generate_image_json.call_args[1] - assert call_args["mode"] == "image-to-image" - assert call_args["image"] == input_image_base64 - assert call_args["strength"] == 0.75 + # Verify generate_image_bytes was called + stability_client.generate_image_bytes.assert_called_once_with(model.config["model_id"], **request) def test_format_request_no_user_message(): @@ -313,14 +314,47 @@ def test_format_chunk_content_start(): assert chunk == {"contentBlockStart": {"start": {}}} -def test_format_chunk_content_block_delta(): - model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") - raw_image_data = b"raw_image_data" - base64_encoded_data = base64.b64encode(raw_image_data) - event = {"chunk_type": "content_block_delta", "data": base64_encoded_data} +def test_format_chunk_content_block_delta(model): + """Test formatting content block delta event.""" + # For bytes data (return_json=False) + event = { + "chunk_type": "content_block_delta", + "data": b"raw_image_data", # Pass actual bytes + "data_type": "image", + } - chunk = model.format_chunk(event) - assert chunk == {"contentBlockDelta": {"delta": {"image": {"format": "png", "source": {"bytes": raw_image_data}}}}} + result = model.format_chunk(event) + + assert result == { + "contentBlockDelta": { + "delta": { + "image": { + "format": "png", + "source": {"bytes": b"raw_image_data"}, + } + } + } + } + + # For base64 string data (return_json=True) + event_base64 = { + "chunk_type": "content_block_delta", + "data": base64.b64encode(b"raw_image_data").decode("utf-8"), # Valid base64 + "data_type": "image", + } + + result_base64 = model.format_chunk(event_base64) + + assert result_base64 == { + "contentBlockDelta": { + "delta": { + "image": { + "format": "png", + "source": {"bytes": b"raw_image_data"}, + } + } + } + } def test_format_chunk_content_stop(): @@ -353,7 +387,7 @@ def test_format_chunk_unknown_type(): def test_stream_authentication_error(model, stability_client): """Test that AuthenticationError is converted to ModelAuthenticationException.""" - stability_client.generate_image_json.side_effect = AuthenticationError( + stability_client.generate_image_bytes.side_effect = AuthenticationError( "Invalid API key", response_data={"error": "unauthorized"} ) @@ -373,7 +407,7 @@ def test_stream_authentication_error(model, stability_client): def test_stream_content_moderation_error(model, stability_client): """Test that ContentModerationError is converted to ContentModerationException.""" - stability_client.generate_image_json.side_effect = StabilityContentModerationError( + stability_client.generate_image_bytes.side_effect = StabilityContentModerationError( "Content flagged by moderation", response_data={"error": "content_policy_violation"} ) @@ -394,7 +428,7 @@ def test_stream_content_moderation_error(model, stability_client): def test_stream_validation_error(model, stability_client): """Test that ValidationError is converted to ModelValidationException.""" - stability_client.generate_image_json.side_effect = ValidationError( + stability_client.generate_image_bytes.side_effect = ValidationError( "Prompt exceeds maximum length", response_data={"error": "validation_error", "field": "prompt"} ) @@ -414,7 +448,7 @@ def test_stream_validation_error(model, stability_client): def test_stream_bad_request_error(model, stability_client): """Test that BadRequestError is converted to ModelValidationException.""" - stability_client.generate_image_json.side_effect = BadRequestError( + stability_client.generate_image_bytes.side_effect = BadRequestError( "Invalid aspect ratio", response_data={"error": "bad_request"} ) @@ -434,7 +468,7 @@ def test_stream_bad_request_error(model, stability_client): def test_stream_payload_too_large_error(model, stability_client): """Test that PayloadTooLargeError is converted to ModelValidationException.""" - stability_client.generate_image_json.side_effect = PayloadTooLargeError("Request size exceeds 10MB limit") + stability_client.generate_image_bytes.side_effect = PayloadTooLargeError("Request size exceeds 10MB limit") request = { "prompt": "test prompt", @@ -452,7 +486,7 @@ def test_stream_payload_too_large_error(model, stability_client): def test_stream_rate_limit_error(model, stability_client): """Test that RateLimitError is converted to ModelThrottledException.""" - stability_client.generate_image_json.side_effect = RateLimitError( + stability_client.generate_image_bytes.side_effect = RateLimitError( "Rate limit exceeded. Please retry after 60 seconds.", response_data={"retry_after": 60} ) @@ -472,7 +506,7 @@ def test_stream_rate_limit_error(model, stability_client): def test_stream_internal_server_error(model, stability_client): """Test that InternalServerError is converted to ModelServiceException with is_transient=True.""" - stability_client.generate_image_json.side_effect = InternalServerError("Service temporarily unavailable") + stability_client.generate_image_bytes.side_effect = InternalServerError("Service temporarily unavailable") request = { "prompt": "test prompt", @@ -492,7 +526,7 @@ def test_stream_internal_server_error(model, stability_client): def test_stream_network_error(model, stability_client): """Test that NetworkError is converted to EventLoopException.""" original_error = ConnectionError("Connection timed out") - stability_client.generate_image_json.side_effect = NetworkError( + stability_client.generate_image_bytes.side_effect = NetworkError( "Network request failed", original_error=original_error ) @@ -513,7 +547,7 @@ def test_stream_network_error(model, stability_client): def test_stream_generic_stability_error(model, stability_client): """Test that generic StabilityAiError is converted to ModelServiceException with is_transient=False.""" - stability_client.generate_image_json.side_effect = StabilityAiError("Unexpected error occurred", status_code=418) + stability_client.generate_image_bytes.side_effect = StabilityAiError("Unexpected error occurred", status_code=418) request = { "prompt": "test prompt", @@ -530,10 +564,17 @@ def test_stream_generic_stability_error(model, stability_client): assert exc_info.value.is_transient is False -def test_stream_success(model, stability_client): +def test_stream_success_image_bytes(model, stability_client): """Test successful image generation stream.""" - mock_response = {"image": base64.b64encode(b"fake_image_data").decode("utf-8"), "finish_reason": "SUCCESS"} - stability_client.generate_image_json.return_value = mock_response + # Mock generate_image_bytes to return a mock Response object + from unittest.mock import Mock + + raw_image_data = b"fake_image_data" + mock_response = Mock() + mock_response.content = raw_image_data + mock_response.headers = {"finish-reason": "SUCCESS"} + + stability_client.generate_image_bytes.return_value = mock_response request = { "prompt": "a beautiful sunset", @@ -550,12 +591,13 @@ def test_stream_success(model, stability_client): assert events[1] == {"chunk_type": "content_start", "data_type": "text"} assert events[2]["chunk_type"] == "content_block_delta" assert events[2]["data_type"] == "image" - assert events[2]["data"] == mock_response["image"] + # The stream method should pass raw bytes directly + assert events[2]["data"] == raw_image_data assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} - # Verify the client was called with the correct parameters - stability_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request) + # Verify the client was called with generate_image_bytes + stability_client.generate_image_bytes.assert_called_once_with(model.config["model_id"], **request) def test_stream_with_invalid_api_key_string(): @@ -566,7 +608,7 @@ def test_stream_with_invalid_api_key_string(): ) # Mock the client to raise AuthenticationError when called - with unittest.mock.patch.object(model.client, "generate_image_json") as mock_generate: + with unittest.mock.patch.object(model.client, "generate_image_bytes") as mock_generate: mock_generate.side_effect = AuthenticationError("Invalid API key", response_data={"error": "unauthorized"}) request = { @@ -581,3 +623,35 @@ def test_stream_with_invalid_api_key_string(): list(model.stream(request)) assert "Invalid API key" in str(exc_info.value) + + +def test_stream_with_json_response(stability_client_cls): + """Test streaming with JSON response.""" + # Set up the mock before creating the model + mock_client = stability_client_cls.return_value + mock_response = {"image": base64.b64encode(b"fake_image_data").decode("utf-8"), "finish_reason": "SUCCESS"} + mock_client.generate_image_json.return_value = mock_response + + # Now create the model with return_json=True + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + return_json=True, # Explicitly enable JSON mode + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + events = list(model.stream(request)) + + # Verify the response + assert len(events) == 5 + assert events[2]["data"] == mock_response["image"] + + # Now generate_image_json should be called + mock_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request)