Skip to content

Commit

Permalink
add support for CLIPTextModel single file loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Lincoln Stein committed Sep 4, 2024
1 parent 125b459 commit ef18ecd
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.3",
"vocab_size": 49408
}
12 changes: 12 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,17 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")


class CLIPEmbedCheckpointConfig(CheckpointConfigBase):
"""Model config for CLIP Embedding checkpoints."""

type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Checkpoint]

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Checkpoint.value}")


class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""

Expand Down Expand Up @@ -478,6 +489,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPEmbedCheckpointConfig, CLIPEmbedCheckpointConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
47 changes: 45 additions & 2 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
)

import invokeai.backend.assets.model_base_conf_files as model_conf_files
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
Expand All @@ -23,6 +32,7 @@
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedCheckpointConfig,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
Expand Down Expand Up @@ -69,7 +79,7 @@ def _load_model(


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
class ClipDiffusersModel(ModelLoader):
"""Class to load main models."""

def _load_model(
Expand All @@ -91,6 +101,39 @@ def _load_model(
)


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedCheckpointConfig):
raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.")

match submodel_type:
case SubModelType.Tokenizer:
# Clip embedding checkpoints don't have an integrated tokenizer, so we cheat and fetch it into the HuggingFace cache
# TODO: Fix this ugly workaround
return CLIPTokenizer.from_pretrained(
"InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer"
)
case SubModelType.TextEncoder:
config_json = CLIPTextConfig.from_json_file(Path(model_conf_files.__path__[0], config.config_path))
model = CLIPTextModel(config_json)
state_dict = load_file(config.path)
new_dict = {key: value for (key, value) in state_dict.items() if key.startswith("text_model.")}
model.load_state_dict(new_dict)
model.eval()
return model

raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
Expand Down
19 changes: 15 additions & 4 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from picklescan.scanner import scan_file_path

import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
Expand All @@ -27,6 +26,7 @@
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.silence_warnings import SilenceWarnings

CkptType = Dict[str | int, Any]
Expand Down Expand Up @@ -180,7 +180,9 @@ def probe(
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()

# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE, ModelType.CLIPEmbed] and fields[
"format"
] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
Expand All @@ -203,7 +205,6 @@ def probe(
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)

model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info

Expand Down Expand Up @@ -252,6 +253,8 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
elif key.startswith(("text_model.embeddings", "text_model.encoder")):
return ModelType.CLIPEmbed

# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
Expand Down Expand Up @@ -388,6 +391,8 @@ def _get_checkpoint_config_path(
if base_type is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
elif model_type is ModelType.CLIPEmbed:
return Path("clip_text_model", "config.json")
else:
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
Expand Down Expand Up @@ -650,6 +655,11 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class CLIPEmbedCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
Expand Down Expand Up @@ -807,7 +817,7 @@ def get_base_type(self) -> BaseModelType:
if (self.model_path / "unet" / "config.json").exists():
return super().get_base_type()
else:
logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
InvokeAILogger.get_logger().warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
return BaseModelType.StableDiffusion1

def get_format(self) -> ModelFormat:
Expand Down Expand Up @@ -941,6 +951,7 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPEmbed, CLIPEmbedCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)

Expand Down

0 comments on commit ef18ecd

Please sign in to comment.