Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 3 local support #1018

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
031aabd
Add 🤗 diffusers image generation driver
andrewfrench Jul 25, 2024
1e4f3c5
Add image generation model drivers for diffusers
andrewfrench Jul 25, 2024
413bbb4
Add driver for sd3+controlnet
andrewfrench Jul 25, 2024
71461bf
Export new drivers
andrewfrench Jul 25, 2024
10b3702
Optional dependencies update
andrewfrench Jul 25, 2024
e81b195
Import fixes
andrewfrench Jul 25, 2024
9d9f385
Linter fixes, pyright fixes
andrewfrench Jul 25, 2024
a4e2f89
Fix make format artifact
andrewfrench Jul 25, 2024
e6bc346
Docstrings
andrewfrench Jul 25, 2024
2742140
Add unit tests, small fixes
andrewfrench Jul 25, 2024
e9e9ddf
Update img2img input field name
andrewfrench Jul 25, 2024
3d87cb7
Address comments, update poetry.lock
andrewfrench Jul 25, 2024
2157096
whoops, fix tests
andrewfrench Jul 26, 2024
6e74c29
Configurable image output format
andrewfrench Jul 26, 2024
4df8dde
Update docstring
andrewfrench Jul 26, 2024
7f0c211
Refactor pipeline drivers
andrewfrench Jul 26, 2024
d84c724
model_driver -> pipeline_driver
andrewfrench Jul 26, 2024
3180a18
Downgrade torch
andrewfrench Jul 26, 2024
77304f6
Expose memory saving pipeline options
andrewfrench Jul 27, 2024
1296a46
Mark torch as optional again
andrewfrench Jul 27, 2024
40a6e8a
Support from_single_file for img2img pipeline
andrewfrench Jul 27, 2024
1e72051
Add tests for new options
andrewfrench Jul 28, 2024
ce89d33
Transfer fork to griptape-ai
andrewfrench Jul 28, 2024
7b5950a
git@ -> https://
andrewfrench Jul 28, 2024
25c2d39
poetry lock --no-update
andrewfrench Jul 28, 2024
1aaeef6
Update reference links
andrewfrench Jul 28, 2024
313f8a0
Test coverage
andrewfrench Jul 28, 2024
f8d1e09
Merge branch 'dev' into sd3-local
collindutter Jul 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions docs/griptape-framework/drivers/image-generation-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,130 @@ agent = Agent(tools=[

agent.run("Generate a watercolor painting of a dog riding a skateboard")
```

### HuggingFace Pipelines

!!! info
This driver requires the `drivers-image-generation-huggingface` [extra](../index.md#extras).

The [HuggingFace Pipelines Image Generation Driver](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) enables image generation through locally-hosted models using the HuggingFace [Diffusers](https://huggingface.co/docs/diffusers/en/index) library. This Driver requires a [Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/index.md) to prepare the appropriate Pipeline.

This Driver requires a `model` configuration, specifying the model to use for image generation. The value of the `model` configuration must be one of the following:

- A model name from the HuggingFace Model Hub, like `stabilityai/stable-diffusion-3-medium-diffusers`
- A path to the directory containing a model on the filesystem, like `./models/stable-diffusion-3/`
- A path to a file containing a model on the filesystem, like `./models/sd3_medium_incl_clips.safetensors`

The `device` configuration specifies the hardware device used to run inference. Common values include `cuda` (supporting CUDA-enabled GPUs), `cpu` (supported by a device's CPU), and `mps` (supported by Apple silicon GPUs). For more information, see [HuggingFace's documentation](https://huggingface.co/docs/transformers/en/perf_infer_gpu_one) on GPU inference.

#### Stable Diffusion 3 Image Generation Pipeline Driver

!!! info
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [Stable Diffusion 3 Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.md) provides a Stable `Diffusion3DiffusionPipeline` for text-to-image generations via the [HuggingFace Pipelines Image Generation Driver's](../../reference/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.md) `.try_text_to_image()` method. This Driver accepts a text prompt and configurations including Stable Diffusion 3 model, output image size, generation seed, and inference steps.

Image generation consumes substantial memory. On devices with limited VRAM, it may be necessary to enable the `enable_model_cpu_offload` or `drop_t5_encoder` configurations. For more information, see [HuggingFace's documentation](https://huggingface.co/docs/diffusers/en/optimization/memory) on reduced memory usage.

```python title="PYTEST_IGNORE"
from griptape.structures import Pipeline
from griptape.tasks import PromptImageGenerationTask
from griptape.engines import PromptImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3ImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact

image_generation_task = PromptImageGenerationTask(
input=TextArtifact("landscape photograph, verdant, countryside, 8k"),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3ImageGenerationPipelineDriver(
height=512,
width=512,
)
)
)
)

output_artifact = Pipeline(tasks=[image_generation_task]).run().output
```

#### Stable Diffusion 3 Img2Img Image Generation Pipeline Driver

!!! info
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [Stable Diffusion 3 Img2Img Image Generation Pipeline Driver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.md) provides a `StableDiffusion3Img2ImgPipeline` for image-to-image generations, accepting a text prompt and input image. This Driver accepts a text prompt, an input image, and configurations including Stable Diffusion 3 model, output image size, inference steps, generation seed, and strength of generation over the input image.

```python title="PYTEST_IGNORE"
from pathlib import Path

from griptape.structures import Pipeline
from griptape.tasks import VariationImageGenerationTask
from griptape.engines import VariationImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3Img2ImgImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.loaders import ImageLoader

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())

image_variation_task = VariationImageGenerationTask(
input=(prompt_artifact, input_image_artifact),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3Img2ImgImageGenerationPipelineDriver(
height=1024,
width=1024,
)
)
)
)

output_artifact = Pipeline(tasks=[image_variation_task]).run().output
```

#### StableDiffusion3ControlNetImageGenerationPipelineDriver

!!! note
The `Stable Diffusion 3 Image Generation Pipeline Driver` requires the `drivers-image-generation-huggingface` extra.

The [StableDiffusion3ControlNetImageGenerationPipelineDriver](../../reference/griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.md) provides a `StableDiffusion3ControlNetPipeline` for image-to-image generations, accepting a text prompt and a control image. This Driver accepts a text prompt, a control image, and configurations including Stable Diffusion 3 model, ControlNet model, output image size, generation seed, inference steps, and the degree to which the model adheres to the control image.

```python title="PYTEST_IGNORE"
from pathlib import Path

from griptape.structures import Pipeline
from griptape.tasks import VariationImageGenerationTask
from griptape.engines import VariationImageGenerationEngine
from griptape.drivers import HuggingFacePipelineImageGenerationDriver, \
StableDiffusion3ControlNetImageGenerationPipelineDriver
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.loaders import ImageLoader

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes())

controlnet_task = VariationImageGenerationTask(
input=(prompt_artifact, control_image_artifact),
image_generation_engine=PromptImageGenerationEngine(
image_generation_driver=HuggingFacePipelineImageGenerationDriver(
model="stabilityai/stable-diffusion-3-medium-diffusers",
device="cuda",
pipeline_driver=StableDiffusion3ControlNetImageGenerationPipelineDriver(
controlnet_model="InstantX/SD3-Controlnet-Canny",
control_strength=0.8,
height=768,
width=1024,
)
)
)
)

output_artifact = Pipeline(tasks=[controlnet_task]).run().output
```
21 changes: 21 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,29 @@
)
from .image_generation_model.bedrock_titan_image_generation_model_driver import BedrockTitanImageGenerationModelDriver

from .image_generation_pipeline.base_image_generation_pipeline_driver import (
BaseDiffusionImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_image_generation_pipeline_driver import (
StableDiffusion3ImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_img_2_img_image_generation_pipeline_driver import (
StableDiffusion3Img2ImgImageGenerationPipelineDriver,
)
from .image_generation_pipeline.stable_diffusion_3_controlnet_image_generation_pipeline_driver import (
StableDiffusion3ControlNetImageGenerationPipelineDriver,
)

from .image_generation.base_image_generation_driver import BaseImageGenerationDriver
from .image_generation.base_multi_model_image_generation_driver import BaseMultiModelImageGenerationDriver
from .image_generation.openai_image_generation_driver import OpenAiImageGenerationDriver
from .image_generation.leonardo_image_generation_driver import LeonardoImageGenerationDriver
from .image_generation.amazon_bedrock_image_generation_driver import AmazonBedrockImageGenerationDriver
from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver
from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver
from .image_generation.huggingface_pipeline_image_generation_driver import (
HuggingFacePipelineImageGenerationDriver,
)

from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver
from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver
Expand Down Expand Up @@ -164,13 +180,18 @@
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
"BaseDiffusionImageGenerationPipelineDriver",
"StableDiffusion3ImageGenerationPipelineDriver",
"StableDiffusion3Img2ImgImageGenerationPipelineDriver",
"StableDiffusion3ControlNetImageGenerationPipelineDriver",
"BaseImageGenerationDriver",
"BaseMultiModelImageGenerationDriver",
"OpenAiImageGenerationDriver",
"LeonardoImageGenerationDriver",
"AmazonBedrockImageGenerationDriver",
"AzureOpenAiImageGenerationDriver",
"DummyImageGenerationDriver",
"HuggingFacePipelineImageGenerationDriver",
"BaseImageQueryModelDriver",
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import io
from abc import ABC
from typing import Optional

from attrs import define, field

from griptape.artifacts import ImageArtifact
from griptape.drivers import BaseDiffusionImageGenerationPipelineDriver, BaseImageGenerationDriver
from griptape.utils import import_optional_dependency


@define
class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC):
"""Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.

For more information, see the HuggingFace documentation for Diffusers:
https://huggingface.co/docs/diffusers/en/index

Attributes:
pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model.
device: The hardware device used for inference. For example, "cpu", "cuda", or "mps".
output_format: The format the generated image is returned in. Defaults to "png".
"""

pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True})
device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
output_format: str = field(default="png", kw_only=True, metadata={"serializable": True})

def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
output_image = pipeline(
prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format=self.output_format.upper())

return ImageArtifact(
value=buffer.getvalue(),
format=self.output_format.lower(),
height=output_image.height,
width=output_image.width,
prompt=prompt,
)

def try_image_variation(
self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
pil_image = import_optional_dependency("PIL.Image")

pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
input_image = pil_image.open(io.BytesIO(image.value))
# The size of the input image drives the size of the output image.
# Resize the input image to the configured dimensions.
output_width, output_height = self.pipeline_driver.output_image_dimensions
if input_image.height != output_height or input_image.width != output_width:
input_image = input_image.resize((output_width, output_height))

Check warning on line 63 in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py#L63

Added line #L63 was not covered by tests

output_image = pipeline(
prompt,
**self.pipeline_driver.make_image_param(input_image),
**self.pipeline_driver.make_additional_params(negative_prompts, self.device),
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format=self.output_format.upper())

return ImageArtifact(
value=buffer.getvalue(),
format=self.output_format.lower(),
height=output_image.height,
width=output_image.width,
prompt=prompt,
)

def try_image_inpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Inpainting is not supported by this driver.")

def try_image_outpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Outpainting is not supported by this driver.")
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from attrs import define

if TYPE_CHECKING:
from PIL.Image import Image


@define
class BaseDiffusionImageGenerationPipelineDriver(ABC):
@abstractmethod
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...

@abstractmethod
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...

@abstractmethod
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...

@property
@abstractmethod
def output_image_dimensions(self) -> tuple[int, int]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from PIL.Image import Image


@define
class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver):
"""Image generation model driver for Stable Diffusion 3 models with ControlNet.

For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline:
https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3

Attributes:
controlnet_model: The ControlNet model to use for image generation.
controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None.
"""

controlnet_model: str = field(kw_only=True)
controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel
sd3_controlnet_pipeline = import_optional_dependency(
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
).StableDiffusion3ControlNetPipeline

pipeline_params = {}
controlnet_pipeline_params = {}
if self.torch_dtype is not None:
pipeline_params["torch_dtype"] = self.torch_dtype
controlnet_pipeline_params["torch_dtype"] = self.torch_dtype

if self.drop_t5_encoder:
pipeline_params["text_encoder_3"] = None
pipeline_params["tokenizer_3"] = None

# For both Stable Diffusion and ControlNet, models can be provided either
# as a path to a local file or as a HuggingFace model repo name.
# We use the from_single_file method if the model is a local file and the
# from_pretrained method if the model is a local directory or hosted on HuggingFace.
if os.path.isfile(self.controlnet_model):
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
self.controlnet_model, **controlnet_pipeline_params
)
else:
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
self.controlnet_model, **controlnet_pipeline_params
)

if os.path.isfile(model):
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)
else:
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

if self.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()

if device is not None:
pipeline.to(device)

return pipeline

def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
if image is None:
raise ValueError("Input image is required for ControlNet pipelines.")

return {"control_image": image}

def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
additional_params = super().make_additional_params(negative_prompts, device)

del additional_params["height"]
del additional_params["width"]

if self.controlnet_conditioning_scale is not None:
additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale

return additional_params
Loading
Loading