Skip to content

Commit

Permalink
Stable Diffusion 3 local support (#1018)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench authored and collindutter committed Aug 2, 2024
1 parent 54cd8cc commit 56b4104
Show file tree
Hide file tree
Showing 15 changed files with 1,058 additions and 11 deletions.
127 changes: 127 additions & 0 deletions docs/griptape-framework/drivers/image-generation-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,130 @@ agent = Agent(tools=[

agent.run("Generate a watercolor painting of a dog riding a skateboard")
```

### HuggingFace Pipelines

!!! info
This driver requires the `drivers-image-generation-huggingface` [extra](../index.md#extras).

The [HuggingFace Pipelines Image Generation Driver](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) enables image generation through locally-hosted models using the HuggingFace [Diffusers](https://huggingface.co/docs/diffusers/en/index) library. This Driver requires a [Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/index.md) to prepare the appropriate Pipeline.

This Driver requires a `model` configuration, specifying the model to use for image generation. The value of the `model` configuration must be one of the following:

- A model name from the HuggingFace Model Hub, like `stabilityai/stable-diffusion-3-medium-diffusers`
- A path to the directory containing a model on the filesystem, like `./models/stable-diffusion-3/`
- A path to a file containing a model on the filesystem, like `./models/sd3_medium_incl_clips.safetensors`

The `device` configuration specifies the hardware device used to run inference. Common values include `cuda` (supporting CUDA-enabled GPUs), `cpu` (supported by a device's CPU), and `mps` (supported by Apple silicon GPUs). For more information, see [HuggingFace's documentation](https://huggingface.co/docs/transformers/en/perf_infer_gpu_one) on GPU inference.

#### Stable Diffusion 3 Image Generation Pipeline Driver

!!! info
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [Stable Diffusion 3 Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.md) provides a Stable `Diffusion3DiffusionPipeline` for text-to-image generations via the [HuggingFace Pipelines Image Generation Driver's](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) `.try_text_to_image()` method. This Driver accepts a text prompt and configurations including Stable Diffusion 3 model, output image size, generation seed, and inference steps.

Image generation consumes substantial memory. On devices with limited VRAM, it may be necessary to enable the `enable_model_cpu_offload` or `drop_t5_encoder` configurations. For more information, see [HuggingFace's documentation](https://huggingface.co/docs/diffusers/en/optimization/memory) on reduced memory usage.

```python title="PYTEST_IGNORE"
from griptape.structures import Pipeline
from griptape.tasks import PromptImageGenerationTask
from griptape.engines import PromptImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3ImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact

image_generation_task = PromptImageGenerationTask(
input=TextArtifact("landscape photograph, verdant, countryside, 8k"),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3ImageGenerationPipelineDriver(
height=512,
width=512,
)
)
)
)

output_artifact = Pipeline(tasks=[image_generation_task]).run().output
```

#### Stable Diffusion 3 Img2Img Image Generation Pipeline Driver

!!! info
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [Stable Diffusion 3 Img2Img Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.md) provides a `StableDiffusion3Img2ImgPipeline` for image-to-image generations, accepting a text prompt and input image. This Driver accepts a text prompt, an input image, and configurations including Stable Diffusion 3 model, output image size, inference steps, generation seed, and strength of generation over the input image.

```python title="PYTEST_IGNORE"
from pathlib import Path

from griptape.structures import Pipeline
from griptape.tasks import VariationImageGenerationTask
from griptape.engines import VariationImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3Img2ImgImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.loaders import ImageLoader

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())

image_variation_task = VariationImageGenerationTask(
input=(prompt_artifact, input_image_artifact),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3Img2ImgImageGenerationPipelineDriver(
height=1024,
width=1024,
)
)
)
)

output_artifact = Pipeline(tasks=[image_variation_task]).run().output
```

#### StableDiffusion3ControlNetImageGenerationPipelineDriver

!!! note
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [StableDiffusion3ControlNetImageGenerationPipelineDriver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.md) provides a `StableDiffusion3ControlNetPipeline` for image-to-image generations, accepting a text prompt and a control image. This Driver accepts a text prompt, a control image, and configurations including Stable Diffusion 3 model, ControlNet model, output image size, generation seed, inference steps, and the degree to which the model adheres to the control image.

```python title="PYTEST_IGNORE"
from pathlib import Path

from griptape.structures import Pipeline
from griptape.tasks import VariationImageGenerationTask
from griptape.engines import VariationImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3ControlNetImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.loaders import ImageLoader

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes())

controlnet_task = VariationImageGenerationTask(
input=(prompt_artifact, control_image_artifact),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3ControlNetImageGenerationPipelineDriver(
controlnet_model="InstantX/SD3-Controlnet-Canny",
control_strength=0.8,
height=768,
width=1024,
)
)
)
)

output_artifact = Pipeline(tasks=[controlnet_task]).run().output
```
21 changes: 21 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,29 @@
)
from .image_generation_model.bedrock_titan_image_generation_model_driver import BedrockTitanImageGenerationModelDriver

from .image_generation_pipeline.base_image_generation_pipeline_driver import (
BaseDiffusionImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_image_generation_pipeline_driver import (
StableDiffusion3ImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_img_2_img_image_generation_pipeline_driver import (
StableDiffusion3Img2ImgImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_controlnet_image_generation_pipeline_driver import (
StableDiffusion3ControlNetImageGenerationPipelineDriver,
)

from .image_generation.base_image_generation_driver import BaseImageGenerationDriver
from .image_generation.base_multi_model_image_generation_driver import BaseMultiModelImageGenerationDriver
from .image_generation.openai_image_generation_driver import OpenAiImageGenerationDriver
from .image_generation.leonardo_image_generation_driver import LeonardoImageGenerationDriver
from .image_generation.amazon_bedrock_image_generation_driver import AmazonBedrockImageGenerationDriver
from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver
from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver
from .image_generation.huggingface_pipeline_image_generation_driver import (
HuggingFacePipelineImageGenerationDriver,
)

from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver
from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver
Expand Down Expand Up @@ -164,13 +180,18 @@
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
"BaseDiffusionImageGenerationPipelineDriver",
"StableDiffusion3ImageGenerationPipelineDriver",
"StableDiffusion3Img2ImgImageGenerationPipelineDriver",
"StableDiffusion3ControlNetImageGenerationPipelineDriver",
"BaseImageGenerationDriver",
"BaseMultiModelImageGenerationDriver",
"OpenAiImageGenerationDriver",
"LeonardoImageGenerationDriver",
"AmazonBedrockImageGenerationDriver",
"AzureOpenAiImageGenerationDriver",
"DummyImageGenerationDriver",
"HuggingFacePipelineImageGenerationDriver",
"BaseImageQueryModelDriver",
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import io
from abc import ABC
from typing import Optional

from attrs import define, field

from griptape.artifacts import ImageArtifact
from griptape.drivers import BaseDiffusionImageGenerationPipelineDriver, BaseImageGenerationDriver
from griptape.utils import import_optional_dependency


@define
class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC):
"""Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.
For more information, see the HuggingFace documentation for Diffusers:
https://huggingface.co/docs/diffusers/en/index
Attributes:
pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model.
device: The hardware device used for inference. For example, "cpu", "cuda", or "mps".
output_format: The format the generated image is returned in. Defaults to "png".
"""

pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True})
device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
output_format: str = field(default="png", kw_only=True, metadata={"serializable": True})

def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
output_image = pipeline(
prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format=self.output_format.upper())

return ImageArtifact(
value=buffer.getvalue(),
format=self.output_format.lower(),
height=output_image.height,
width=output_image.width,
prompt=prompt,
)

def try_image_variation(
self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
pil_image = import_optional_dependency("PIL.Image")

pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
input_image = pil_image.open(io.BytesIO(image.value))
# The size of the input image drives the size of the output image.
# Resize the input image to the configured dimensions.
output_width, output_height = self.pipeline_driver.output_image_dimensions
if input_image.height != output_height or input_image.width != output_width:
input_image = input_image.resize((output_width, output_height))

output_image = pipeline(
prompt,
**self.pipeline_driver.make_image_param(input_image),
**self.pipeline_driver.make_additional_params(negative_prompts, self.device),
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format=self.output_format.upper())

return ImageArtifact(
value=buffer.getvalue(),
format=self.output_format.lower(),
height=output_image.height,
width=output_image.width,
prompt=prompt,
)

def try_image_inpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Inpainting is not supported by this driver.")

def try_image_outpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Outpainting is not supported by this driver.")
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from attrs import define

if TYPE_CHECKING:
from PIL.Image import Image


@define
class BaseDiffusionImageGenerationPipelineDriver(ABC):
@abstractmethod
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...

@abstractmethod
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...

@abstractmethod
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...

@property
@abstractmethod
def output_image_dimensions(self) -> tuple[int, int]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from PIL.Image import Image


@define
class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver):
"""Image generation model driver for Stable Diffusion 3 models with ControlNet.
For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline:
https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3
Attributes:
controlnet_model: The ControlNet model to use for image generation.
controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None.
"""

controlnet_model: str = field(kw_only=True)
controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel
sd3_controlnet_pipeline = import_optional_dependency(
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
).StableDiffusion3ControlNetPipeline

pipeline_params = {}
controlnet_pipeline_params = {}
if self.torch_dtype is not None:
pipeline_params["torch_dtype"] = self.torch_dtype
controlnet_pipeline_params["torch_dtype"] = self.torch_dtype

if self.drop_t5_encoder:
pipeline_params["text_encoder_3"] = None
pipeline_params["tokenizer_3"] = None

# For both Stable Diffusion and ControlNet, models can be provided either
# as a path to a local file or as a HuggingFace model repo name.
# We use the from_single_file method if the model is a local file and the
# from_pretrained method if the model is a local directory or hosted on HuggingFace.
if os.path.isfile(self.controlnet_model):
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
self.controlnet_model, **controlnet_pipeline_params
)
else:
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
self.controlnet_model, **controlnet_pipeline_params
)

if os.path.isfile(model):
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)
else:
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

if self.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()

if device is not None:
pipeline.to(device)

return pipeline

def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
if image is None:
raise ValueError("Input image is required for ControlNet pipelines.")

return {"control_image": image}

def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
additional_params = super().make_additional_params(negative_prompts, device)

del additional_params["height"]
del additional_params["width"]

if self.controlnet_conditioning_scale is not None:
additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale

return additional_params
Loading

0 comments on commit 56b4104

Please sign in to comment.