diff --git a/griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py b/griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py index 5703cbe07..006b160be 100644 --- a/griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py +++ b/griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py @@ -1,12 +1,15 @@ import io from abc import ABC -from typing import Optional +from typing import Optional, TYPE_CHECKING from attrs import define, field -from PIL import Image from griptape.artifacts import ImageArtifact from griptape.drivers import BaseDiffusionPipelineImageGenerationModelDriver, BaseImageGenerationDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from PIL import Image @define @@ -35,6 +38,7 @@ def try_image_variation( pipeline = self.model_driver.prepare_pipeline(self.model, self.device) prompt = ", ".join(prompts) + Image = import_optional_dependency("PIL.Image") input_image = Image.open(io.BytesIO(image.value)) # The size of the input image drives the size of the output image. diff --git a/griptape/drivers/image_generation_model/base_diffusion_pipeline_image_generation_model_driver.py b/griptape/drivers/image_generation_model/base_diffusion_pipeline_image_generation_model_driver.py index 75dd8611a..ef0413950 100644 --- a/griptape/drivers/image_generation_model/base_diffusion_pipeline_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/base_diffusion_pipeline_image_generation_model_driver.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from attrs import define -from PIL.Image import Image + +if TYPE_CHECKING: + from PIL.Image import Image @define diff --git a/griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py b/griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py index 59225bac1..731927546 100644 --- a/griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py @@ -1,12 +1,15 @@ import os -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from attrs import define, field -from diffusers.models.controlnet_sd3 import SD3ControlNetModel -from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline -from PIL.Image import Image from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet import \ + StableDiffusion3ControlNetPipeline + from PIL.Image import Image @define @@ -25,21 +28,26 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: # as a path to a local file or as a HuggingFace model repo name. # We use the from_single_file method if the model is a local file and the # from_pretrained method if the model is a local directory or hosted on HuggingFace. + sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3.SD3ControlNetModel") if os.path.isfile(self.controlnet_model): - pipeline_params["controlnet"] = SD3ControlNetModel.from_single_file( + pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( self.controlnet_model, **controlnet_pipeline_params ) else: - pipeline_params["controlnet"] = SD3ControlNetModel.from_pretrained( + pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( self.controlnet_model, **controlnet_pipeline_params ) + sd3_controlnet_pipeline = import_optional_dependency( + "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" + ".StableDiffusion3ControlNetPipeline" + ) if os.path.isfile(model): - pipeline = StableDiffusion3ControlNetPipeline.from_single_file(model, **pipeline_params) + pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) else: - pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(model, **pipeline_params) + pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) if not isinstance(pipeline, StableDiffusion3ControlNetPipeline): raise ValueError(f"Expected StableDiffusion3ControlNetPipeline, but got {type(pipeline)}.") diff --git a/griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py b/griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py index 24f1e0221..29de439ed 100644 --- a/griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py @@ -1,11 +1,15 @@ import os -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from attrs import define, field -from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline -from PIL.Image import Image from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img import \ + StableDiffusion3Img2ImgPipeline + from PIL.Image import Image @define @@ -19,6 +23,9 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. + sd3_img2img_pipeline = import_optional_dependency( + "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline" + ) if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. @@ -29,7 +36,7 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. - pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained(model, **pipeline_params) + pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params) if not isinstance(pipeline, StableDiffusion3Img2ImgPipeline): raise ValueError(f"Expected StableDiffusion3Img2ImgPipeline, but got {type(pipeline)}.") diff --git a/griptape/drivers/image_generation_model/stable_diffusion_3_pipeline_image_generation_model_driver.py b/griptape/drivers/image_generation_model/stable_diffusion_3_pipeline_image_generation_model_driver.py index fdf67bc31..4cb7834c7 100644 --- a/griptape/drivers/image_generation_model/stable_diffusion_3_pipeline_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/stable_diffusion_3_pipeline_image_generation_model_driver.py @@ -1,15 +1,17 @@ import os -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING -import torch from attrs import define, field -from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline -from PIL.Image import Image from griptape.drivers.image_generation_model.base_diffusion_pipeline_image_generation_model_driver import ( BaseDiffusionPipelineImageGenerationModelDriver, ) +from griptape.utils import import_optional_dependency +if TYPE_CHECKING: + from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline + from PIL.Image import Image + import torch @define class StableDiffusion3PipelineImageGenerationModelDriver(BaseDiffusionPipelineImageGenerationModelDriver): @@ -27,14 +29,17 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. + sd3_pipeline = import_optional_dependency( + "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline" + ) if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. - pipeline = StableDiffusion3Pipeline.from_single_file(model, **pipeline_params) + pipeline = sd3_pipeline.from_single_file(model, **pipeline_params) else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. - pipeline = StableDiffusion3Pipeline.from_pretrained(model, **pipeline_params) + pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params) if not isinstance(pipeline, StableDiffusion3Pipeline): raise ValueError(f"Expected StableDiffusion3Pipeline, but got {type(pipeline)}.") diff --git a/pyproject.toml b/pyproject.toml index 98bbe4a03..d70708fee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,10 @@ opentelemetry-api = {version = "^1.25.0", optional = true} opentelemetry-instrumentation = {version = "^0.46b0", optional = true} opentelemetry-instrumentation-threading = {version = "^0.46b0", optional = true} opentelemetry-exporter-otlp-proto-http = {version = "^1.25.0", optional = true} +diffusers = {version = "^0.29.2", optional = true} +accelerate = {version = "^0.32.1", optional = true} +sentencepiece = {version = "^0.2.0", optional = true} +torch = {version = "^2.4.0", optional = true} # loaders pandas = {version = "^1.3", optional = true} @@ -134,6 +138,14 @@ drivers-observability-datadog = [ "opentelemetry-exporter-otlp-proto-http", ] +drivers-imagegen-huggingface = [ + "diffusers", + "accelerate", + "sentencepiece", + "torch", + "pillow", +] + loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] @@ -174,6 +186,10 @@ all = [ "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-exporter-otlp-proto-http", + "diffusers", + "accelerate", + "sentencepiece", + "torch", # loaders "pandas",