Skip to content

Commit

Permalink
Support using SD3/SD3.5 with T5 text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 22, 2024
1 parent d00348e commit 35bd623
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita"""

__version__ = "1.26.0"
__version__ = "1.27.0"

import importlib.util

Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __getitem__(self, key: ControlMode | UpscalerName | str):
def find(self, key: ControlMode | UpscalerName | str, allow_universal=False) -> str | None:
if key in [ControlMode.style, ControlMode.composition]:
key = ControlMode.reference # Same model with different weight types
result = self._models.resources.get(resource_id(self.kind, self.arch, key))
result = self._models.find(ResourceId(self.kind, self.arch, key))
if result is None and allow_universal and isinstance(key, ControlMode):
result = self.find(ControlMode.universal)
return result
Expand Down
8 changes: 8 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,14 @@ def load_dual_clip(self, clip_name1: str, clip_name2: str, type: str):
node = "DualCLIPLoaderGGUF"
return self.add_cached(node, 1, clip_name1=clip_name1, clip_name2=clip_name2, type=type)

def load_triple_clip(self, clip_name1: str, clip_name2: str, clip_name3: str):
node = "TripleCLIPLoader"
if any(f.endswith(".gguf") for f in (clip_name1, clip_name2, clip_name3)):
node = "TripleCLIPLoaderGGUF"
return self.add_cached(
node, 1, clip_name1=clip_name1, clip_name2=clip_name2, clip_name3=clip_name3
)

def load_vae(self, vae_name: str):
return self.add_cached("VAELoader", 1, vae_name=vae_name)

Expand Down
6 changes: 3 additions & 3 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

# Version identifier for all the resources defined here. This is used as the server version.
# It usually follows the plugin version, but not all new plugin versions also require a server update.
version = "1.26.0"
version = "1.27.0"

comfy_url = "https://github.com/comfyanonymous/ComfyUI"
comfy_version = "1b8089528502a881d0ed2918b2abd54441743dd0"
comfy_version = "8ce2a1052ca03183768da0aaa483024e58b8008c"


class CustomNode(NamedTuple):
Expand Down Expand Up @@ -56,7 +56,7 @@ class CustomNode(NamedTuple):
"GGUF",
"ComfyUI-GGUF",
"https://github.com/city96/ComfyUI-GGUF",
"454955ead3336322215a206edbd7191eb130bba0",
"68ad5fb2149ead7e4978f83f14e045ecd812a394",
["UnetLoaderGGUF", "DualCLIPLoaderGGUF"],
)
]
Expand Down
5 changes: 4 additions & 1 deletion ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
case Arch.sdxl:
clip = w.load_dual_clip(te_models["clip_g"], te_models["clip_l"], type="sdxl")
case Arch.sd3:
clip = w.load_dual_clip(te_models["clip_g"], te_models["clip_l"], type="sd3")
if te_models.find("t5"):
clip = w.load_triple_clip(te_models["clip_l"], te_models["clip_g"], te_models["t5"])
else:
clip = w.load_dual_clip(te_models["clip_g"], te_models["clip_l"], type="sd3")
case Arch.flux:
clip = w.load_dual_clip(te_models["clip_l"], te_models["t5"], type="flux")
case _:
Expand Down

0 comments on commit 35bd623

Please sign in to comment.