Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support checkpoint bundles containing more than the transformer #6808

Merged
merged 6 commits into from
Sep 4, 2024
5 changes: 5 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.util.model_util import convert_bundle_to_flux_transformer_checkpoint
from invokeai.backend.util.silence_warnings import SilenceWarnings

try:
Expand Down Expand Up @@ -190,6 +191,8 @@ def _load_from_singlefile(
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model

Expand Down Expand Up @@ -230,5 +233,7 @@ def _load_from_singlefile(
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model
28 changes: 24 additions & 4 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,18 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
ckpt = ckpt.get("state_dict", ckpt)

for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
if key.startswith(
(
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
"double_blocks.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
)
):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
Expand Down Expand Up @@ -337,7 +348,10 @@ def _get_checkpoint_config_path(
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
if (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
):
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
Expand Down Expand Up @@ -452,7 +466,10 @@ def __init__(self, model_path: Path):

def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
if (
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")

Expand All @@ -479,7 +496,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
if (
"double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
):
return BaseModelType.Flux
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
Expand Down
26 changes: 26 additions & 0 deletions invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,29 @@ def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Ten
break

return lora_token_vector_length


def convert_bundle_to_flux_transformer_checkpoint(
transformer_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
original_state_dict: dict[str, torch.Tensor] = {}
keys_to_remove: list[str] = []

for k, v in transformer_state_dict.items():
if not k.startswith("model.diffusion_model"):
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
continue
if k.endswith("scale"):
# Scale math must be done at bfloat16 due to our current flux model
# support limitations at inference time
v = v.to(dtype=torch.bfloat16)
new_key = k.replace("model.diffusion_model.", "")
original_state_dict[new_key] = v
keys_to_remove.append(k)

# Remove processed keys from the original dictionary, leaving others in case
# other model state dicts need to be pulled
for k in keys_to_remove:
del transformer_state_dict[k]

return original_state_dict
Loading