Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading single file CLIPEmbedding models #6813

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.3",
"vocab_size": 49408
}
12 changes: 12 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,17 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")


class CLIPEmbedCheckpointConfig(CheckpointConfigBase):
"""Model config for CLIP Embedding checkpoints."""

type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Checkpoint]

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Checkpoint.value}")


class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""

Expand Down Expand Up @@ -481,6 +492,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPEmbedCheckpointConfig, CLIPEmbedCheckpointConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
47 changes: 45 additions & 2 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
)

import invokeai.backend.assets.model_base_conf_files as model_conf_files
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
Expand All @@ -23,6 +32,7 @@
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedCheckpointConfig,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
Expand Down Expand Up @@ -71,7 +81,7 @@ def _load_model(


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
class ClipDiffusersModel(ModelLoader):
"""Class to load main models."""

def _load_model(
Expand All @@ -93,6 +103,39 @@ def _load_model(
)


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedCheckpointConfig):
raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.")

match submodel_type:
case SubModelType.Tokenizer:
# Clip embedding checkpoints don't have an integrated tokenizer, so we cheat and fetch it into the HuggingFace cache
# TODO: Fix this ugly workaround
return CLIPTokenizer.from_pretrained(
"InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer"
)
case SubModelType.TextEncoder:
config_json = CLIPTextConfig.from_json_file(Path(model_conf_files.__path__[0], config.config_path))
model = CLIPTextModel(config_json)
state_dict = load_file(config.path)
new_dict = {key: value for (key, value) in state_dict.items() if key.startswith("text_model.")}
model.load_state_dict(new_dict)
model.eval()
return model

raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
Expand Down
19 changes: 15 additions & 4 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from picklescan.scanner import scan_file_path

import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
Expand All @@ -31,6 +30,7 @@
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.silence_warnings import SilenceWarnings

CkptType = Dict[str | int, Any]
Expand Down Expand Up @@ -184,7 +184,9 @@ def probe(
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()

# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE, ModelType.CLIPEmbed] and fields[
"format"
] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
Expand All @@ -207,7 +209,6 @@ def probe(
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)

model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info

Expand Down Expand Up @@ -258,6 +259,8 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
elif key.startswith(("text_model.embeddings", "text_model.encoder")):
return ModelType.CLIPEmbed

# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
Expand Down Expand Up @@ -394,6 +397,8 @@ def _get_checkpoint_config_path(
if base_type is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
elif model_type is ModelType.CLIPEmbed:
return Path("clip_text_model", "config.json")
else:
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
Expand Down Expand Up @@ -665,6 +670,11 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class CLIPEmbedCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
Expand Down Expand Up @@ -822,7 +832,7 @@ def get_base_type(self) -> BaseModelType:
if (self.model_path / "unet" / "config.json").exists():
return super().get_base_type()
else:
logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
InvokeAILogger.get_logger().warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
return BaseModelType.StableDiffusion1

def get_format(self) -> ModelFormat:
Expand Down Expand Up @@ -956,6 +966,7 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPEmbed, CLIPEmbedCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)

Expand Down