-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable Diffusion 3 local support (#1018)
- Loading branch information
1 parent
54cd8cc
commit 56b4104
Showing
15 changed files
with
1,058 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
from abc import ABC | ||
from typing import Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import ImageArtifact | ||
from griptape.drivers import BaseDiffusionImageGenerationPipelineDriver, BaseImageGenerationDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
|
||
@define | ||
class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): | ||
"""Image generation driver for models hosted by Hugging Face's Diffusion Pipeline. | ||
For more information, see the HuggingFace documentation for Diffusers: | ||
https://huggingface.co/docs/diffusers/en/index | ||
Attributes: | ||
pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model. | ||
device: The hardware device used for inference. For example, "cpu", "cuda", or "mps". | ||
output_format: The format the generated image is returned in. Defaults to "png". | ||
""" | ||
|
||
pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True}) | ||
device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
output_format: str = field(default="png", kw_only=True, metadata={"serializable": True}) | ||
|
||
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: | ||
pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) | ||
|
||
prompt = ", ".join(prompts) | ||
output_image = pipeline( | ||
prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device) | ||
).images[0] | ||
|
||
buffer = io.BytesIO() | ||
output_image.save(buffer, format=self.output_format.upper()) | ||
|
||
return ImageArtifact( | ||
value=buffer.getvalue(), | ||
format=self.output_format.lower(), | ||
height=output_image.height, | ||
width=output_image.width, | ||
prompt=prompt, | ||
) | ||
|
||
def try_image_variation( | ||
self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None | ||
) -> ImageArtifact: | ||
pil_image = import_optional_dependency("PIL.Image") | ||
|
||
pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) | ||
|
||
prompt = ", ".join(prompts) | ||
input_image = pil_image.open(io.BytesIO(image.value)) | ||
# The size of the input image drives the size of the output image. | ||
# Resize the input image to the configured dimensions. | ||
output_width, output_height = self.pipeline_driver.output_image_dimensions | ||
if input_image.height != output_height or input_image.width != output_width: | ||
input_image = input_image.resize((output_width, output_height)) | ||
|
||
output_image = pipeline( | ||
prompt, | ||
**self.pipeline_driver.make_image_param(input_image), | ||
**self.pipeline_driver.make_additional_params(negative_prompts, self.device), | ||
).images[0] | ||
|
||
buffer = io.BytesIO() | ||
output_image.save(buffer, format=self.output_format.upper()) | ||
|
||
return ImageArtifact( | ||
value=buffer.getvalue(), | ||
format=self.output_format.lower(), | ||
height=output_image.height, | ||
width=output_image.width, | ||
prompt=prompt, | ||
) | ||
|
||
def try_image_inpainting( | ||
self, | ||
prompts: list[str], | ||
image: ImageArtifact, | ||
mask: ImageArtifact, | ||
negative_prompts: Optional[list[str]] = None, | ||
) -> ImageArtifact: | ||
raise NotImplementedError("Inpainting is not supported by this driver.") | ||
|
||
def try_image_outpainting( | ||
self, | ||
prompts: list[str], | ||
image: ImageArtifact, | ||
mask: ImageArtifact, | ||
negative_prompts: Optional[list[str]] = None, | ||
) -> ImageArtifact: | ||
raise NotImplementedError("Outpainting is not supported by this driver.") |
Empty file.
25 changes: 25 additions & 0 deletions
25
griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import define | ||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
|
||
|
||
@define | ||
class BaseDiffusionImageGenerationPipelineDriver(ABC): | ||
@abstractmethod | ||
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ... | ||
|
||
@abstractmethod | ||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ... | ||
|
||
@abstractmethod | ||
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ... | ||
|
||
@property | ||
@abstractmethod | ||
def output_image_dimensions(self) -> tuple[int, int]: ... |
87 changes: 87 additions & 0 deletions
87
...age_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.drivers import StableDiffusion3ImageGenerationPipelineDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
|
||
|
||
@define | ||
class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver): | ||
"""Image generation model driver for Stable Diffusion 3 models with ControlNet. | ||
For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline: | ||
https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3 | ||
Attributes: | ||
controlnet_model: The ControlNet model to use for image generation. | ||
controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None. | ||
""" | ||
|
||
controlnet_model: str = field(kw_only=True) | ||
controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
|
||
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: | ||
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel | ||
sd3_controlnet_pipeline = import_optional_dependency( | ||
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" | ||
).StableDiffusion3ControlNetPipeline | ||
|
||
pipeline_params = {} | ||
controlnet_pipeline_params = {} | ||
if self.torch_dtype is not None: | ||
pipeline_params["torch_dtype"] = self.torch_dtype | ||
controlnet_pipeline_params["torch_dtype"] = self.torch_dtype | ||
|
||
if self.drop_t5_encoder: | ||
pipeline_params["text_encoder_3"] = None | ||
pipeline_params["tokenizer_3"] = None | ||
|
||
# For both Stable Diffusion and ControlNet, models can be provided either | ||
# as a path to a local file or as a HuggingFace model repo name. | ||
# We use the from_single_file method if the model is a local file and the | ||
# from_pretrained method if the model is a local directory or hosted on HuggingFace. | ||
if os.path.isfile(self.controlnet_model): | ||
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( | ||
self.controlnet_model, **controlnet_pipeline_params | ||
) | ||
else: | ||
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( | ||
self.controlnet_model, **controlnet_pipeline_params | ||
) | ||
|
||
if os.path.isfile(model): | ||
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) | ||
else: | ||
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) | ||
|
||
if self.enable_model_cpu_offload: | ||
pipeline.enable_model_cpu_offload() | ||
|
||
if device is not None: | ||
pipeline.to(device) | ||
|
||
return pipeline | ||
|
||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: | ||
if image is None: | ||
raise ValueError("Input image is required for ControlNet pipelines.") | ||
|
||
return {"control_image": image} | ||
|
||
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: | ||
additional_params = super().make_additional_params(negative_prompts, device) | ||
|
||
del additional_params["height"] | ||
del additional_params["width"] | ||
|
||
if self.controlnet_conditioning_scale is not None: | ||
additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale | ||
|
||
return additional_params |
Oops, something went wrong.