diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa724cdd..84e048bf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -143,3 +143,37 @@ If you discover a potential security issue in this project we ask that you notif ## Licensing See the [LICENSE](./LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. + +## Stability Specific Instructions + +* Modified files + + 1. src/strands/event_loop/streaming.py + 2. tests/strands/event_loop/test_streaming.py + + These files have fixes to make the image return to agent consumers work. + +* New Files + 1. src/strands/models/_stabilityaiclient.py - A class with a rest client implementation for Stability + 2. src/strands/models/stability.py - The main model provider implementation + 3. tests-integ/test_model_stability.py - Integration test to test the image generation workflow. + 4. tests/TODO.lis - A text file containing TODOs. + 5. tests/strands/models/test_stability.py - Unit tests for the model provider + + +* Running the stability tests and testing the code + +You can run the integration test for the stability model as follows +``` + cd path/to/strands-agents-sdk-python + export STABILITY_API_KEY= + hatch test tests-integ/test_model_stability.py +``` + +You can also install this as a editable module in a python project and write a agent + +``` +pip install -e /path/to/strands-agents-sdk-python +Sample code similar to +https://gist.github.com/sayanc82/17f0d2442acd78f74f8393f58b552c54 +``` \ No newline at end of file diff --git a/TODO.lis b/TODO.lis new file mode 100644 index 00000000..7316285a --- /dev/null +++ b/TODO.lis @@ -0,0 +1,4 @@ +Add support for all the parameters for core and 3.5 +Do specific config validation for the 3 models, support parameters like cfg for sd3.5 +Cleanup code :- Modularize. Move validation to its own methods for example. +Add documentation similar to other model providers diff --git a/pyproject.toml b/pyproject.toml index b17dcfb2..3efe90bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,10 +54,12 @@ dev = [ "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", + "Pillow>=11.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", + "types-Pillow>=10.2.0.20240822", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -119,6 +121,7 @@ lint-fix = [ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", + "Pillow>=11.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "pytest-cov>=4.1.0,<5.0.0", @@ -220,6 +223,10 @@ ignore_missing_imports = false module = "litellm" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "PIL" +ignore_missing_imports = true + [tool.ruff] line-length = 120 include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 0e9d472b..6aa5aca6 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -102,13 +102,14 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: def handle_content_block_delta( - event: ContentBlockDeltaEvent, state: dict[str, Any] + event: ContentBlockDeltaEvent, state: dict[str, Any], **kwargs: Any ) -> tuple[dict[str, Any], dict[str, Any]]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: event: Delta event. state: The current state of message processing. + **kwargs: Additional keyword arguments to pass to the callback handler. Returns: Updated state with appended text or tool input. @@ -150,7 +151,11 @@ def handle_content_block_delta( "delta": delta_content, "reasoning": True, } - + elif "image" in delta_content: + # Handle the new ImageContent structure + image_content = delta_content["image"] + state["image"] = image_content + callback_event["callback"] = {"image": image_content, "delta": delta_content, **kwargs} return state, callback_event @@ -168,7 +173,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: current_tool_use = state["current_tool_use"] text = state["text"] reasoning_text = state["reasoningText"] - + image = state["image"] if current_tool_use: if "input" not in current_tool_use: current_tool_use["input"] = "" @@ -192,7 +197,6 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: elif text: content.append({"text": text}) state["text"] = "" - elif reasoning_text: content.append( { @@ -205,6 +209,9 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: } ) state["reasoningText"] = "" + elif image: + content.append({"image": image}) + state["image"] = "" return state @@ -272,6 +279,7 @@ def process_stream( "current_tool_use": {}, "reasoningText": "", "signature": "", + "image": None, } state["content"] = state["message"]["content"] diff --git a/src/strands/models/_stabilityaiclient.py b/src/strands/models/_stabilityaiclient.py new file mode 100644 index 00000000..dbd6dde8 --- /dev/null +++ b/src/strands/models/_stabilityaiclient.py @@ -0,0 +1,384 @@ +import base64 +from enum import Enum +from io import BytesIO +from typing import Any, BinaryIO, Dict, Optional, Union, cast + +import requests +from PIL import Image + +# Validation classes and functions + + +class Mode(Enum): + TEXT_TO_IMAGE = "text-to-image" + IMAGE_TO_IMAGE = "image-to-image" + + +class OutputFormat(Enum): + """Supported output formats for image generation.""" + + JPEG = "jpeg" + PNG = "png" + WEBP = "webp" + + +class StylePreset(Enum): + """Supported style presets for image generation.""" + + THREE_D_MODEL = "3d-model" + ANALOG_FILM = "analog-film" + ANIME = "anime" + CINEMATIC = "cinematic" + COMIC_BOOK = "comic-book" + DIGITAL_ART = "digital-art" + ENHANCE = "enhance" + FANTASY_ART = "fantasy-art" + ISOMETRIC = "isometric" + LINE_ART = "line-art" + LOW_POLY = "low-poly" + MODELING_COMPOUND = "modeling-compound" + NEON_PUNK = "neon-punk" + ORIGAMI = "origami" + PHOTOGRAPHIC = "photographic" + PIXEL_ART = "pixel-art" + TILE_TEXTURE = "tile-texture" + + +def _validate_image_pixels_and_aspect_ratio(image: Union[str, BinaryIO]) -> None: + """Validates the number of pixels in the 'image' field of the request. + + The image must have a total pixel count between 4,096 and 9,437,184 (inclusive). + Not implemented yet (but required for stable image services): + If the model is outpaint, the aspect ratio must be between 1:2.5 and 2.5:1. + + Args: + image: Either a base64-encoded string or a BinaryIO object + """ + # Get the raw image data + if isinstance(image, str): + # Decode base64 string + try: + image_data = base64.b64decode(image) + except Exception as e: + raise ValueError("Invalid base64 encoding for 'image'") from e + else: + # Read from BinaryIO + image_data = image.read() + image.seek(0) # Reset the file pointer so it can be read again later + + # Attempt to open the image using Pillow + try: + with Image.open(BytesIO(image_data)) as img: + width, height = img.size + except Exception as e: + raise ValueError("Unable to open or process the image data") from e + + # Check the image type based on magic bytes (JPEG, PNG, WebP) + image_format = None + if image_data.startswith(b"\xff\xd8\xff"): # JPEG magic number + image_format = "jpeg" + elif image_data.startswith(b"\x89\x50\x4e\x47"): # PNG magic number + image_format = "png" + elif image_data.startswith(b"\x52\x49\x46\x46") and image_data[8:12] == b"WEBP": # WebP magic number + image_format = "webp" + + if not image_format: + raise ValueError("Unsupported image format. Only JPEG, PNG, or WebP are allowed.") + + total_pixels = width * height + MIN_PIXELS = 4096 + MAX_PIXELS = 9437184 + + if total_pixels < MIN_PIXELS or total_pixels > MAX_PIXELS: + raise ValueError( + f"Image total pixel count {total_pixels} is invalid. Image size (height x width) must be between " + f"{MIN_PIXELS} and {MAX_PIXELS} pixels." + ) + + +class StabilityAiError(Exception): + """Base exception for Stability AI API errors. + + Attributes: + message: Error message + status_code: HTTP status code (if applicable) + response_data: Full response data from API (if available) + """ + + def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict[str, Any]] = None): + """Initialize the exception. + + Args: + message: Error message + status_code: HTTP status code that caused the error + response_data: Full response data from the API + """ + super().__init__(message) + self.message = message + self.status_code = status_code + self.response_data = response_data or {} + + def __str__(self) -> str: + """String representation of the error.""" + if self.status_code: + return f"[HTTP {self.status_code}] {self.message}" + return self.message + + +### Specific error classes for common API errors +class AuthenticationError(StabilityAiError): + """Raised when authentication fails (401).""" + + def __init__(self, message: str = "Authentication failed", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=401, response_data=response_data) + + +class BadRequestError(StabilityAiError): + """Raised when request parameters are invalid (400).""" + + def __init__(self, message: str = "Invalid request parameters", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=400, response_data=response_data) + + +class ContentModerationError(StabilityAiError): + """Raised when content is flagged by moderation (403).""" + + def __init__(self, message: str = "Content flagged by moderation", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=403, response_data=response_data) + + +class PayloadTooLargeError(StabilityAiError): + """Raised when request exceeds size limit (413).""" + + def __init__(self, message: str = "Request too large (max 10MiB)", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=413, response_data=response_data) + + +class ValidationError(StabilityAiError): + """Raised when request validation fails (422).""" + + def __init__(self, message: str = "Request validation failed", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=422, response_data=response_data) + + +class RateLimitError(StabilityAiError): + """Raised when rate limit is exceeded (429).""" + + def __init__(self, message: str = "Rate limit exceeded", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=429, response_data=response_data) + + +class InternalServerError(StabilityAiError): + """Raised when server encounters an error (500).""" + + def __init__(self, message: str = "Internal server error", response_data: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code=500, response_data=response_data) + + +class NetworkError(StabilityAiError): + """Raised when network request fails.""" + + def __init__(self, message: str = "Network request failed", original_error: Optional[Exception] = None): + super().__init__(message) + self.original_error = original_error + + +class StabilityAiClient: + """Client for interacting with the Stability AI API.""" + + MODEL_ID_TO_BASE_URL = { + "stability.stable-image-core-v1:1": "https://api.stability.ai/v2beta/stable-image/generate/core", + "stability.stable-image-ultra-v1:1": "https://api.stability.ai/v2beta/stable-image/generate/ultra", + "stability.sd3-5-large-v1:0": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + } + + def __init__(self, api_key: str, client_id: Optional[str] = None, client_version: Optional[str] = None): + """Initialize the Stability AI client. + + Args: + api_key: Your Stability API key + client_id: Optional client ID for debugging + client_version: Optional client version for debugging + """ + self.api_key = api_key + self.client_id = client_id + self.client_version = client_version + + def _get_headers(self, accept: str = "image/*") -> Dict[str, str]: + """Get the headers for the API request. + + Args: + accept: The accept header value (image/* or application/json) + + Returns: + Dict of headers + """ + headers = {"Authorization": f"Bearer {self.api_key}", "Accept": accept} + + if self.client_id: + headers["stability-client-id"] = self.client_id + if self.client_version: + headers["stability-client-version"] = self.client_version + + return headers + + def generate_image_bytes(self, model_id: str, **kwargs: Any) -> requests.Response: + """Generate an image using the Stability AI API. + + Args: + model_id: The model ID to use for the API request. + **kwargs: See _generate_image for available parameters + + Returns: + requests.Response object contining image as bytes + """ + kwargs["return_json"] = False + return cast(requests.Response, self._generate_image(model_id, **kwargs)) + + def generate_image_json(self, model_id: str, **kwargs: Any) -> Dict[str, Any]: + """Generate an image using the Stability AI API. + + Args: + model_id: The model ID to use for the API request. + **kwargs: See _generate_image for available parameters + + Returns: + JSON response with base64 image + """ + kwargs["return_json"] = True + return cast(Dict[str, Any], self._generate_image(model_id, **kwargs)) + + def _generate_image( + self, + model_id: str, + prompt: str, + negative_prompt: Optional[str] = None, + aspect_ratio: str = "1:1", + cfg_scale: Optional[int] = 4, + seed: Optional[int] = None, + output_format: Union[OutputFormat, str] = "png", + image: Optional[BinaryIO] = None, + mode: Union[Mode] = Mode.TEXT_TO_IMAGE, + style_preset: Optional[str] = None, + strength: Optional[float] = 0.35, + return_json: bool = False, + **extra_kwargs: Any, + ) -> Union[requests.Response, Dict[str, Any]]: + """Generate an image using the Stability AI API. + + Args: + model_id: The model ID to use for the API request + prompt: Text prompt for image generation + negative_prompt: Optional text describing what not to include + aspect_ratio: Aspect ratio of the output image + cfg_scale: Optional classifier-free guidance scale, only used for stability.sd3-5-large-v1:0 + seed: Random seed for generation + output_format: Output format (jpeg, png, webp) + image: Optional input image for img2img + mode: "text-to-image" or "image-to-image" + style_preset: Optional style preset + strength: Required when image is provided, controls influence of input image + return_json: If True, returns JSON response with base64 image + **extra_kwargs: Additional keyword arguments + + Returns: + Either image bytes or JSON response with base64 image + + Raises: + StabilityAiError: If the API request fails + """ + # Validate and prepare the base URL + if model_id not in self.MODEL_ID_TO_BASE_URL: + raise ValueError(f"Invalid model_id: {model_id}. Must be one of: {list(self.MODEL_ID_TO_BASE_URL.keys())}") + base_url = self.MODEL_ID_TO_BASE_URL[model_id] + + if isinstance(output_format, str): + try: + output_format = OutputFormat(output_format) + except ValueError as e: + raise ValueError( + f"Invalid output_format: {output_format}. Must be one of: {[e.value for e in OutputFormat]}" + ) from e + + if isinstance(mode, str): + try: + mode = Mode(mode) + except ValueError as e: + raise ValueError(f"Invalid mode: {mode}. Must be one of: {[e.value for e in Mode]}") from e + + # Prepare the multipart form data + files: Dict[str, Union[BinaryIO, str]] = {} + data: Dict[str, Any] = {} + + # Add all parameters to data as strings + data["prompt"] = prompt + if negative_prompt: + data["negative_prompt"] = negative_prompt + if aspect_ratio: + data["aspect_ratio"] = aspect_ratio + if seed is not None: + data["seed"] = seed + if output_format: + data["output_format"] = output_format.value + if style_preset: + data["style_preset"] = style_preset + + # Handle input image if provided + if image: + _validate_image_pixels_and_aspect_ratio(image) + files["image"] = image + + if len(files) == 0: + files["none"] = "" + try: + # Make the API request + response = requests.post( + base_url, + headers=self._get_headers("application/json" if return_json else "image/*"), + data=data, + files=files, + ) + + # Handle successful response + if response.status_code == 200: + if return_json: + return cast(Dict[str, Any], response.json()) + # If return_json is False, return the full requests.Response object + # This is because data like the seed and finish_reason are not included in the image bytes response + # but need to be retreived from the headers + else: + return cast(requests.Response, response) + # return response + + # Parse error response + try: + error_data = response.json() + error_message = error_data.get("errors", error_data.get("message", "Unknown error")) + except ValueError: + error_data = {} + error_message = response.text or "Unknown error" + + # Handle specific error cases + if response.status_code == 401: + raise AuthenticationError( + f"Unauthorized: check authentication credentials: {error_message}", response_data=error_data + ) + elif response.status_code == 400: + raise BadRequestError(f"Invalid parameters: {error_message}", response_data=error_data) + elif response.status_code == 403: + raise ContentModerationError("Request flagged by content moderation", response_data=error_data) + elif response.status_code == 413: + raise PayloadTooLargeError("Request too large (max 10MiB)", response_data=error_data) + elif response.status_code == 422: + raise ValidationError(f"Request rejected: {error_message}", response_data=error_data) + elif response.status_code == 429: + raise RateLimitError("Rate limit exceeded (max 150 requests per 10 seconds)", response_data=error_data) + elif response.status_code == 500: + raise InternalServerError("Internal server error", response_data=error_data) + else: + raise StabilityAiError( + f"Unexpected error: {error_message}", status_code=response.status_code, response_data=error_data + ) + + except requests.exceptions.RequestException as e: + raise NetworkError(f"Request failed: {str(e)}", original_error=e) from e diff --git a/src/strands/models/stability.py b/src/strands/models/stability.py new file mode 100644 index 00000000..1fc96397 --- /dev/null +++ b/src/strands/models/stability.py @@ -0,0 +1,429 @@ +"""Stability AI model provider. + +- Docs: https://platform.stability.ai/ +""" + +import base64 +import logging +from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast + +from pydantic import BaseModel +from typing_extensions import NotRequired, Unpack, override + +from strands.types.content import Messages +from strands.types.exceptions import ( + ContentModerationException, + EventLoopException, + ModelAuthenticationException, + ModelServiceException, + ModelThrottledException, + ModelValidationException, +) +from strands.types.models import Model +from strands.types.streaming import ContentBlockDelta, ContentBlockDeltaEvent, StreamEvent +from strands.types.tools import ToolSpec + +from ._stabilityaiclient import ( + AuthenticationError, + BadRequestError, + ContentModerationError, + InternalServerError, + Mode, + NetworkError, + OutputFormat, + PayloadTooLargeError, + RateLimitError, + StabilityAiClient, + StabilityAiError, + StylePreset, + ValidationError, +) + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class Defaults: + """Default values for Stability AI configuration.""" + + ASPECT_RATIO = "1:1" + OUTPUT_FORMAT = OutputFormat.PNG + STYLE_PRESET = StylePreset.PHOTOGRAPHIC + STRENGTH = 0.35 + MODE = Mode.TEXT_TO_IMAGE + + +class ChunkTypes: + """Chunk type constants.""" + + MESSAGE_START = "message_start" + CONTENT_START = "content_start" + CONTENT_BLOCK_DELTA = "content_block_delta" + CONTENT_STOP = "content_stop" + MESSAGE_STOP = "message_stop" + + +class StabilityAiImageModel(Model): + """Stability AI image generation model provider.""" + + class StabilityAiImageModelConfig(TypedDict): + """Configuration for Stability AI image model. + + Attributes: + model_id: ID of the Stability AI model (required). + aspect_ratio: Aspect ratio of the output image. + cfg_scale: CFG scale for image generation (only used for stability.sd3-5-large-v1:0). + seed: Random seed for generation. + output_format: Output format (jpeg, png, webp). + style_preset: Style preset for image generation. + image: Input image for img2img generation. + mode: Mode of operation (text-to-image, image-to-image). + return_json: Return JSON response with base64-encoded image data. Defaults to False. + strength: Influence of input image on output (0.0-1.0). + """ + + # Required parameters + model_id: str + + # Optional parameters with defaults + aspect_ratio: NotRequired[str] # defaults to "1:1" + cfg_scale: NotRequired[int] # defaults to 4. Only used for stability.sd3-5-large-v1:0 + seed: NotRequired[int] # defaults to random + output_format: NotRequired[OutputFormat] # defaults to PNG + style_preset: NotRequired[StylePreset] # defaults to PHOTOGRAPHIC + image: NotRequired[str] # defaults to None + mode: NotRequired[Mode] # defaults to "text-to-image" + return_json: NotRequired[bool] # defaults to False + strength: NotRequired[float] # defaults to 0.35 + + def __init__(self, api_key: str, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: + """Initialize the Stability AI model provider. + + Args: + api_key: The API key for connecting to Stability AI. + **model_config: Configuration options for the model. + """ + # pop the return_json to avoid it being passed into the request + self.return_json = model_config.pop("return_json", False) + + config_dict = {**{"output_format": Defaults.OUTPUT_FORMAT}, **dict(model_config)} + self._validate_and_convert_config(config_dict) + + self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) + logger.debug("config=<%s> | initializing", self.config) + + model_id = self.config.get("model_id") + if model_id is None: + raise ValueError("model_id is required") + self.client = StabilityAiClient(api_key=api_key) + + def _convert_output_format(self, config_dict: dict[str, Any]) -> None: + """Convert string output_format to enum if needed.""" + if "output_format" in config_dict and isinstance(config_dict["output_format"], str): + try: + config_dict["output_format"] = OutputFormat(config_dict["output_format"]) + except ValueError as e: + valid_formats = [f.value for f in OutputFormat] + raise ValueError(f"output_format must be one of: {valid_formats}") from e + + def _convert_style_preset(self, config_dict: dict[str, Any]) -> None: + """Convert string style_preset to enum if needed.""" + if "style_preset" in config_dict and isinstance(config_dict["style_preset"], str): + try: + config_dict["style_preset"] = StylePreset(config_dict["style_preset"]) + except ValueError as e: + valid_presets = [p.value for p in StylePreset] + raise ValueError(f"style_preset must be one of: {valid_presets}") from e + + self.config = cast(StabilityAiImageModel.StabilityAiImageModelConfig, config_dict) + logger.debug("config=<%s> | initializing", self.config) + + def _validate_and_convert_config(self, config_dict: dict[str, Any]) -> None: + """Validate and convert configuration values to proper types.""" + # Validate required fields first + if "model_id" not in config_dict: + raise ValueError("model_id is required in configuration") + + # Validate model_id is one of the supported models + valid_model_ids = [ + "stability.stable-image-core-v1:1", + "stability.stable-image-ultra-v1:1", + "stability.sd3-5-large-v1:0", + ] + if config_dict["model_id"] not in valid_model_ids: + raise ValueError(f"Invalid model_id: {config_dict['model_id']}. Must be one of: {valid_model_ids}") + + if config_dict["model_id"] == "stability.sd3-5-large-v1:0": + allowed_formats = [OutputFormat.PNG, OutputFormat.JPEG] # Compare enum to enum + if config_dict["output_format"] not in allowed_formats: + raise ValueError( + f"'output_format' must be one of {[f.value for f in allowed_formats]} " + f"for {config_dict['model_id']}. Got '{config_dict['output_format']}'." + ) + + # Warn if cfg_scale is used with non-SD3.5 models + if "cfg_scale" in config_dict and config_dict["model_id"] != "stability.sd3-5-large-v1:0": + logger.warning( + "cfg_scale is only supported for stability.sd3-5-large-v1:0. It will be ignored for model %s", + config_dict["model_id"], + ) + # Convert other fields + self._convert_output_format(config_dict) + self._convert_style_preset(config_dict) + + def _extract_prompt_from_messages(self, messages: Messages) -> str: + """Extract the last user message as prompt. + + Args: + messages: List of conversation messages. + + Returns: + The extracted prompt text. + + Raises: + ValueError: If no user message with text content is found. + """ + for message in reversed(messages): + if message["role"] == "user": + for content in message["content"]: + if isinstance(content, dict) and "text" in content: + return content["text"] + raise ValueError("No user message found in the conversation") + + def _build_base_request(self, prompt: str) -> dict[str, Any]: + """Build the base request with required parameters. + + Args: + prompt: The text prompt for image generation. + + Returns: + Dictionary with base request parameters. + """ + return { + "prompt": prompt, + "aspect_ratio": self.config.get("aspect_ratio", Defaults.ASPECT_RATIO), + "output_format": self.config.get("output_format", Defaults.OUTPUT_FORMAT).value, + "style_preset": self.config.get("style_preset", Defaults.STYLE_PRESET).value, + "mode": self.config.get("mode", Defaults.MODE).value, + } + + def _add_optional_parameters(self, request: dict[str, Any]) -> None: + """Add optional parameters to the request if they exist in config. + + Args: + request: The request dictionary to modify. + """ + # Only add cfg_scale for SD3.5 model + if "cfg_scale" in self.config and self.config["model_id"] == "stability.sd3-5-large-v1:0": + request["cfg_scale"] = self.config["cfg_scale"] + + if "seed" in self.config: + request["seed"] = self.config["seed"] + if self.config.get("image") is not None: + request["image"] = self.config["image"] + request["strength"] = self.config.get("strength", Defaults.STRENGTH) + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + ) -> T: + """Get structured output from the model. + + Args: + output_model(Type[BaseModel]): The output model to use for the agent. + prompt(Messages): The prompt to use for the agent. + callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None. + """ + return output_model() + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Stability AI model request. + + Args: + messages: List of messages containing the conversation history + tool_specs: Optional list of tool specifications + system_prompt: Optional system prompt + + Returns: + Formatted request parameters for the Stability AI API. + """ + prompt = self._extract_prompt_from_messages(messages) + request = self._build_base_request(prompt) + self._add_optional_parameters(request) + + return request + + @override + def update_config(self, **model_config: Unpack[StabilityAiImageModelConfig]) -> None: # type: ignore[override] + """Update the model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> StabilityAiImageModelConfig: + """Get the model configuration. + + Returns: + The model configuration. + """ + return self.config + + def _format_message_start(self) -> StreamEvent: + """Format message start event.""" + return {"messageStart": {"role": "assistant"}} + + def _format_content_start(self) -> StreamEvent: + """Format content start event.""" + return {"contentBlockStart": {"start": {}}} + + def _format_content_block_delta(self, event: dict[str, Any]) -> StreamEvent: + """Format content block delta event. + + Args: + event: The event containing image data. + + Returns: + Formatted content block delta event. + """ + data = event.get("data", b"") + + # Handle both bytes and base64 string + if isinstance(data, bytes): + image_bytes = data + else: + # Assume it's a base64 string + image_bytes = base64.b64decode(data) + + content_block_delta = cast( + ContentBlockDelta, + { + "image": { + "format": self.config["output_format"].value, + "source": {"bytes": image_bytes}, + } + }, + ) + content_block_delta_event = ContentBlockDeltaEvent(delta=content_block_delta) + return {"contentBlockDelta": content_block_delta_event} + + def _format_content_stop(self) -> StreamEvent: + """Format content stop event.""" + return {"contentBlockStop": {}} + + def _format_message_stop(self, event: dict[str, Any]) -> StreamEvent: + """Format message stop event. + + Args: + event: The event containing stop reason. + + Returns: + Formatted message stop event. + """ + return {"messageStop": {"stopReason": event["data"]}} + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an event into a standardized message chunk. + + Args: + event: A response event from the Stability AI model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + chunk_type = event["chunk_type"] + + match chunk_type: + case ChunkTypes.MESSAGE_START: + return self._format_message_start() + case ChunkTypes.CONTENT_START: + return self._format_content_start() + case ChunkTypes.CONTENT_BLOCK_DELTA: + return self._format_content_block_delta(event) + case ChunkTypes.CONTENT_STOP: + return self._format_content_stop() + case ChunkTypes.MESSAGE_STOP: + return self._format_message_stop(event) + case _: + raise RuntimeError(f"chunk_type=<{chunk_type}> | unknown type") + + @override + def stream(self, request: dict[str, Any]) -> Iterable[Any]: + """Send the request to the Stability AI model and get a streaming response. + + Args: + request: The formatted request to send to the Stability AI model. + + Returns: + An iterable of response events from the Stability AI model. + + Raises: + ModelAuthenticationException: If authentication fails + ModelValidationException: If request validation fails + ContentModerationException: If content is flagged + ModelThrottledException: If rate limit is exceeded + ModelServiceException: If server error occurs + EventLoopException: If network error occurs + """ + yield {"chunk_type": ChunkTypes.MESSAGE_START} + yield {"chunk_type": ChunkTypes.CONTENT_START, "data_type": "text"} + + model_id = self.config["model_id"] + + try: + if self.return_json: + json_response = self.client.generate_image_json(model_id, **request) + image_data = json_response.get("image") # base64 string + finish_reason = json_response.get("finish_reason", "SUCCESS") + else: + bytes_response = self.client.generate_image_bytes(model_id, **request) # requests.Response object + image_data = bytes_response.content # raw bytes from response body + finish_reason = bytes_response.headers.get("finish-reason", "SUCCESS") # from HTTP header + + # Yield the image data as a single event + yield { + "chunk_type": ChunkTypes.CONTENT_BLOCK_DELTA, + "data_type": "image", + "data": image_data, # Either bytes or base64 string + } + yield {"chunk_type": ChunkTypes.CONTENT_STOP, "data_type": "text"} + yield {"chunk_type": ChunkTypes.MESSAGE_STOP, "data": finish_reason} + + except AuthenticationError as e: + logger.error("Authentication failed: %s", str(e)) + raise ModelAuthenticationException(str(e)) from e + + except RateLimitError as e: + logger.warning("Rate limit exceeded: %s", str(e)) + raise ModelThrottledException(str(e)) from e + + except ContentModerationError as e: + logger.warning("Content flagged by moderation: %s", str(e)) + raise ContentModerationException(str(e)) from e + + except (ValidationError, BadRequestError, PayloadTooLargeError) as e: + logger.error("Request validation failed: %s", str(e)) + raise ModelValidationException(str(e)) from e + + except InternalServerError as e: + logger.error("Server error during image generation: %s", str(e)) + raise ModelServiceException(str(e), is_transient=True) from e + + except NetworkError as e: + logger.error("Network error during image generation: %s", str(e)) + raise EventLoopException(e.original_error or e, request_state=request) from e + + except StabilityAiError as e: + # Catch any other StabilityAiError subclasses + logger.error("Unexpected error during image generation: %s", str(e)) + raise ModelServiceException(str(e), is_transient=False) from e diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 790e9094..f5b531b0 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -118,6 +118,7 @@ class DeltaContent(TypedDict, total=False): text: str toolUse: Dict[Literal["input"], str] + image: ImageContent class ContentBlockStartToolUse(TypedDict): @@ -142,7 +143,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Optional[ContentBlockStartToolUse] -class ContentBlockDelta(TypedDict): +class ContentBlockDelta(TypedDict, total=False): """The content block delta event. Attributes: diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1ffeba4e..d500bf96 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -52,3 +52,52 @@ def __init__(self, message: str) -> None: super().__init__(message) pass + + +class ModelAuthenticationException(Exception): + """Exception raised when model authentication fails. + + This exception is raised when the API key or other authentication + credentials are invalid or expired. + """ + + pass + + +class ModelValidationException(Exception): + """Exception raised when model input validation fails. + + This exception is raised when the input parameters don't meet the + model's requirements (e.g., invalid formats, out-of-range values). + """ + + pass + + +class ContentModerationException(Exception): + """Exception raised when content is flagged by safety filters. + + This exception is raised when the model's safety systems reject + the input or output content as inappropriate. + """ + + pass + + +class ModelServiceException(Exception): + """Exception raised for model service errors. + + This is a general exception for server-side errors that aren't + covered by more specific exceptions. + """ + + def __init__(self, message: str, is_transient: bool = False) -> None: + """Initialize exception. + + Args: + message: Error message + is_transient: Whether the error is likely transient (retryable) + """ + self.message = message + self.is_transient = is_transient + super().__init__(message) diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index 9c99b210..dcd1f868 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -12,6 +12,7 @@ from .content import ContentBlockStart, Role from .event_loop import Metrics, StopReason, Usage from .guardrails import Trace +from .media import ImageContent class MessageStartEvent(TypedDict): @@ -78,11 +79,13 @@ class ContentBlockDelta(TypedDict, total=False): reasoningContent: Contains content regarding the reasoning that is carried out by the model. text: Text fragment being streamed. toolUse: Tool use input fragment being streamed. + image: Image content being streamed. """ reasoningContent: ReasoningContentBlockDelta text: str toolUse: ContentBlockDeltaToolUse + image: ImageContent class ContentBlockDeltaEvent(TypedDict, total=False): diff --git a/tests-integ/test_model_stability.py b/tests-integ/test_model_stability.py new file mode 100644 index 00000000..864d24c0 --- /dev/null +++ b/tests-integ/test_model_stability.py @@ -0,0 +1,64 @@ +import os + +import pytest + +from strands import Agent +from strands.models.stability import OutputFormat, StabilityAiImageModel + + +@pytest.fixture +def model_id(request): + return request.param + + +@pytest.fixture +def model(model_id): + return StabilityAiImageModel( + api_key=os.getenv("STABILITY_API_KEY"), # Use the API key loaded from .env + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.PNG, + ) + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif( + "STABILITY_API_KEY" not in os.environ, + reason="STABILITY_API_KEY environment variable missing", +) +@pytest.mark.parametrize( + "model_id", + [ + "stability.stable-image-core-v1:1", + "stability.stable-image-ultra-v1:1", + "stability.sd3-5-large-v1:0", + ], + indirect=True, +) +def test_agent(agent): + result = agent("dark high contrast render of a psychedelic tree of life illuminating dust in a mystical cave.") + + # Initialize variables + image_data = None + image_format = None + + # Find image content + for content in result.message.get("content", []): + if isinstance(content, dict) and "image" in content: + image_data = content["image"]["source"]["bytes"] + image_format = content["image"]["format"] + break + + # Verify we found an image + assert image_data is not None, "No image data found in the response" + assert image_format is not None, "No image format found in the response" + + # Verify image data is not empty + assert len(image_data) > 0, "Image data should not be empty" + + # Verify image format is PNG + assert image_format == "png", f"Expected image format to be 'png', got '{image_format}'" diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index e91f4986..6608a46c 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1,5 +1,3 @@ -import unittest.mock - import pytest import strands @@ -133,6 +131,13 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Image + ( + {"delta": {"image": {"format": "png", "source": {"bytes": b"image_data"}}}}, + {"image": {"format": "png", "source": {"bytes": b"image_data"}}}, + {"image": {"format": "png", "source": {"bytes": b"image_data"}}}, + {"image": {"format": "png", "source": {"bytes": b"image_data"}}}, + ), # Empty ( {"delta": {}}, @@ -161,12 +166,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test", "input": '{"key": "value"}'}, "text": "", "reasoningText": "", + "image": None, }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Tool Use - Missing input @@ -176,12 +183,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test"}, "text": "", "reasoningText": "", + "image": None, }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Text @@ -191,12 +200,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "test", "reasoningText": "", + "image": None, }, { "content": [{"text": "test"}], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), # Reasoning @@ -207,6 +218,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "test", "signature": "123", + "image": None, }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test", "signature": "123"}}}], @@ -214,6 +226,24 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "", "signature": "123", + "image": None, + }, + ), + # Image + ( + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "image": {"format": "png", "source": {"bytes": b"image_data"}}, + }, + { + "content": [{"image": {"format": "png", "source": {"bytes": b"image_data"}}}], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "image": "", }, ), # Empty @@ -223,12 +253,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", + "image": None, }, ), ], @@ -524,6 +556,104 @@ def test_extract_usage_metrics(): }, ], ), + # Image Message - FIXED + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"image": {"format": "png", "source": {"bytes": b"image_data"}}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + "metrics": {"latencyMs": 1}, + } + }, + ], + [ + { + "callback": { + "event": { + "messageStart": { + "role": "assistant", + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStart": { + "start": {}, + }, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockDelta": { + "delta": { + "image": {"format": "png", "source": {"bytes": b"image_data"}}, + }, + }, + }, + }, + }, + { + "callback": { + "image": {"format": "png", "source": {"bytes": b"image_data"}}, + "delta": { + "image": {"format": "png", "source": {"bytes": b"image_data"}}, + }, + }, + }, + { + "callback": { + "event": { + "contentBlockStop": {}, + }, + }, + }, + { + "callback": { + "event": { + "messageStop": { + "stopReason": "end_turn", + }, + }, + }, + }, + { + "callback": { + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + }, + }, + }, + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [{"image": {"format": "png", "source": {"bytes": b"image_data"}}}], + }, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ), + }, + ], + ), ], ) def test_process_stream(response, exp_events): @@ -532,63 +662,3 @@ def test_process_stream(response, exp_events): tru_events = list(stream) assert tru_events == exp_events - - -def test_stream_messages(): - mock_model = unittest.mock.MagicMock() - mock_model.converse.return_value = [ - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - ] - - stream = strands.event_loop.streaming.stream_messages( - mock_model, - system_prompt="test prompt", - messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, - ) - - tru_events = list(stream) - exp_events = [ - { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "test", - }, - }, - }, - }, - }, - { - "callback": { - "data": "test", - "delta": { - "text": "test", - }, - }, - }, - { - "callback": { - "event": { - "contentBlockStop": {}, - }, - }, - }, - { - "stop": ( - "end_turn", - {"role": "assistant", "content": [{"text": "test"}]}, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, - ) - }, - ] - assert tru_events == exp_events - - mock_model.converse.assert_called_with( - [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], - None, - "test prompt", - ) diff --git a/tests/strands/models/test_stability.py b/tests/strands/models/test_stability.py new file mode 100644 index 00000000..10c43c56 --- /dev/null +++ b/tests/strands/models/test_stability.py @@ -0,0 +1,657 @@ +import base64 +import io +import unittest.mock + +import pytest +from PIL import Image + +import strands +from strands.models._stabilityaiclient import ( + AuthenticationError, + BadRequestError, + InternalServerError, + NetworkError, + OutputFormat, + PayloadTooLargeError, + RateLimitError, + StabilityAiError, + StylePreset, + ValidationError, +) +from strands.models._stabilityaiclient import ( + ContentModerationError as StabilityContentModerationError, +) +from strands.models.stability import StabilityAiImageModel +from strands.types.exceptions import ( + ContentModerationException, + EventLoopException, + ModelAuthenticationException, + ModelServiceException, + ModelThrottledException, + ModelValidationException, +) + + +@pytest.fixture +def stability_client_cls(): + with unittest.mock.patch.object(strands.models.stability, "StabilityAiClient") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def stability_client(stability_client_cls): + return stability_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "stability.stable-image-ultra-v1:1" + + +@pytest.fixture +def model(stability_client, model_id): + _ = stability_client + return StabilityAiImageModel(api_key="test_key", model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "a beautiful sunset over mountains"}]}] + + +def test__init__(stability_client_cls, model_id): + model = StabilityAiImageModel( + api_key="test_key", + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.JPEG, + style_preset=StylePreset.PHOTOGRAPHIC, + ) + + tru_config = model.get_config() + exp_config = { + "model_id": model_id, + "aspect_ratio": "16:9", + "output_format": OutputFormat.JPEG, + "style_preset": StylePreset.PHOTOGRAPHIC, + } + + assert tru_config == exp_config + stability_client_cls.assert_called_once_with(api_key="test_key") + + +def test__init__with_string_enums(stability_client_cls, model_id): + model = StabilityAiImageModel( + api_key="test_key", + model_id=model_id, + output_format="jpeg", + style_preset="photographic", + ) + + tru_config = model.get_config() + exp_config = { + "model_id": model_id, + "output_format": OutputFormat.JPEG, + "style_preset": StylePreset.PHOTOGRAPHIC, + } + + assert tru_config == exp_config + + +def test__init__with_invalid_output_format(): + with pytest.raises(ValueError) as exc_info: + StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + output_format="invalid", + ) + assert "output_format must be one of:" in str(exc_info.value) + + +def test__init__with_invalid_style_preset(): + with pytest.raises(ValueError) as exc_info: + StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + style_preset="invalid", + ) + assert "style_preset must be one of:" in str(exc_info.value) + + +def test_update_config(model, model_id): + model.update_config( + model_id=model_id, + aspect_ratio="16:9", + output_format=OutputFormat.JPEG, + ) + + tru_config = model.get_config() + exp_config = {"model_id": model_id, "aspect_ratio": "16:9", "output_format": OutputFormat.JPEG} + + assert tru_config == exp_config + + +def test_format_request(model, messages): + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "mode": "text-to-image", + "style_preset": "photographic", + } + + assert request == exp_request + + +def test_format_request_with_optional_params(model, messages): + model.update_config( + seed=12345, + image="base64_encoded_image", + strength=0.5, + ) + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "seed": 12345, + "image": "base64_encoded_image", + "mode": "text-to-image", + "strength": 0.5, + } + + assert request == exp_request + + +def test_format_request_with_cfg_scale_sd35(stability_client, messages): + """Test that cfg_scale is included in request for SD3.5 model.""" + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.sd3-5-large-v1:0", + cfg_scale=8, + ) + + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "mode": "text-to-image", + "style_preset": "photographic", + "cfg_scale": 8, + } + + assert request == exp_request + + +def test_format_request_with_cfg_scale_non_sd35(stability_client, messages): + """Test that cfg_scale is NOT included in request for non-SD3.5 models.""" + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + cfg_scale=8, # This should be ignored + ) + + request = model.format_request(messages) + + exp_request = { + "prompt": "a beautiful sunset over mountains", + "aspect_ratio": "1:1", + "output_format": "png", + "mode": "text-to-image", + "style_preset": "photographic", + # Note: cfg_scale is not passed in + } + + assert request == exp_request + assert "cfg_scale" not in request + + +def test_update_config_change_model_id(model, messages): + """Test updating config to change model_id.""" + # Initial model uses stability.stable-image-ultra-v1:1 from fixture + initial_config = model.get_config() + assert initial_config["model_id"] == "stability.stable-image-ultra-v1:1" + + # Update to different model + model.update_config( + model_id="stability.stable-image-core-v1:1", + aspect_ratio="16:9", + ) + + updated_config = model.get_config() + exp_config = { + "model_id": "stability.stable-image-core-v1:1", + "aspect_ratio": "16:9", + "output_format": OutputFormat.PNG, + } + + assert updated_config == exp_config + + # Verify the model uses the new model_id in requests + request = model.format_request(messages) + assert request["aspect_ratio"] == "16:9" + + +def test_stream_image_to_image_mode(model, stability_client): + """Test successful image-to-image generation stream, receiving data as bytes.""" + # Create a 64x64 white PNG image + from unittest.mock import Mock + + # Create a white 64x64 image + white_image = Image.new("RGB", (64, 64), color="white") + + # Convert to PNG bytes + img_buffer = io.BytesIO() + white_image.save(img_buffer, format="PNG") + img_bytes = img_buffer.getvalue() + + # Base64 encode the image + input_image_base64 = base64.b64encode(img_bytes).decode("utf-8") + + # Mock response - generate_image_bytes returns a Response object + transformed_image_bytes = b"fake_transformed_image_data" + mock_response = Mock() + mock_response.content = transformed_image_bytes + mock_response.headers = {"finish-reason": "SUCCESS"} + + stability_client.generate_image_bytes.return_value = mock_response + + request = { + "prompt": "transform this image into a sunset scene", + "image": input_image_base64, + "mode": "image-to-image", + "strength": 0.75, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + } + + events = list(model.stream(request)) + + # Verify the stream events + assert len(events) == 5 + assert events[0] == {"chunk_type": "message_start"} + assert events[1] == {"chunk_type": "content_start", "data_type": "text"} + assert events[2]["chunk_type"] == "content_block_delta" + assert events[2]["data_type"] == "image" + # The stream method should pass raw bytes directly + assert events[2]["data"] == transformed_image_bytes + assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} + assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} + + # Verify generate_image_bytes was called + stability_client.generate_image_bytes.assert_called_once_with(model.config["model_id"], **request) + + +def test_format_request_no_user_message(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + messages = [{"role": "assistant", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError) as exc_info: + model.format_request(messages) + assert "No user message found in the conversation" in str(exc_info.value) + + +def test_format_chunk_message_start(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "message_start"} + + chunk = model.format_chunk(event) + assert chunk == {"messageStart": {"role": "assistant"}} + + +def test_format_chunk_content_start(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "content_start"} + + chunk = model.format_chunk(event) + assert chunk == {"contentBlockStart": {"start": {}}} + + +def test_format_chunk_content_block_delta(model): + """Test formatting content block delta event.""" + # For bytes data (return_json=False) + event = { + "chunk_type": "content_block_delta", + "data": b"raw_image_data", # Pass actual bytes + "data_type": "image", + } + + result = model.format_chunk(event) + + assert result == { + "contentBlockDelta": { + "delta": { + "image": { + "format": "png", + "source": {"bytes": b"raw_image_data"}, + } + } + } + } + + # For base64 string data (return_json=True) + event_base64 = { + "chunk_type": "content_block_delta", + "data": base64.b64encode(b"raw_image_data").decode("utf-8"), # Valid base64 + "data_type": "image", + } + + result_base64 = model.format_chunk(event_base64) + + assert result_base64 == { + "contentBlockDelta": { + "delta": { + "image": { + "format": "png", + "source": {"bytes": b"raw_image_data"}, + } + } + } + } + + +def test_format_chunk_content_stop(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "content_stop"} + + chunk = model.format_chunk(event) + assert chunk == {"contentBlockStop": {}} + + +def test_format_chunk_message_stop(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "message_stop", "data": "stop"} + + chunk = model.format_chunk(event) + assert chunk == {"messageStop": {"stopReason": "stop"}} + + +def test_format_chunk_unknown_type(): + model = StabilityAiImageModel(api_key="test_key", model_id="stability.stable-image-core-v1:1") + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError) as exc_info: + model.format_chunk(event) + assert "unknown type" in str(exc_info.value) + + +# Test exceptions for stream method + + +def test_stream_authentication_error(model, stability_client): + """Test that AuthenticationError is converted to ModelAuthenticationException.""" + stability_client.generate_image_bytes.side_effect = AuthenticationError( + "Invalid API key", response_data={"error": "unauthorized"} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelAuthenticationException) as exc_info: + list(model.stream(request)) + + assert "Invalid API key" in str(exc_info.value) + + +def test_stream_content_moderation_error(model, stability_client): + """Test that ContentModerationError is converted to ContentModerationException.""" + stability_client.generate_image_bytes.side_effect = StabilityContentModerationError( + "Content flagged by moderation", response_data={"error": "content_policy_violation"} + ) + + request = { + "prompt": "an unclothed woman on the beach", + "seed": 7, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ContentModerationException) as exc_info: + list(model.stream(request)) + + assert "Content flagged by moderation" in str(exc_info.value) + + +def test_stream_validation_error(model, stability_client): + """Test that ValidationError is converted to ModelValidationException.""" + stability_client.generate_image_bytes.side_effect = ValidationError( + "Prompt exceeds maximum length", response_data={"error": "validation_error", "field": "prompt"} + ) + + request = { + "prompt": "a" * 10001, + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Prompt exceeds maximum length" in str(exc_info.value) + + +def test_stream_bad_request_error(model, stability_client): + """Test that BadRequestError is converted to ModelValidationException.""" + stability_client.generate_image_bytes.side_effect = BadRequestError( + "Invalid aspect ratio", response_data={"error": "bad_request"} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "invalid", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Invalid aspect ratio" in str(exc_info.value) + + +def test_stream_payload_too_large_error(model, stability_client): + """Test that PayloadTooLargeError is converted to ModelValidationException.""" + stability_client.generate_image_bytes.side_effect = PayloadTooLargeError("Request size exceeds 10MB limit") + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelValidationException) as exc_info: + list(model.stream(request)) + + assert "Request size exceeds 10MB limit" in str(exc_info.value) + + +def test_stream_rate_limit_error(model, stability_client): + """Test that RateLimitError is converted to ModelThrottledException.""" + stability_client.generate_image_bytes.side_effect = RateLimitError( + "Rate limit exceeded. Please retry after 60 seconds.", response_data={"retry_after": 60} + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelThrottledException) as exc_info: + list(model.stream(request)) + + assert "Rate limit exceeded" in str(exc_info.value) + + +def test_stream_internal_server_error(model, stability_client): + """Test that InternalServerError is converted to ModelServiceException with is_transient=True.""" + stability_client.generate_image_bytes.side_effect = InternalServerError("Service temporarily unavailable") + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelServiceException) as exc_info: + list(model.stream(request)) + + assert "Service temporarily unavailable" in str(exc_info.value) + assert exc_info.value.is_transient is True + + +def test_stream_network_error(model, stability_client): + """Test that NetworkError is converted to EventLoopException.""" + original_error = ConnectionError("Connection timed out") + stability_client.generate_image_bytes.side_effect = NetworkError( + "Network request failed", original_error=original_error + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(EventLoopException) as exc_info: + list(model.stream(request)) + + assert exc_info.value.original_exception == original_error + assert exc_info.value.request_state == request + + +def test_stream_generic_stability_error(model, stability_client): + """Test that generic StabilityAiError is converted to ModelServiceException with is_transient=False.""" + stability_client.generate_image_bytes.side_effect = StabilityAiError("Unexpected error occurred", status_code=418) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelServiceException) as exc_info: + list(model.stream(request)) + + assert "Unexpected error occurred" in str(exc_info.value) + assert exc_info.value.is_transient is False + + +def test_stream_success_image_bytes(model, stability_client): + """Test successful image generation stream.""" + # Mock generate_image_bytes to return a mock Response object + from unittest.mock import Mock + + raw_image_data = b"fake_image_data" + mock_response = Mock() + mock_response.content = raw_image_data + mock_response.headers = {"finish-reason": "SUCCESS"} + + stability_client.generate_image_bytes.return_value = mock_response + + request = { + "prompt": "a beautiful sunset", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + events = list(model.stream(request)) + + assert len(events) == 5 + assert events[0] == {"chunk_type": "message_start"} + assert events[1] == {"chunk_type": "content_start", "data_type": "text"} + assert events[2]["chunk_type"] == "content_block_delta" + assert events[2]["data_type"] == "image" + # The stream method should pass raw bytes directly + assert events[2]["data"] == raw_image_data + assert events[3] == {"chunk_type": "content_stop", "data_type": "text"} + assert events[4] == {"chunk_type": "message_stop", "data": "SUCCESS"} + + # Verify the client was called with generate_image_bytes + stability_client.generate_image_bytes.assert_called_once_with(model.config["model_id"], **request) + + +def test_stream_with_invalid_api_key_string(): + """Test that invalid API key string raises ModelAuthenticationException.""" + model = StabilityAiImageModel( + api_key="12345", # Invalid API key (valid str format but not authorized) + model_id="stability.stable-image-core-v1:1", + ) + + # Mock the client to raise AuthenticationError when called + with unittest.mock.patch.object(model.client, "generate_image_bytes") as mock_generate: + mock_generate.side_effect = AuthenticationError("Invalid API key", response_data={"error": "unauthorized"}) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + with pytest.raises(ModelAuthenticationException) as exc_info: + list(model.stream(request)) + + assert "Invalid API key" in str(exc_info.value) + + +def test_stream_with_json_response(stability_client_cls): + """Test streaming with JSON response.""" + # Set up the mock before creating the model + mock_client = stability_client_cls.return_value + mock_response = {"image": base64.b64encode(b"fake_image_data").decode("utf-8"), "finish_reason": "SUCCESS"} + mock_client.generate_image_json.return_value = mock_response + + # Now create the model with return_json=True + model = StabilityAiImageModel( + api_key="test_key", + model_id="stability.stable-image-core-v1:1", + return_json=True, # Explicitly enable JSON mode + ) + + request = { + "prompt": "test prompt", + "aspect_ratio": "1:1", + "output_format": "png", + "style_preset": "photographic", + "mode": "text-to-image", + } + + events = list(model.stream(request)) + + # Verify the response + assert len(events) == 5 + assert events[2]["data"] == mock_response["image"] + + # Now generate_image_json should be called + mock_client.generate_image_json.assert_called_once_with(model.config["model_id"], **request)