-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion 3 local support #1018
Conversation
raise NotImplementedError( | ||
"StableDiffusion3Img2ImgPipeline does not yet support loading from a single file." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For eventual ComfyUI convenience, we accept three model input types:
- path to a single file containing a model
- path to a directory containing model files
- HuggingFace model repo name
The one exception to this is the StableDiffusion3Img2ImgPipeline, which doesn't support .from_single_file()
. Models can still be loaded by path to a local directory or by model repo name (and not downloaded again if they're cached locally).
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
pyproject.toml
Outdated
@@ -134,6 +138,14 @@ drivers-observability-datadog = [ | |||
"opentelemetry-exporter-otlp-proto-http", | |||
] | |||
|
|||
drivers-imagegen-huggingface = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be named drivers-image-generation-huggingface
@abstractmethod | ||
def get_output_image_dimensions(self) -> Optional[tuple[int, int]]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should maybe be a @property
?
|
||
|
||
@define | ||
class HuggingFaceDiffusionPipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this subclass BaseImageGenerationDriver
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean BaseMultiModelImageGenerationDriver
. That was the original intention, but the Pipeline drivers require a substantially different interface than those that inherit from BaseImageGenerationModelDriver
as required by the BaseMultiModelImageGenerationDriver
.
|
||
|
||
@define | ||
class HuggingFaceDiffusionPipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we should just rename to HuggingFacePipelineImageGenerationDriver
for similarity with the HuggingFacePipelinePromptDriver
that uses transformers
.
...age_generation_model/stable_diffusion_3_controlnet_pipeline_image_generation_model_driver.py
Outdated
Show resolved
Hide resolved
# as a path to a local file or as a HuggingFace model repo name. | ||
# We use the from_single_file method if the model is a local file and the | ||
# from_pretrained method if the model is a local directory or hosted on HuggingFace. | ||
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports should happen at the top of the function as if it were a regular import
.
sd3_controlnet_pipeline = import_optional_dependency( | ||
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" | ||
).StableDiffusion3ControlNetPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports should happen at the top of the function as if it were a regular import
.
).StableDiffusion3ControlNetPipeline | ||
if os.path.isfile(model): | ||
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird that the autoformatter didn't format away this whitespace
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ... | ||
|
||
@abstractmethod | ||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is image optional if every implementation seems to require it?
return pipeline | ||
|
||
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is why image is optional, but this feels odd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is odd, and a half-measure. Zooming out, I think the time has come to admit the design where each driver implements a suite of common image generation types (prompt, variation, inpainting, outpainting) isn't working so well: forcing NotImplementedErrors and backwards workarounds like this. We can move this conversation to another venue, but should maybe consider making image generation drivers more granular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please create a ticket to track this work? This feels like an important refactor pre-1.0.
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
else: | ||
StableDiffusion3ControlNetPipeline = import_optional_dependency( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having 3
in the file/class name feels odd, maybe let's remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stable Diffusion's pipelines include versions in their class names and are distinct enough that we would need a driver for each, much in the same way that ControlNet and Img2Img are distinct here. Amaru pointed out that Stable Diffusion 3 is still quite new and many people still prefer their Stable Diffusion XL workflows — to support this, we'd need another driver typed for StableDiffusionXL (same story for SD1.5 and SD2). I don't like the aesthetics of it, but this leaves some space for other SD versions.
It seemed like I was trying to fit a square peg (model drivers that were actually specific to pipeline types, not models) into a round hole (model drivers meant to define interactions with a model), so I refactored these out to a new type: |
Describe your changes
This PR introduces drivers that can be used to generate images using Stable Diffusion 3 locally.
A model-agnostic HuggingFaceDiffusionPipelineImageGenerationDriver manages creating and running inferences on a HuggingFace diffusers pipeline. New model drivers: StableDiffusion3PipelineImageGenerationModelDriver, StableDiffusion3Img2ImgPipelineImageGenerationModelDriver, and StableDiffusion3ControlNetPipelineImageGenerationDriver extend the BaseDiffusionPipelineImageGenerationModelDriver to specify how to prepare the inference pipeline and format pipeline inputs.