-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion 3 local support #1018
Changes from 11 commits
031aabd
1e4f3c5
413bbb4
71461bf
10b3702
e81b195
9d9f385
a4e2f89
e6bc346
2742140
e9e9ddf
3d87cb7
2157096
6e74c29
4df8dde
7f0c211
d84c724
3180a18
77304f6
1296a46
40a6e8a
1e72051
ce89d33
7b5950a
25c2d39
1aaeef6
313f8a0
f8d1e09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
from abc import ABC | ||
from typing import Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import ImageArtifact | ||
from griptape.drivers import BaseDiffusionPipelineImageGenerationModelDriver, BaseImageGenerationDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
|
||
@define | ||
class HuggingFaceDiffusionPipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if we should just rename to |
||
"""Image generation driver for models hosted by Hugging Face's Diffusion Pipeline. | ||
|
||
For more information, see the HuggingFace documentation for Diffusers: | ||
https://huggingface.co/docs/diffusers/en/index | ||
|
||
Attributes: | ||
model_driver: A pipeline image generation model driver typed for the specific pipeline required by the model. | ||
device: The hardware device used for inference. For example, "cpu", "cuda", or "mps". | ||
""" | ||
|
||
model_driver: BaseDiffusionPipelineImageGenerationModelDriver = field(kw_only=True, metadata={"serializable": True}) | ||
device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
|
||
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: | ||
pipeline = self.model_driver.prepare_pipeline(self.model, self.device) | ||
|
||
prompt = ", ".join(prompts) | ||
output_image = pipeline( | ||
prompt, **self.model_driver.make_additional_params(negative_prompts, self.device) | ||
).images[0] | ||
|
||
buffer = io.BytesIO() | ||
output_image.save(buffer, format="PNG") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other formats besides "PNG"? Maybe move to field? |
||
|
||
return ImageArtifact( | ||
value=buffer.getvalue(), format="png", height=output_image.height, width=output_image.width, prompt=prompt | ||
) | ||
|
||
def try_image_variation( | ||
self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None | ||
) -> ImageArtifact: | ||
pipeline = self.model_driver.prepare_pipeline(self.model, self.device) | ||
|
||
prompt = ", ".join(prompts) | ||
input_image = import_optional_dependency("PIL.Image").open(io.BytesIO(image.value)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports should happen at the top of the function as if it were a regular |
||
# The size of the input image drives the size of the output image. | ||
# Resize the input image to the configured dimensions. | ||
requested_dimensions = self.model_driver.get_output_image_dimensions() | ||
if requested_dimensions is not None and ( | ||
input_image.height != requested_dimensions[0] or input_image.width != requested_dimensions[1] | ||
): | ||
input_image = input_image.resize(requested_dimensions) | ||
Check warning on line 57 in griptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py Codecov / codecov/patchgriptape/drivers/image_generation/huggingface_diffusion_pipeline_image_generation_driver.py#L57
|
||
|
||
output_image = pipeline( | ||
prompt, | ||
**self.model_driver.make_image_param(input_image), | ||
**self.model_driver.make_additional_params(negative_prompts, self.device), | ||
).images[0] | ||
|
||
buffer = io.BytesIO() | ||
output_image.save(buffer, format="PNG") | ||
|
||
return ImageArtifact( | ||
value=buffer.getvalue(), format="png", height=output_image.height, width=output_image.width, prompt=prompt | ||
) | ||
|
||
def try_image_inpainting( | ||
self, | ||
prompts: list[str], | ||
image: ImageArtifact, | ||
mask: ImageArtifact, | ||
negative_prompts: Optional[list[str]] = None, | ||
) -> ImageArtifact: | ||
raise NotImplementedError("Inpainting is not supported by this driver.") | ||
|
||
def try_image_outpainting( | ||
self, | ||
prompts: list[str], | ||
image: ImageArtifact, | ||
mask: ImageArtifact, | ||
negative_prompts: Optional[list[str]] = None, | ||
) -> ImageArtifact: | ||
raise NotImplementedError("Outpainting is not supported by this driver.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import define | ||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
|
||
|
||
@define | ||
class BaseDiffusionPipelineImageGenerationModelDriver(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a new base class? Can we instead use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The interface difference between these two model driver types is substantial. Drivers for remote models package inputs and unpackage outputs, drivers for remote models package inputs and prepare diffusion pipelines. We could certainly merge these together, but I think that would increase the overall awkwardness of this implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood. Should any of these model driver classes be renamed with the new name for |
||
@abstractmethod | ||
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ... | ||
|
||
@abstractmethod | ||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is image optional if every implementation seems to require it? |
||
|
||
@abstractmethod | ||
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ... | ||
|
||
@abstractmethod | ||
def get_output_image_dimensions(self) -> Optional[tuple[int, int]]: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should maybe be a |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
else: | ||
StableDiffusion3ControlNetPipeline = import_optional_dependency( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stable Diffusion's pipelines include versions in their class names and are distinct enough that we would need a driver for each, much in the same way that ControlNet and Img2Img are distinct here. Amaru pointed out that Stable Diffusion 3 is still quite new and many people still prefer their Stable Diffusion XL workflows — to support this, we'd need another driver typed for StableDiffusionXL (same story for SD1.5 and SD2). I don't like the aesthetics of it, but this leaves some space for other SD versions. |
||
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" | ||
).StableDiffusion3ControlNetPipeline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't have this import here as it'll break bare installs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops, leftover test change. Removed. |
||
|
||
|
||
@define | ||
class StableDiffusion3ControlNetPipelineImageGenerationModelDriver(StableDiffusion3PipelineImageGenerationModelDriver): | ||
"""Image generation model driver for Stable Diffusion 3 models with ControlNet. | ||
|
||
For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline: | ||
https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3 | ||
|
||
Attributes: | ||
controlnet_model: The ControlNet model to use for image generation. | ||
controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None. | ||
""" | ||
|
||
controlnet_model: str = field(kw_only=True) | ||
controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
|
||
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: | ||
collindutter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pipeline_params = {} | ||
controlnet_pipeline_params = {} | ||
if self.torch_dtype is not None: | ||
pipeline_params["torch_dtype"] = self.torch_dtype | ||
controlnet_pipeline_params["torch_dtype"] = self.torch_dtype | ||
Check warning on line 39 in griptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py Codecov / codecov/patchgriptape/drivers/image_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py#L38-L39
|
||
|
||
# For both Stable Diffusion and ControlNet, models can be provided either | ||
# as a path to a local file or as a HuggingFace model repo name. | ||
# We use the from_single_file method if the model is a local file and the | ||
# from_pretrained method if the model is a local directory or hosted on HuggingFace. | ||
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports should happen at the top of the function as if it were a regular |
||
if os.path.isfile(self.controlnet_model): | ||
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( | ||
self.controlnet_model, **controlnet_pipeline_params | ||
) | ||
|
||
else: | ||
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( | ||
self.controlnet_model, **controlnet_pipeline_params | ||
) | ||
|
||
sd3_controlnet_pipeline = import_optional_dependency( | ||
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" | ||
).StableDiffusion3ControlNetPipeline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports should happen at the top of the function as if it were a regular |
||
if os.path.isfile(model): | ||
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird that the autoformatter didn't format away this whitespace |
||
else: | ||
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) | ||
|
||
if device is not None: | ||
pipeline.to(device) | ||
|
||
return pipeline | ||
|
||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: | ||
if image is None: | ||
raise ValueError("Input image is required for ControlNet pipelines.") | ||
|
||
return {"control_image": image} | ||
|
||
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: | ||
additional_params = super().make_additional_params(negative_prompts, device) | ||
|
||
del additional_params["height"] | ||
del additional_params["width"] | ||
|
||
if self.controlnet_conditioning_scale is not None: | ||
additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale | ||
|
||
return additional_params |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
|
||
|
||
@define | ||
class StableDiffusion3Img2ImgPipelineImageGenerationModelDriver(StableDiffusion3PipelineImageGenerationModelDriver): | ||
"""Image generation model driver for Stable Diffusion 3 model image to image pipelines. | ||
|
||
For more information, see the HuggingFace documentation for the StableDiffusion3Img2ImgPipeline: | ||
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py | ||
|
||
Attributes: | ||
strength: A value [0.0, 1.0] that determines the strength of the initial image in the output. | ||
""" | ||
|
||
strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
|
||
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: | ||
pipeline_params = {} | ||
if self.torch_dtype is not None: | ||
pipeline_params["torch_dtype"] = self.torch_dtype | ||
Check warning on line 31 in griptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py Codecov / codecov/patchgriptape/drivers/image_generation_model/stable_diffusion_3_img_2_img_pipeline_image_generation_model_driver.py#L31
|
||
|
||
# A model can be provided either as a path to a local file | ||
# or as a HuggingFace model repo name. | ||
sd3_img2img_pipeline = import_optional_dependency( | ||
"diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img" | ||
).StableDiffusion3Img2ImgPipeline | ||
if os.path.isfile(model): | ||
# If the model provided is a local file (not a directory), | ||
# we load it using the from_single_file method. | ||
|
||
raise NotImplementedError( | ||
"StableDiffusion3Img2ImgPipeline does not yet support loading from a single file." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For eventual ComfyUI convenience, we accept three model input types:
The one exception to this is the StableDiffusion3Img2ImgPipeline, which doesn't support |
||
else: | ||
# If the model is a local directory or hosted on HuggingFace, | ||
# we load it using the from_pretrained method. | ||
pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params) | ||
|
||
# Move inference to particular device if requested. | ||
if device is not None: | ||
pipeline.to(device) | ||
|
||
return pipeline | ||
|
||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: | ||
if image is None: | ||
raise ValueError("Input image is required for image to image pipelines.") | ||
|
||
return {"image": image} | ||
|
||
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: | ||
additional_params = super().make_additional_params(negative_prompts, device) | ||
|
||
# Explicit height and width params are not supported, but | ||
# are instead inferred from input image. | ||
del additional_params["height"] | ||
del additional_params["width"] | ||
|
||
if self.strength is not None: | ||
additional_params["strength"] = self.strength | ||
|
||
return additional_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this subclass
BaseImageGenerationDriver
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean
BaseMultiModelImageGenerationDriver
. That was the original intention, but the Pipeline drivers require a substantially different interface than those that inherit fromBaseImageGenerationModelDriver
as required by theBaseMultiModelImageGenerationDriver
.