diff --git a/docs/griptape-framework/drivers/image-generation-drivers.md b/docs/griptape-framework/drivers/image-generation-drivers.md index 75316134ae..aae25c64dc 100644 --- a/docs/griptape-framework/drivers/image-generation-drivers.md +++ b/docs/griptape-framework/drivers/image-generation-drivers.md @@ -179,3 +179,130 @@ agent = Agent(tools=[ agent.run("Generate a watercolor painting of a dog riding a skateboard") ``` + +### HuggingFace Pipelines + +!!! info + This driver requires the `drivers-image-generation-huggingface` [extra](../index.md#extras). + +The [HuggingFace Pipelines Image Generation Driver](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) enables image generation through locally-hosted models using the HuggingFace [Diffusers](https://huggingface.co/docs/diffusers/en/index) library. This Driver requires a [Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/index.md) to prepare the appropriate Pipeline. + +This Driver requires a `model` configuration, specifying the model to use for image generation. The value of the `model` configuration must be one of the following: + + - A model name from the HuggingFace Model Hub, like `stabilityai/stable-diffusion-3-medium-diffusers` + - A path to the directory containing a model on the filesystem, like `./models/stable-diffusion-3/` + - A path to a file containing a model on the filesystem, like `./models/sd3_medium_incl_clips.safetensors` + +The `device` configuration specifies the hardware device used to run inference. Common values include `cuda` (supporting CUDA-enabled GPUs), `cpu` (supported by a device's CPU), and `mps` (supported by Apple silicon GPUs). For more information, see [HuggingFace's documentation](https://huggingface.co/docs/transformers/en/perf_infer_gpu_one) on GPU inference. + +#### Stable Diffusion 3 Image Generation Pipeline Driver + +!!! info + The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra. + +The [Stable Diffusion 3 Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.md) provides a Stable `Diffusion3DiffusionPipeline` for text-to-image generations via the [HuggingFace Pipelines Image Generation Driver's](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) `.try_text_to_image()` method. This Driver accepts a text prompt and configurations including Stable Diffusion 3 model, output image size, generation seed, and inference steps. + +Image generation consumes substantial memory. On devices with limited VRAM, it may be necessary to enable the `enable_model_cpu_offload` or `drop_t5_encoder` configurations. For more information, see [HuggingFace's documentation](https://huggingface.co/docs/diffusers/en/optimization/memory) on reduced memory usage. + +```python title="PYTEST_IGNORE" +from griptape.structures import Pipeline +from griptape.tasks import PromptImageGenerationTask +from griptape.engines import PromptImageGenerationEngine +from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \ + StableDiffusion3ImageGenerationPipelineDriver +from griptape.artifacts import TextArtifact + +image_generation_task = PromptImageGenerationTask( + input=TextArtifact("landscape photograph, verdant, countryside, 8k"), + image_generation_engine=PromptImageGenerationEngine( + image_generation_driver=HuggingFacePipelineImageGenerationDriver( + model="stabilityai/stable-diffusion-3-medium-diffusers", + device="cuda", + pipeline_driver=StableDiffusion3ImageGenerationPipelineDriver( + height=512, + width=512, + ) + ) + ) +) + +output_artifact = Pipeline(tasks=[image_generation_task]).run().output +``` + +#### Stable Diffusion 3 Img2Img Image Generation Pipeline Driver + +!!! info + The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra. + +The [Stable Diffusion 3 Img2Img Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.md) provides a `StableDiffusion3Img2ImgPipeline` for image-to-image generations, accepting a text prompt and input image. This Driver accepts a text prompt, an input image, and configurations including Stable Diffusion 3 model, output image size, inference steps, generation seed, and strength of generation over the input image. + +```python title="PYTEST_IGNORE" +from pathlib import Path + +from griptape.structures import Pipeline +from griptape.tasks import VariationImageGenerationTask +from griptape.engines import VariationImageGenerationEngine +from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \ + StableDiffusion3Img2ImgImageGenerationPipelineDriver +from griptape.artifacts import TextArtifact, ImageArtifact +from griptape.loaders import ImageLoader + +prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") +input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) + +image_variation_task = VariationImageGenerationTask( + input=(prompt_artifact, input_image_artifact), + image_generation_engine=PromptImageGenerationEngine( + image_generation_driver=HuggingFacePipelineImageGenerationDriver( + model="stabilityai/stable-diffusion-3-medium-diffusers", + device="cuda", + pipeline_driver=StableDiffusion3Img2ImgImageGenerationPipelineDriver( + height=1024, + width=1024, + ) + ) + ) +) + +output_artifact = Pipeline(tasks=[image_variation_task]).run().output +``` + +#### StableDiffusion3ControlNetImageGenerationPipelineDriver + +!!! note + The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra. + +The [StableDiffusion3ControlNetImageGenerationPipelineDriver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.md) provides a `StableDiffusion3ControlNetPipeline` for image-to-image generations, accepting a text prompt and a control image. This Driver accepts a text prompt, a control image, and configurations including Stable Diffusion 3 model, ControlNet model, output image size, generation seed, inference steps, and the degree to which the model adheres to the control image. + +```python title="PYTEST_IGNORE" +from pathlib import Path + +from griptape.structures import Pipeline +from griptape.tasks import VariationImageGenerationTask +from griptape.engines import VariationImageGenerationEngine +from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \ + StableDiffusion3ControlNetImageGenerationPipelineDriver +from griptape.artifacts import TextArtifact, ImageArtifact +from griptape.loaders import ImageLoader + +prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") +control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes()) + +controlnet_task = VariationImageGenerationTask( + input=(prompt_artifact, control_image_artifact), + image_generation_engine=PromptImageGenerationEngine( + image_generation_driver=HuggingFacePipelineImageGenerationDriver( + model="stabilityai/stable-diffusion-3-medium-diffusers", + device="cuda", + pipeline_driver=StableDiffusion3ControlNetImageGenerationPipelineDriver( + controlnet_model="InstantX/SD3-Controlnet-Canny", + control_strength=0.8, + height=768, + width=1024, + ) + ) + ) +) + +output_artifact = Pipeline(tasks=[controlnet_task]).run().output +``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 4a516caf9c..f948f1be1d 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -54,6 +54,19 @@ ) from .image_generation_model.bedrock_titan_image_generation_model_driver import BedrockTitanImageGenerationModelDriver +from .image_generation_pipeline.base_image_generation_pipeline_driver import ( + BaseDiffusionImageGenerationPipelineDriver, +) +from .image_generation_pipeline.stable_diffusion_3_image_generation_pipeline_driver import ( + StableDiffusion3ImageGenerationPipelineDriver, +) +from .image_generation_pipeline.stable_diffusion_3_img_2_img_image_generation_pipeline_driver import ( + StableDiffusion3Img2ImgImageGenerationPipelineDriver, +) +from .image_generation_pipeline.stable_diffusion_3_controlnet_image_generation_pipeline_driver import ( + StableDiffusion3ControlNetImageGenerationPipelineDriver, +) + from .image_generation.base_image_generation_driver import BaseImageGenerationDriver from .image_generation.base_multi_model_image_generation_driver import BaseMultiModelImageGenerationDriver from .image_generation.openai_image_generation_driver import OpenAiImageGenerationDriver @@ -61,6 +74,9 @@ from .image_generation.amazon_bedrock_image_generation_driver import AmazonBedrockImageGenerationDriver from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver +from .image_generation.huggingface_pipeline_image_generation_driver import ( + HuggingFacePipelineImageGenerationDriver, +) from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver @@ -164,6 +180,10 @@ "BaseImageGenerationModelDriver", "BedrockStableDiffusionImageGenerationModelDriver", "BedrockTitanImageGenerationModelDriver", + "BaseDiffusionImageGenerationPipelineDriver", + "StableDiffusion3ImageGenerationPipelineDriver", + "StableDiffusion3Img2ImgImageGenerationPipelineDriver", + "StableDiffusion3ControlNetImageGenerationPipelineDriver", "BaseImageGenerationDriver", "BaseMultiModelImageGenerationDriver", "OpenAiImageGenerationDriver", @@ -171,6 +191,7 @@ "AmazonBedrockImageGenerationDriver", "AzureOpenAiImageGenerationDriver", "DummyImageGenerationDriver", + "HuggingFacePipelineImageGenerationDriver", "BaseImageQueryModelDriver", "BedrockClaudeImageQueryModelDriver", "BaseImageQueryDriver", diff --git a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py new file mode 100644 index 0000000000..46dbcd331c --- /dev/null +++ b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import io +from abc import ABC +from typing import Optional + +from attrs import define, field + +from griptape.artifacts import ImageArtifact +from griptape.drivers import BaseDiffusionImageGenerationPipelineDriver, BaseImageGenerationDriver +from griptape.utils import import_optional_dependency + + +@define +class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): + """Image generation driver for models hosted by Hugging Face's Diffusion Pipeline. + + For more information, see the HuggingFace documentation for Diffusers: + https://huggingface.co/docs/diffusers/en/index + + Attributes: + pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model. + device: The hardware device used for inference. For example, "cpu", "cuda", or "mps". + output_format: The format the generated image is returned in. Defaults to "png". + """ + + pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True}) + device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + output_format: str = field(default="png", kw_only=True, metadata={"serializable": True}) + + def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: + pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) + + prompt = ", ".join(prompts) + output_image = pipeline( + prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device) + ).images[0] + + buffer = io.BytesIO() + output_image.save(buffer, format=self.output_format.upper()) + + return ImageArtifact( + value=buffer.getvalue(), + format=self.output_format.lower(), + height=output_image.height, + width=output_image.width, + prompt=prompt, + ) + + def try_image_variation( + self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None + ) -> ImageArtifact: + pil_image = import_optional_dependency("PIL.Image") + + pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) + + prompt = ", ".join(prompts) + input_image = pil_image.open(io.BytesIO(image.value)) + # The size of the input image drives the size of the output image. + # Resize the input image to the configured dimensions. + output_width, output_height = self.pipeline_driver.output_image_dimensions + if input_image.height != output_height or input_image.width != output_width: + input_image = input_image.resize((output_width, output_height)) + + output_image = pipeline( + prompt, + **self.pipeline_driver.make_image_param(input_image), + **self.pipeline_driver.make_additional_params(negative_prompts, self.device), + ).images[0] + + buffer = io.BytesIO() + output_image.save(buffer, format=self.output_format.upper()) + + return ImageArtifact( + value=buffer.getvalue(), + format=self.output_format.lower(), + height=output_image.height, + width=output_image.width, + prompt=prompt, + ) + + def try_image_inpainting( + self, + prompts: list[str], + image: ImageArtifact, + mask: ImageArtifact, + negative_prompts: Optional[list[str]] = None, + ) -> ImageArtifact: + raise NotImplementedError("Inpainting is not supported by this driver.") + + def try_image_outpainting( + self, + prompts: list[str], + image: ImageArtifact, + mask: ImageArtifact, + negative_prompts: Optional[list[str]] = None, + ) -> ImageArtifact: + raise NotImplementedError("Outpainting is not supported by this driver.") diff --git a/griptape/drivers/image_generation_pipeline/__init__.py b/griptape/drivers/image_generation_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py b/griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py new file mode 100644 index 0000000000..418034e7c0 --- /dev/null +++ b/griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +from attrs import define + +if TYPE_CHECKING: + from PIL.Image import Image + + +@define +class BaseDiffusionImageGenerationPipelineDriver(ABC): + @abstractmethod + def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ... + + @abstractmethod + def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ... + + @abstractmethod + def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ... + + @property + @abstractmethod + def output_image_dimensions(self) -> tuple[int, int]: ... diff --git a/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py new file mode 100644 index 0000000000..063735fb06 --- /dev/null +++ b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Optional + +from attrs import define, field + +from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from PIL.Image import Image + + +@define +class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver): + """Image generation model driver for Stable Diffusion 3 models with ControlNet. + + For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline: + https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3 + + Attributes: + controlnet_model: The ControlNet model to use for image generation. + controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None. + """ + + controlnet_model: str = field(kw_only=True) + controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) + + def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: + sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel + sd3_controlnet_pipeline = import_optional_dependency( + "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" + ).StableDiffusion3ControlNetPipeline + + pipeline_params = {} + controlnet_pipeline_params = {} + if self.torch_dtype is not None: + pipeline_params["torch_dtype"] = self.torch_dtype + controlnet_pipeline_params["torch_dtype"] = self.torch_dtype + + if self.drop_t5_encoder: + pipeline_params["text_encoder_3"] = None + pipeline_params["tokenizer_3"] = None + + # For both Stable Diffusion and ControlNet, models can be provided either + # as a path to a local file or as a HuggingFace model repo name. + # We use the from_single_file method if the model is a local file and the + # from_pretrained method if the model is a local directory or hosted on HuggingFace. + if os.path.isfile(self.controlnet_model): + pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( + self.controlnet_model, **controlnet_pipeline_params + ) + else: + pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( + self.controlnet_model, **controlnet_pipeline_params + ) + + if os.path.isfile(model): + pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) + else: + pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) + + if self.enable_model_cpu_offload: + pipeline.enable_model_cpu_offload() + + if device is not None: + pipeline.to(device) + + return pipeline + + def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: + if image is None: + raise ValueError("Input image is required for ControlNet pipelines.") + + return {"control_image": image} + + def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: + additional_params = super().make_additional_params(negative_prompts, device) + + del additional_params["height"] + del additional_params["width"] + + if self.controlnet_conditioning_scale is not None: + additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale + + return additional_params diff --git a/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py new file mode 100644 index 0000000000..53e90c3e2b --- /dev/null +++ b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Optional + +from attrs import define, field + +from griptape.drivers.image_generation_pipeline.base_image_generation_pipeline_driver import ( + BaseDiffusionImageGenerationPipelineDriver, +) +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + import torch + from PIL.Image import Image + + +@define +class StableDiffusion3ImageGenerationPipelineDriver(BaseDiffusionImageGenerationPipelineDriver): + """Image generation model driver for Stable Diffusion 3 models. + + For more information, see the HuggingFace documentation for the StableDiffusion3Pipeline: + https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3 + + Attributes: + width: The width of the generated image. Defaults to 1024. Must be a multiple of 64. + height: The height of the generated image. Defaults to 1024. Must be a multiple of 64. + seed: The random seed to use for image generation. If not provided, a random seed will be used. + guidance_scale: The strength of the guidance loss. If not provided, the default value will be used. + steps: The number of inference steps to use in image generation. If not provided, the default value will be used. + torch_dtype: The torch data type to use for image generation. If not provided, the default value will be used. + """ + + width: int = field(default=1024, kw_only=True, metadata={"serializable": True}) + height: int = field(default=1024, kw_only=True, metadata={"serializable": True}) + seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + guidance_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) + steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + torch_dtype: Optional[torch.dtype] = field(default=None, kw_only=True, metadata={"serializable": True}) + enable_model_cpu_offload: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + drop_t5_encoder: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + + def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: + sd3_pipeline = import_optional_dependency( + "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3" + ).StableDiffusion3Pipeline + + pipeline_params = {} + if self.torch_dtype is not None: + pipeline_params["torch_dtype"] = self.torch_dtype + + if self.drop_t5_encoder: + pipeline_params["text_encoder_3"] = None + pipeline_params["tokenizer_3"] = None + + # A model can be provided either as a path to a local file + # or as a HuggingFace model repo name. + if os.path.isfile(model): + # If the model provided is a local file (not a directory), + # we load it using the from_single_file method. + pipeline = sd3_pipeline.from_single_file(model, **pipeline_params) + else: + # If the model is a local directory or hosted on HuggingFace, + # we load it using the from_pretrained method. + pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params) + + if self.enable_model_cpu_offload: + pipeline.enable_model_cpu_offload() + + # Move inference to particular device if requested. + if device is not None: + pipeline.to(device) + + return pipeline + + def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: + return None + + def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: + torch_generator = import_optional_dependency("torch").Generator + + additional_params = {} + if negative_prompts: + additional_params["negative_prompt"] = ", ".join(negative_prompts) + + if self.width is not None: + additional_params["width"] = self.width + + if self.height is not None: + additional_params["height"] = self.height + + if self.seed is not None: + additional_params["generator"] = [torch_generator(device=device).manual_seed(self.seed)] + + if self.guidance_scale is not None: + additional_params["guidance_scale"] = self.guidance_scale + + if self.steps is not None: + additional_params["num_inference_steps"] = self.steps + + return additional_params + + @property + def output_image_dimensions(self) -> tuple[int, int]: + return self.width, self.height diff --git a/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py new file mode 100644 index 0000000000..8276b110bb --- /dev/null +++ b/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Optional + +from attrs import define, field + +from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from PIL.Image import Image + + +@define +class StableDiffusion3Img2ImgImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver): + """Image generation model driver for Stable Diffusion 3 model image to image pipelines. + + For more information, see the HuggingFace documentation for the StableDiffusion3Img2ImgPipeline: + https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py + + Attributes: + strength: A value [0.0, 1.0] that determines the strength of the initial image in the output. + """ + + strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) + + def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: + sd3_img2img_pipeline = import_optional_dependency( + "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img" + ).StableDiffusion3Img2ImgPipeline + + pipeline_params = {} + if self.torch_dtype is not None: + pipeline_params["torch_dtype"] = self.torch_dtype + + if self.drop_t5_encoder: + pipeline_params["text_encoder_3"] = None + pipeline_params["tokenizer_3"] = None + + # A model can be provided either as a path to a local file + # or as a HuggingFace model repo name. + if os.path.isfile(model): + # If the model provided is a local file (not a directory), + # we load it using the from_single_file method. + pipeline = sd3_img2img_pipeline.from_single_file(model, **pipeline_params) + else: + # If the model is a local directory or hosted on HuggingFace, + # we load it using the from_pretrained method. + pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params) + + if self.enable_model_cpu_offload: + pipeline.enable_model_cpu_offload() + + # Move inference to particular device if requested. + if device is not None: + pipeline.to(device) + + return pipeline + + def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: + if image is None: + raise ValueError("Input image is required for image to image pipelines.") + + return {"image": image} + + def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: + additional_params = super().make_additional_params(negative_prompts, device) + + # Explicit height and width params are not supported, but + # are instead inferred from input image. + del additional_params["height"] + del additional_params["width"] + + if self.strength is not None: + additional_params["strength"] = self.strength + + return additional_params diff --git a/poetry.lock b/poetry.lock index 2738d6c972..58fff8c4e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,18 +2,18 @@ [[package]] name = "accelerate" -version = "0.31.0" +version = "0.32.1" description = "Accelerate" optional = true python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.31.0-py3-none-any.whl", hash = "sha256:0fc608dc49584f64d04711a39711d73cb0ad4ef3d21cddee7ef2216e29471144"}, - {file = "accelerate-0.31.0.tar.gz", hash = "sha256:b5199865b26106ccf9205acacbe8e4b3b428ad585e7c472d6a46f6fb75b6c176"}, + {file = "accelerate-0.32.1-py3-none-any.whl", hash = "sha256:71fcf4be00872194071de561634268b71417d7f5b16b178e2fa76b6f117c52b0"}, + {file = "accelerate-0.32.1.tar.gz", hash = "sha256:3999acff0237cd0d4f9fd98b42d5a3163544777b53fc4f1eec886b77e992d177"}, ] [package.dependencies] huggingface-hub = "*" -numpy = ">=1.17" +numpy = ">=1.17,<2.0.0" packaging = ">=20.0" psutil = "*" pyyaml = "*" @@ -1201,6 +1201,40 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "diffusers" +version = "0.30.0.dev0" +description = "State-of-the-art diffusion in PyTorch and JAX." +optional = true +python-versions = ">=3.8.0" +files = [] +develop = false + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.23.2" +importlib_metadata = "*" +numpy = "*" +Pillow = "*" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.3.1" + +[package.extras] +dev = ["GitPython (<3.1.19)", "Jinja2", "Jinja2", "accelerate (>=0.31.0)", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] +docs = ["hf-doc-builder (>=0.3.0)"] +flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] +quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] +test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.41.2)"] +torch = ["accelerate (>=0.31.0)", "torch (>=1.4)"] +training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] + +[package.source] +type = "git" +url = "https://github.com/griptape-ai/diffusers.git" +reference = "main" +resolved_reference = "90c1f182683a9bb51e370816d063b2e3aba53fc4" + [[package]] name = "distlib" version = "0.3.8" @@ -2500,13 +2534,9 @@ files = [ {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, @@ -5355,6 +5385,68 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" +[[package]] +name = "sentencepiece" +version = "0.2.0" +description = "SentencePiece python wrapper" +optional = true +python-versions = "*" +files = [ + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"}, + {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"}, +] + [[package]] name = "sentinels" version = "1.0.0" @@ -6686,7 +6778,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "trafilatura", "transformers", "voyageai"] +all = ["accelerate", "anthropic", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6697,6 +6789,7 @@ drivers-embedding-voyageai = ["voyageai"] drivers-event-listener-amazon-iot = ["boto3"] drivers-event-listener-amazon-sqs = ["boto3"] drivers-event-listener-pusher = ["pusher"] +drivers-image-generation-huggingface = ["accelerate", "diffusers", "pillow", "sentencepiece", "torch"] drivers-memory-conversation-amazon-dynamodb = ["boto3"] drivers-memory-conversation-redis = ["redis"] drivers-observability-datadog = ["opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk"] @@ -6735,4 +6828,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "214d3ce5365f5ce4d852e44019dbbea4659422bdc0231f7cf745a47eeb90ae18" +content-hash = "20511683fde939102f4d9e331fd9ecd064ad6cc490a5bf2f6c3f4366b4e43447" diff --git a/pyproject.toml b/pyproject.toml index 98bbe4a038..8bb0ca5d45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,10 @@ opentelemetry-api = {version = "^1.25.0", optional = true} opentelemetry-instrumentation = {version = "^0.46b0", optional = true} opentelemetry-instrumentation-threading = {version = "^0.46b0", optional = true} opentelemetry-exporter-otlp-proto-http = {version = "^1.25.0", optional = true} +diffusers = {git = "https://github.com/griptape-ai/diffusers.git", branch = "main", optional = true} +accelerate = {version = "^0.32.1", optional = true} +sentencepiece = {version = "^0.2.0", optional = true} +torch = {version = "^2.3.1", optional = true} # loaders pandas = {version = "^1.3", optional = true} @@ -134,6 +138,14 @@ drivers-observability-datadog = [ "opentelemetry-exporter-otlp-proto-http", ] +drivers-image-generation-huggingface = [ + "diffusers", + "accelerate", + "sentencepiece", + "torch", + "pillow", +] + loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] @@ -174,11 +186,15 @@ all = [ "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-exporter-otlp-proto-http", + "diffusers", + "accelerate", + "sentencepiece", + "torch", + "pillow", # loaders "pandas", "pypdf", - "pillow", "mail-parser", "filetype", ] diff --git a/tests/unit/drivers/image_generation/test_huggingface_pipeline_image_generation_driver.py b/tests/unit/drivers/image_generation/test_huggingface_pipeline_image_generation_driver.py new file mode 100644 index 0000000000..d47e6a1071 --- /dev/null +++ b/tests/unit/drivers/image_generation/test_huggingface_pipeline_image_generation_driver.py @@ -0,0 +1,72 @@ +import io +from unittest.mock import Mock + +import pytest +from PIL import Image + +from griptape.artifacts import ImageArtifact +from griptape.drivers import ( + BaseDiffusionImageGenerationPipelineDriver, + HuggingFacePipelineImageGenerationDriver, +) + + +class TestHuggingFacePipelineImageGenerationDriver: + @pytest.fixture() + def image_artifact(self): + buffer = io.BytesIO() + Image.new("RGB", (256, 256)).save(buffer, "PNG") + return ImageArtifact(buffer.getvalue(), format="png", width=256, height=256) + + @pytest.fixture() + def model_driver(self): + model_driver = Mock(spec=BaseDiffusionImageGenerationPipelineDriver) + mock_pipeline = Mock() + mock_pipeline.return_value = Mock() + mock_pipeline.return_value.images = [Image.new("RGB", (256, 256))] + model_driver.prepare_pipeline.return_value = mock_pipeline + model_driver.make_image_param.return_value = {"image": Image.new("RGB", (256, 256))} + model_driver.make_additional_params.return_value = {"negative_prompt": ["sample negative prompt"]} + model_driver.output_image_dimensions = (256, 256) + + return model_driver + + @pytest.fixture() + def driver(self, model_driver): + return HuggingFacePipelineImageGenerationDriver(model="repo/model", pipeline_driver=model_driver) + + def test_init(self, driver): + assert driver + + def test_try_text_to_image(self, driver): + image_artifact = driver.try_text_to_image(prompts=["test prompt"]) + + assert image_artifact + assert image_artifact.mime_type == "image/png" + assert image_artifact.width == 256 + assert image_artifact.height == 256 + + def test_try_image_variation(self, driver, image_artifact): + image_artifact = driver.try_image_variation(prompts=["test prompt"], image=image_artifact) + + assert image_artifact + assert image_artifact.mime_type == "image/png" + assert image_artifact.width == 256 + assert image_artifact.height == 256 + + def test_try_image_inpainting(self, driver): + with pytest.raises(NotImplementedError): + driver.try_image_inpainting(prompts=["test prompt"], image=Mock(), mask=Mock()) + + def test_try_image_outpainting(self, driver): + with pytest.raises(NotImplementedError): + driver.try_image_outpainting(prompts=["test prompt"], image=Mock(), mask=Mock()) + + def test_configurable_output_format(self, driver): + driver.output_format = "jpeg" + image_artifact = driver.try_text_to_image(prompts=["test prompt"]) + + assert image_artifact + assert image_artifact.mime_type == "image/jpeg" + assert image_artifact.width == 256 + assert image_artifact.height == 256 diff --git a/tests/unit/drivers/image_generation_pipeline/__init__.py b/tests/unit/drivers/image_generation_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py new file mode 100644 index 0000000000..986eb553de --- /dev/null +++ b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py @@ -0,0 +1,126 @@ +from unittest.mock import Mock, patch + +import pytest +import torch +from PIL import Image + +from griptape.drivers import StableDiffusion3ControlNetImageGenerationPipelineDriver + + +class TestStableDiffusion3ControlNetPipelineImageGenerationModelDriver: + @pytest.fixture() + def model_driver(self): + return StableDiffusion3ControlNetImageGenerationPipelineDriver(controlnet_model="controlnet_model") + + @pytest.fixture() + def mock_import(self, monkeypatch): + mock = Mock() + monkeypatch.setattr( + "griptape.drivers.image_generation_pipeline.stable_diffusion_3_controlnet_image_generation_pipeline_driver.import_optional_dependency", + mock, + ) + + return mock + + def test_prepare_pipeline_local_file(self, model_driver, mock_import): + mock_sd3_controlnet_model = Mock() + mock_sd3_controlnet_pipeline = Mock() + mock_import.side_effect = [ + Mock(SD3ControlNetModel=mock_sd3_controlnet_model), + Mock(StableDiffusion3ControlNetPipeline=mock_sd3_controlnet_pipeline), + ] + + with patch("os.path.isfile", return_value=True): + result = model_driver.prepare_pipeline("local_model", "cuda") + + mock_sd3_controlnet_model.from_single_file.assert_called_once_with("controlnet_model") + mock_sd3_controlnet_pipeline.from_single_file.assert_called_once_with( + "local_model", controlnet=mock_sd3_controlnet_model.from_single_file.return_value + ) + + assert result == mock_sd3_controlnet_pipeline.from_single_file.return_value + result.to.assert_called_once_with("cuda") + + def test_prepare_pipeline_huggingface_model(self, model_driver, mock_import): + mock_sd3_controlnet_model = Mock() + mock_sd3_controlnet_pipeline = Mock() + mock_import.side_effect = [ + Mock(SD3ControlNetModel=mock_sd3_controlnet_model), + Mock(StableDiffusion3ControlNetPipeline=mock_sd3_controlnet_pipeline), + ] + + with patch("os.path.isfile", return_value=False): + result = model_driver.prepare_pipeline("huggingface/model", "cuda") + + mock_sd3_controlnet_model.from_pretrained.assert_called_once_with("controlnet_model") + mock_sd3_controlnet_pipeline.from_pretrained.assert_called_once_with( + "huggingface/model", controlnet=mock_sd3_controlnet_model.from_pretrained.return_value + ) + + assert result == mock_sd3_controlnet_pipeline.from_pretrained.return_value + result.to.assert_called_once_with("cuda") + + def test_prepare_pipeline_with_options(self, model_driver, mock_import): + mock_sd3_controlnet_model = Mock() + mock_sd3_controlnet_pipeline = Mock() + mock_import.side_effect = [ + Mock(SD3ControlNetModel=mock_sd3_controlnet_model), + Mock(StableDiffusion3ControlNetPipeline=mock_sd3_controlnet_pipeline), + ] + + model_driver.torch_dtype = torch.float16 + model_driver.drop_t5_encoder = True + model_driver.enable_model_cpu_offload = True + + result = model_driver.prepare_pipeline("huggingface/model", "cpu") + + mock_sd3_controlnet_pipeline.from_pretrained.assert_called_once_with( + "huggingface/model", + controlnet=mock_sd3_controlnet_model.from_pretrained.return_value, + torch_dtype=torch.float16, + text_encoder_3=None, + tokenizer_3=None, + ) + + assert result == mock_sd3_controlnet_pipeline.from_pretrained.return_value + result.to.assert_called_once_with("cpu") + result.enable_model_cpu_offload.assert_called_once() + + def test_make_image_param(self, model_driver): + mock_image = Mock(spec=Image.Image) + result = model_driver.make_image_param(mock_image) + + assert result == {"control_image": mock_image} + + def test_make_image_param_without_image(self, model_driver): + with pytest.raises(ValueError): + model_driver.make_image_param(None) + + def test_make_additional_params(self, model_driver, mock_import): + mock_torch = Mock() + mock_import.return_value = mock_torch + + model_driver.guidance_scale = 7.5 + model_driver.steps = 50 + model_driver.controlnet_conditioning_scale = 0.8 + + result = model_driver.make_additional_params(["no cats", "no dogs"], "cuda") + + expected = { + "negative_prompt": "no cats, no dogs", + "guidance_scale": 7.5, + "num_inference_steps": 50, + "controlnet_conditioning_scale": 0.8, + } + + assert result == expected + assert "height" not in result + assert "width" not in result + + def test_output_image_dimensions(self, model_driver): + model_driver.width = 512 + model_driver.height = 768 + + dimensions = model_driver.output_image_dimensions + + assert dimensions == (512, 768) diff --git a/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py new file mode 100644 index 0000000000..2ad28bec5b --- /dev/null +++ b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py @@ -0,0 +1,105 @@ +from unittest.mock import Mock, patch + +import pytest +import torch +from PIL import Image + +from griptape.drivers import StableDiffusion3Img2ImgImageGenerationPipelineDriver + + +class TestStableDiffusion3Img2ImgPipelineImageGenerationModelDriver: + @pytest.fixture() + def model_driver(self): + return StableDiffusion3Img2ImgImageGenerationPipelineDriver() + + @pytest.fixture() + def mock_import(self, monkeypatch): + mock = Mock() + monkeypatch.setattr( + "griptape.drivers.image_generation_pipeline.stable_diffusion_3_img_2_img_image_generation_pipeline_driver.import_optional_dependency", + mock, + ) + + return mock + + def test_prepare_pipeline_local_file(self, model_driver, mock_import): + mock_sd3_img2img_pipeline = Mock() + mock_import.return_value.StableDiffusion3Img2ImgPipeline = mock_sd3_img2img_pipeline + + with patch("os.path.isfile", return_value=True): + result = model_driver.prepare_pipeline("local_model", "cuda") + + mock_sd3_img2img_pipeline.from_single_file.assert_called_once_with("local_model") + + assert result == mock_sd3_img2img_pipeline.from_single_file.return_value + result.to.assert_called_once_with("cuda") + + def test_prepare_pipeline_huggingface_model(self, model_driver, mock_import): + mock_sd3_img2img_pipeline = Mock() + mock_import.return_value.StableDiffusion3Img2ImgPipeline = mock_sd3_img2img_pipeline + + with patch("os.path.isfile", return_value=False): + result = model_driver.prepare_pipeline("huggingface/model", "cuda") + + mock_sd3_img2img_pipeline.from_pretrained.assert_called_once_with("huggingface/model") + + assert result == mock_sd3_img2img_pipeline.from_pretrained.return_value + result.to.assert_called_once_with("cuda") + + def test_prepare_pipeline_with_options(self, model_driver, mock_import): + mock_sd3_img2img_pipeline = Mock() + mock_import.return_value.StableDiffusion3Img2ImgPipeline = mock_sd3_img2img_pipeline + + model_driver.torch_dtype = torch.float16 + model_driver.drop_t5_encoder = True + model_driver.enable_model_cpu_offload = True + + result = model_driver.prepare_pipeline("huggingface/model", "cpu") + + mock_sd3_img2img_pipeline.from_pretrained.assert_called_once_with( + "huggingface/model", + torch_dtype=torch.float16, + text_encoder_3=None, + tokenizer_3=None, + ) + assert result == mock_sd3_img2img_pipeline.from_pretrained.return_value + result.to.assert_called_once_with("cpu") + result.enable_model_cpu_offload.assert_called_once() + + def test_make_image_param(self, model_driver): + mock_image = Mock(spec=Image.Image) + result = model_driver.make_image_param(mock_image) + assert result == {"image": mock_image} + + def test_make_image_param_without_image(self, model_driver): + with pytest.raises(ValueError): + model_driver.make_image_param(None) + + def test_make_additional_params(self, model_driver, mock_import): + mock_torch = Mock() + mock_import.return_value = mock_torch + + model_driver.guidance_scale = 7.5 + model_driver.steps = 50 + model_driver.strength = 0.75 + + result = model_driver.make_additional_params(["no cats", "no dogs"], "cuda") + + expected = { + "negative_prompt": "no cats, no dogs", + "guidance_scale": 7.5, + "num_inference_steps": 50, + "strength": 0.75, + } + + assert result == expected + assert "height" not in result + assert "width" not in result + + def test_output_image_dimensions(self, model_driver): + model_driver.width = 512 + model_driver.height = 768 + + dimensions = model_driver.output_image_dimensions + + assert dimensions == (512, 768) diff --git a/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_pipeline_image_generation_model_driver.py b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_pipeline_image_generation_model_driver.py new file mode 100644 index 0000000000..36077f43a9 --- /dev/null +++ b/tests/unit/drivers/image_generation_pipeline/test_stable_diffusion_3_pipeline_image_generation_model_driver.py @@ -0,0 +1,94 @@ +from unittest.mock import Mock, patch + +import pytest +import torch +from PIL import Image + +from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver + + +class TestStableDiffusion3PipelineImageGenerationModelDriver: + @pytest.fixture() + def model_driver(self): + return StableDiffusion3ImageGenerationPipelineDriver() + + @pytest.fixture() + def mock_import(self, monkeypatch): + mock = Mock() + monkeypatch.setattr( + "griptape.drivers.image_generation_pipeline.stable_diffusion_3_image_generation_pipeline_driver.import_optional_dependency", + mock, + ) + + return mock + + def test_prepare_pipeline_local_file(self, model_driver, mock_import): + mock_sd3_pipeline = Mock() + mock_import.return_value.StableDiffusion3Pipeline = mock_sd3_pipeline + + with patch("os.path.isfile", return_value=True): + result = model_driver.prepare_pipeline("local_model", "cuda") + + mock_sd3_pipeline.from_single_file.assert_called_once_with("local_model") + + assert result == mock_sd3_pipeline.from_single_file.return_value + result.to.assert_called_once_with("cuda") + + def test_prepare_pipeline_huggingface_model(self, model_driver, mock_import): + mock_sd3_pipeline = Mock() + mock_import.return_value.StableDiffusion3Pipeline = mock_sd3_pipeline + + with patch("os.path.isfile", return_value=False): + result = model_driver.prepare_pipeline("huggingface/model", None) + + mock_sd3_pipeline.from_pretrained.assert_called_once_with("huggingface/model") + assert result == mock_sd3_pipeline.from_pretrained.return_value + result.to.assert_not_called() + + def test_prepare_pipeline_with_options(self, model_driver, mock_import): + mock_sd3_pipeline = Mock() + mock_import.return_value.StableDiffusion3Pipeline = mock_sd3_pipeline + + model_driver.torch_dtype = torch.float16 + model_driver.drop_t5_encoder = True + model_driver.enable_model_cpu_offload = True + + result = model_driver.prepare_pipeline("huggingface/model", "cpu") + + mock_sd3_pipeline.from_pretrained.assert_called_once_with( + "huggingface/model", + torch_dtype=torch.float16, + text_encoder_3=None, + tokenizer_3=None, + ) + assert result == mock_sd3_pipeline.from_pretrained.return_value + result.to.assert_called_once_with("cpu") + result.enable_model_cpu_offload.assert_called_once() + + def test_make_image_param(self, model_driver): + assert model_driver.make_image_param(Mock(spec=Image.Image)) is None + + def test_make_additional_params(self, model_driver, mock_import): + mock_torch = Mock() + mock_import.return_value = mock_torch + + model_driver.guidance_scale = 7.5 + model_driver.steps = 50 + + result = model_driver.make_additional_params(["no cats", "no dogs"], "cuda") + + expected = { + "negative_prompt": "no cats, no dogs", + "width": 1024, + "height": 1024, + "guidance_scale": 7.5, + "num_inference_steps": 50, + } + + assert result == expected + + def test_output_image_dimensions(self, model_driver): + model_driver.width = 512 + model_driver.height = 768 + + assert model_driver.output_image_dimensions == (512, 768)