Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 3 local support #1018

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
031aabd
Add 🤗 diffusers image generation driver
andrewfrench Jul 25, 2024
1e4f3c5
Add image generation model drivers for diffusers
andrewfrench Jul 25, 2024
413bbb4
Add driver for sd3+controlnet
andrewfrench Jul 25, 2024
71461bf
Export new drivers
andrewfrench Jul 25, 2024
10b3702
Optional dependencies update
andrewfrench Jul 25, 2024
e81b195
Import fixes
andrewfrench Jul 25, 2024
9d9f385
Linter fixes, pyright fixes
andrewfrench Jul 25, 2024
a4e2f89
Fix make format artifact
andrewfrench Jul 25, 2024
e6bc346
Docstrings
andrewfrench Jul 25, 2024
2742140
Add unit tests, small fixes
andrewfrench Jul 25, 2024
e9e9ddf
Update img2img input field name
andrewfrench Jul 25, 2024
3d87cb7
Address comments, update poetry.lock
andrewfrench Jul 25, 2024
2157096
whoops, fix tests
andrewfrench Jul 26, 2024
6e74c29
Configurable image output format
andrewfrench Jul 26, 2024
4df8dde
Update docstring
andrewfrench Jul 26, 2024
7f0c211
Refactor pipeline drivers
andrewfrench Jul 26, 2024
d84c724
model_driver -> pipeline_driver
andrewfrench Jul 26, 2024
3180a18
Downgrade torch
andrewfrench Jul 26, 2024
77304f6
Expose memory saving pipeline options
andrewfrench Jul 27, 2024
1296a46
Mark torch as optional again
andrewfrench Jul 27, 2024
40a6e8a
Support from_single_file for img2img pipeline
andrewfrench Jul 27, 2024
1e72051
Add tests for new options
andrewfrench Jul 28, 2024
ce89d33
Transfer fork to griptape-ai
andrewfrench Jul 28, 2024
7b5950a
git@ -> https://
andrewfrench Jul 28, 2024
25c2d39
poetry lock --no-update
andrewfrench Jul 28, 2024
1aaeef6
Update reference links
andrewfrench Jul 28, 2024
313f8a0
Test coverage
andrewfrench Jul 28, 2024
f8d1e09
Merge branch 'dev' into sd3-local
collindutter Jul 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@
BedrockStableDiffusionImageGenerationModelDriver,
)
from .image_generation_model.bedrock_titan_image_generation_model_driver import BedrockTitanImageGenerationModelDriver
from .image_generation_model.stable_diffusion_3_pipeline_image_generation_model_driver import (
StableDiffusion3PipelineImageGenerationModelDriver,
)
from .image_generation_model.base_diffusion_pipeline_image_generation_model_driver import (
BaseDiffusionPipelineImageGenerationModelDriver,
)
from .image_generation_model.stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver import (
StableDiffusion3Img2ImgPipelineImageGenerationModelDriver,
)
from .image_generation_model.stable_diffusion_3_controlnet_pipeline_image_generation_model_driver import (
StableDiffusion3ControlNetPipelineImageGenerationModelDriver,
)

from .image_generation.base_image_generation_driver import BaseImageGenerationDriver
from .image_generation.base_multi_model_image_generation_driver import BaseMultiModelImageGenerationDriver
Expand All @@ -61,6 +73,9 @@
from .image_generation.amazon_bedrock_image_generation_driver import AmazonBedrockImageGenerationDriver
from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver
from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver
from .image_generation.huggingface_diffusion_pipeline_image_generation_driver import (
HuggingFaceDiffusionPipelineImageGenerationDriver,
)

from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver
from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver
Expand Down Expand Up @@ -164,13 +179,18 @@
"BaseImageGenerationModelDriver",
"BedrockStableDiffusionImageGenerationModelDriver",
"BedrockTitanImageGenerationModelDriver",
"BaseDiffusionPipelineImageGenerationModelDriver",
"StableDiffusion3PipelineImageGenerationModelDriver",
"StableDiffusion3Img2ImgPipelineImageGenerationModelDriver",
"StableDiffusion3ControlNetPipelineImageGenerationModelDriver",
"BaseImageGenerationDriver",
"BaseMultiModelImageGenerationDriver",
"OpenAiImageGenerationDriver",
"LeonardoImageGenerationDriver",
"AmazonBedrockImageGenerationDriver",
"AzureOpenAiImageGenerationDriver",
"DummyImageGenerationDriver",
"HuggingFaceDiffusionPipelineImageGenerationDriver",
"BaseImageQueryModelDriver",
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import io
from abc import ABC
from typing import Optional

from attrs import define, field

from griptape.artifacts import ImageArtifact
from griptape.drivers import BaseDiffusionPipelineImageGenerationModelDriver, BaseImageGenerationDriver
from griptape.utils import import_optional_dependency


@define
class HuggingFaceDiffusionPipelineImageGenerationDriver(BaseImageGenerationDriver, ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this subclass BaseImageGenerationDriver instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean BaseMultiModelImageGenerationDriver. That was the original intention, but the Pipeline drivers require a substantially different interface than those that inherit from BaseImageGenerationModelDriver as required by the BaseMultiModelImageGenerationDriver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should just rename to HuggingFacePipelineImageGenerationDriver for similarity with the HuggingFacePipelinePromptDriver that uses transformers.

"""Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.

For more information, see the HuggingFace documentation for Diffusers:
https://huggingface.co/docs/diffusers/en/index

Attributes:
model_driver: A pipeline image generation model driver typed for the specific pipeline required by the model.
device: The hardware device used for inference. For example, "cpu", "cuda", or "mps".
"""

model_driver: BaseDiffusionPipelineImageGenerationModelDriver = field(kw_only=True, metadata={"serializable": True})
device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
pipeline = self.model_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
output_image = pipeline(
prompt, **self.model_driver.make_additional_params(negative_prompts, self.device)
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format="PNG")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other formats besides "PNG"? Maybe move to field?


return ImageArtifact(
value=buffer.getvalue(), format="png", height=output_image.height, width=output_image.width, prompt=prompt
)

def try_image_variation(
self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
pipeline = self.model_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
input_image = import_optional_dependency("PIL.Image").open(io.BytesIO(image.value))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should happen at the top of the function as if it were a regular import.

# The size of the input image drives the size of the output image.
# Resize the input image to the configured dimensions.
requested_dimensions = self.model_driver.get_output_image_dimensions()
if requested_dimensions is not None and (
input_image.height != requested_dimensions[0] or input_image.width != requested_dimensions[1]
):
input_image = input_image.resize(requested_dimensions)

Check warning on line 57 in griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py#L57

Added line #L57 was not covered by tests

output_image = pipeline(
prompt,
**self.model_driver.make_image_param(input_image),
**self.model_driver.make_additional_params(negative_prompts, self.device),
).images[0]

buffer = io.BytesIO()
output_image.save(buffer, format="PNG")

return ImageArtifact(
value=buffer.getvalue(), format="png", height=output_image.height, width=output_image.width, prompt=prompt
)

def try_image_inpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Inpainting is not supported by this driver.")

def try_image_outpainting(
self,
prompts: list[str],
image: ImageArtifact,
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise NotImplementedError("Outpainting is not supported by this driver.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from attrs import define

if TYPE_CHECKING:
from PIL.Image import Image


@define
class BaseDiffusionPipelineImageGenerationModelDriver(ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a new base class? Can we instead use BaseImageGenerationModelDriver?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface difference between these two model driver types is substantial. Drivers for remote models package inputs and unpackage outputs, drivers for remote models package inputs and prepare diffusion pipelines. We could certainly merge these together, but I think that would increase the overall awkwardness of this implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Should any of these model driver classes be renamed with the new name for HuggingFacePipelineImageGenerationDriver?

@abstractmethod
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...

@abstractmethod
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is image optional if every implementation seems to require it?


@abstractmethod
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...

@abstractmethod
def get_output_image_dimensions(self) -> Optional[tuple[int, int]]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should maybe be a @property?

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from PIL.Image import Image
else:
StableDiffusion3ControlNetPipeline = import_optional_dependency(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having 3 in the file/class name feels odd, maybe let's remove?

Copy link
Member Author

@andrewfrench andrewfrench Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stable Diffusion's pipelines include versions in their class names and are distinct enough that we would need a driver for each, much in the same way that ControlNet and Img2Img are distinct here. Amaru pointed out that Stable Diffusion 3 is still quite new and many people still prefer their Stable Diffusion XL workflows — to support this, we'd need another driver typed for StableDiffusionXL (same story for SD1.5 and SD2). I don't like the aesthetics of it, but this leaves some space for other SD versions.

"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
).StableDiffusion3ControlNetPipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't have this import here as it'll break bare installs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, leftover test change. Removed.



@define
class StableDiffusion3ControlNetPipelineImageGenerationModelDriver(StableDiffusion3PipelineImageGenerationModelDriver):
"""Image generation model driver for Stable Diffusion 3 models with ControlNet.

For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline:
https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3

Attributes:
controlnet_model: The ControlNet model to use for image generation.
controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None.
"""

controlnet_model: str = field(kw_only=True)
controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
collindutter marked this conversation as resolved.
Show resolved Hide resolved
pipeline_params = {}
controlnet_pipeline_params = {}
if self.torch_dtype is not None:
pipeline_params["torch_dtype"] = self.torch_dtype
controlnet_pipeline_params["torch_dtype"] = self.torch_dtype

Check warning on line 39 in griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py#L38-L39

Added lines #L38 - L39 were not covered by tests

# For both Stable Diffusion and ControlNet, models can be provided either
# as a path to a local file or as a HuggingFace model repo name.
# We use the from_single_file method if the model is a local file and the
# from_pretrained method if the model is a local directory or hosted on HuggingFace.
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should happen at the top of the function as if it were a regular import.

if os.path.isfile(self.controlnet_model):
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
self.controlnet_model, **controlnet_pipeline_params
)

else:
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
self.controlnet_model, **controlnet_pipeline_params
)

sd3_controlnet_pipeline = import_optional_dependency(
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
).StableDiffusion3ControlNetPipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should happen at the top of the function as if it were a regular import.

if os.path.isfile(model):
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird that the autoformatter didn't format away this whitespace

else:
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

if device is not None:
pipeline.to(device)

return pipeline

def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
if image is None:
raise ValueError("Input image is required for ControlNet pipelines.")

return {"control_image": image}

def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
additional_params = super().make_additional_params(negative_prompts, device)

del additional_params["height"]
del additional_params["width"]

if self.controlnet_conditioning_scale is not None:
additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale

return additional_params
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from PIL.Image import Image


@define
class StableDiffusion3Img2ImgPipelineImageGenerationModelDriver(StableDiffusion3PipelineImageGenerationModelDriver):
"""Image generation model driver for Stable Diffusion 3 model image to image pipelines.

For more information, see the HuggingFace documentation for the StableDiffusion3Img2ImgPipeline:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Attributes:
strength: A value [0.0, 1.0] that determines the strength of the initial image in the output.
"""

strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
pipeline_params = {}
if self.torch_dtype is not None:
pipeline_params["torch_dtype"] = self.torch_dtype

Check warning on line 31 in griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py#L31

Added line #L31 was not covered by tests

# A model can be provided either as a path to a local file
# or as a HuggingFace model repo name.
sd3_img2img_pipeline = import_optional_dependency(
"diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img"
).StableDiffusion3Img2ImgPipeline
if os.path.isfile(model):
# If the model provided is a local file (not a directory),
# we load it using the from_single_file method.

raise NotImplementedError(
"StableDiffusion3Img2ImgPipeline does not yet support loading from a single file."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eventual ComfyUI convenience, we accept three model input types:

  • path to a single file containing a model
  • path to a directory containing model files
  • HuggingFace model repo name

The one exception to this is the StableDiffusion3Img2ImgPipeline, which doesn't support .from_single_file(). Models can still be loaded by path to a local directory or by model repo name (and not downloaded again if they're cached locally).

else:
# If the model is a local directory or hosted on HuggingFace,
# we load it using the from_pretrained method.
pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params)

# Move inference to particular device if requested.
if device is not None:
pipeline.to(device)

return pipeline

def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
if image is None:
raise ValueError("Input image is required for image to image pipelines.")

return {"image": image}

def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
additional_params = super().make_additional_params(negative_prompts, device)

# Explicit height and width params are not supported, but
# are instead inferred from input image.
del additional_params["height"]
del additional_params["width"]

if self.strength is not None:
additional_params["strength"] = self.strength

return additional_params
Loading
Loading