Skip to content

[MM] Install and load Stable Diffusion 3 models #6512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
002f824
add draft SD3 probing; there is an issue with FromOriginalControlNetM…
Jun 13, 2024
03b9d17
draft sd3 loading; probable VRAM leak when using quantized submodels
Jun 13, 2024
c79d9b9
wip: Add Initial support for select SD3 models in UI
blessedcoolant Jun 14, 2024
0c970bc
wip: add SD3 Model Loader Invocation
blessedcoolant Jun 14, 2024
ddbd2eb
wip: add Transformer Field to Node UI
blessedcoolant Jun 14, 2024
4123603
chore: remove unrequired changes to v1 workflow field types
blessedcoolant Jun 14, 2024
78f704e
tweak installer to select correct components of HF SD3 diffusers models
Jun 14, 2024
ac0396e
Merge branch 'lstein/feat/sd3-model-loading' of github.com:invoke-ai/…
Jun 14, 2024
554809c
return correct base type for sd3 VAEs
Jun 15, 2024
f65d50a
wip: basic wrapper for generating sd3 images
blessedcoolant Jun 15, 2024
423057a
add config variable to suppress loading of sd3 text_encoder_3 T5 model
Jun 16, 2024
be14fd5
fix: height and weight not working on sd3 node
blessedcoolant Jun 17, 2024
22b5c03
Revert "fix: height and weight not working on sd3 node"
blessedcoolant Jun 17, 2024
9dce4f0
scale default RAM cache by size of system evirtual memory
Jun 18, 2024
cd99ef2
Merge branch 'main' into lstein/feat/sd3-model-loading
blessedcoolant Jun 20, 2024
c403efa
fix: Make TE5 Optional
blessedcoolant Jun 20, 2024
66260fd
fix: Update Clip 3 slot title & lint issues
blessedcoolant Jun 20, 2024
445561e
add sd3 to starter models
Jun 20, 2024
95377ea
add non-commercial use message to sd3 starter; rebuild frontend
Jun 21, 2024
28f1d25
unpin dependencies; fix typo in sd3.py
Jun 21, 2024
39881d3
fix installer logic for tokenizer_3 and text_encoder_3
Jun 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
SD3MainModel = "SD3MainModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
Expand Down Expand Up @@ -125,6 +126,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
Expand All @@ -133,6 +135,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
Expand Down
9 changes: 1 addition & 8 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand Down
22 changes: 15 additions & 7 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType

from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output


class ModelIdentifierField(BaseModel):
Expand Down Expand Up @@ -54,13 +48,27 @@ class UNetField(BaseModel):
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")


class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")


class CLIPField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")


class SD3CLIPField(BaseModel):
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
tokenizer_3: Optional[ModelIdentifierField] = Field(description="Info to load tokenizer 3 submodel")
text_encoder_3: Optional[ModelIdentifierField] = Field(description="Info to load text_encoder 3 submodel")


class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
Expand Down
200 changes: 200 additions & 0 deletions invokeai/app/invocations/sd3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from contextlib import ExitStack
from typing import Optional, cast

import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from pydantic import field_validator
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Input,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import get_scheduler
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.model_manager.config import SubModelType

sd3_pipeline: Optional[StableDiffusion3Pipeline] = None


class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]

def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()


@invocation_output("sd3_model_loader_output")
class SD3ModelLoaderOutput(BaseInvocationOutput):
"""Stable Diffuion 3 base model loader output"""

transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")


@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
class SD3ModelLoaderInvocation(BaseInvocation):
"""Loads an SD3 base model, outputting its submodels."""

model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)

def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
model_key = self.model.key

if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")

transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
try:
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
except Exception:
tokenizer_3 = None
text_encoder_3 = None
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})

return SD3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=SD3CLIPField(
tokenizer_1=tokenizer_1,
text_encoder_1=text_encoder_1,
tokenizer_2=tokenizer_2,
text_encoder_2=text_encoder_2,
tokenizer_3=tokenizer_3,
text_encoder_3=text_encoder_3,
),
vae=VAEField(vae=vae),
)


@invocation(
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
)
class StableDiffusion3Invocation(BaseInvocation):
"""Generates an image using Stable Diffusion 3."""

transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
ui_order=0,
)
clip: SD3CLIPField = InputField(
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
ui_order=1,
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=2,
)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler_f",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
positive_prompt: str = InputField(default="", title="Positive Prompt")
negative_prompt: str = InputField(default="", title="Negative Prompt")
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
use_clip_3: bool = InputField(default=True, description="Use TE5 Encoder of SD3", title="Use TE5 Encoder")

seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
width: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.width,
)
height: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.height,
)

@field_validator("seed", mode="before")
def modulo_seed(cls, v: int):
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)

def invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as stack:
tokenizer_1 = stack.enter_context(context.models.load(self.clip.tokenizer_1))
tokenizer_2 = stack.enter_context(context.models.load(self.clip.tokenizer_2))
text_encoder_1 = stack.enter_context(context.models.load(self.clip.text_encoder_1))
text_encoder_2 = stack.enter_context(context.models.load(self.clip.text_encoder_2))
transformer = stack.enter_context(context.models.load(self.transformer.transformer))

assert isinstance(transformer, SD3Transformer2DModel)
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
assert isinstance(tokenizer_1, CLIPTokenizer)
assert isinstance(tokenizer_2, CLIPTokenizer)

if self.use_clip_3 and self.clip.tokenizer_3 and self.clip.text_encoder_3:
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
assert isinstance(text_encoder_3, T5EncoderModel)
assert isinstance(tokenizer_3, T5TokenizerFast)
else:
tokenizer_3 = None
text_encoder_3 = None

scheduler = get_scheduler(
context=context,
scheduler_info=self.transformer.scheduler,
scheduler_name=self.scheduler,
seed=self.seed,
)

sd3_pipeline = StableDiffusion3Pipeline(
transformer=transformer,
vae=FakeVae(),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
scheduler=scheduler,
)

results = sd3_pipeline(
self.positive_prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=self.steps,
guidance_scale=self.guidance_scale,
output_type="latent",
)

latents = cast(torch.Tensor, results.images[0])
latents = latents.unsqueeze(0)

latents_name = context.tensors.save(latents)
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)
3 changes: 2 additions & 1 deletion invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
SYSTEM_RAM_TO_CACHE_SIZE_FACTOR = 0.25 # after 60 GB, default ram cache will scale by this factor
CONFIG_SCHEMA_VERSION = "4.0.1"


Expand All @@ -45,7 +46,7 @@ def get_default_ram_cache_size() -> float:
max_ram = psutil.virtual_memory().total / GB

if max_ram >= 60:
return 15.0
return max_ram * SYSTEM_RAM_TO_CACHE_SIZE_FACTOR
if max_ram >= 30:
return 7.5
if max_ram >= 14:
Expand Down
4 changes: 4 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
StableDiffusion3 = "sd-3"
# Kandinsky2_1 = "kandinsky-2.1"


Expand All @@ -75,8 +76,11 @@ class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
Transformer = "transformer"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Expand Down
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _convert_and_load(
except IndexError:
pass

self._logger.info(f"Loading {config.key}:{submodel_type}")

cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
is_quantized: bool = False
loaded: bool = False
_locks: int = 0

Expand Down
Loading
Loading