Skip to content

Commit da21d59

Browse files
DN6a-r-r-o-w
andauthored
[Single File] Add Single File support for HunYuan video (#10320)
* update * Update src/diffusers/loaders/single_file_utils.py Co-authored-by: Aryan <[email protected]> --------- Co-authored-by: Aryan <[email protected]>
1 parent 7c2f0af commit da21d59

File tree

3 files changed

+145
-2
lines changed

3 files changed

+145
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
convert_autoencoder_dc_checkpoint_to_diffusers,
2929
convert_controlnet_checkpoint,
3030
convert_flux_transformer_checkpoint_to_diffusers,
31+
convert_hunyuan_video_transformer_to_diffusers,
3132
convert_ldm_unet_checkpoint,
3233
convert_ldm_vae_checkpoint,
3334
convert_ltx_transformer_checkpoint_to_diffusers,
@@ -101,6 +102,10 @@
101102
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
102103
"default_subfolder": "transformer",
103104
},
105+
"HunyuanVideoTransformer3DModel": {
106+
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107+
"default_subfolder": "transformer",
108+
},
104109
}
105110

106111

@@ -220,6 +225,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
220225
local_files_only = kwargs.pop("local_files_only", None)
221226
subfolder = kwargs.pop("subfolder", None)
222227
revision = kwargs.pop("revision", None)
228+
config_revision = kwargs.pop("config_revision", None)
223229
torch_dtype = kwargs.pop("torch_dtype", None)
224230
quantization_config = kwargs.pop("quantization_config", None)
225231
device = kwargs.pop("device", None)
@@ -297,7 +303,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
297303
subfolder=subfolder,
298304
local_files_only=local_files_only,
299305
token=token,
300-
revision=revision,
306+
revision=config_revision,
301307
)
302308
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
303309

src/diffusers/loaders/single_file_utils.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
109109
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
110110
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111+
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
111112
}
112113

113114
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -162,6 +163,7 @@
162163
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
163164
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
164165
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
166+
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
165167
}
166168

167169
# Use to configure model sample size when original config is provided
@@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint):
624626
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
625627
model_type = "mochi-1-preview"
626628

629+
if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
630+
model_type = "hunyuan-video"
631+
627632
else:
628633
model_type = "v1"
629634

@@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
25222527
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
25232528

25242529
return new_state_dict
2530+
2531+
2532+
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
2533+
def remap_norm_scale_shift_(key, state_dict):
2534+
weight = state_dict.pop(key)
2535+
shift, scale = weight.chunk(2, dim=0)
2536+
new_weight = torch.cat([scale, shift], dim=0)
2537+
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
2538+
2539+
def remap_txt_in_(key, state_dict):
2540+
def rename_key(key):
2541+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
2542+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
2543+
new_key = new_key.replace("txt_in", "context_embedder")
2544+
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
2545+
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
2546+
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
2547+
new_key = new_key.replace("mlp", "ff")
2548+
return new_key
2549+
2550+
if "self_attn_qkv" in key:
2551+
weight = state_dict.pop(key)
2552+
to_q, to_k, to_v = weight.chunk(3, dim=0)
2553+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
2554+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
2555+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
2556+
else:
2557+
state_dict[rename_key(key)] = state_dict.pop(key)
2558+
2559+
def remap_img_attn_qkv_(key, state_dict):
2560+
weight = state_dict.pop(key)
2561+
to_q, to_k, to_v = weight.chunk(3, dim=0)
2562+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
2563+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
2564+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
2565+
2566+
def remap_txt_attn_qkv_(key, state_dict):
2567+
weight = state_dict.pop(key)
2568+
to_q, to_k, to_v = weight.chunk(3, dim=0)
2569+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
2570+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
2571+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
2572+
2573+
def remap_single_transformer_blocks_(key, state_dict):
2574+
hidden_size = 3072
2575+
2576+
if "linear1.weight" in key:
2577+
linear1_weight = state_dict.pop(key)
2578+
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
2579+
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
2580+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
2581+
state_dict[f"{new_key}.attn.to_q.weight"] = q
2582+
state_dict[f"{new_key}.attn.to_k.weight"] = k
2583+
state_dict[f"{new_key}.attn.to_v.weight"] = v
2584+
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
2585+
2586+
elif "linear1.bias" in key:
2587+
linear1_bias = state_dict.pop(key)
2588+
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
2589+
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
2590+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
2591+
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
2592+
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
2593+
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
2594+
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
2595+
2596+
else:
2597+
new_key = key.replace("single_blocks", "single_transformer_blocks")
2598+
new_key = new_key.replace("linear2", "proj_out")
2599+
new_key = new_key.replace("q_norm", "attn.norm_q")
2600+
new_key = new_key.replace("k_norm", "attn.norm_k")
2601+
state_dict[new_key] = state_dict.pop(key)
2602+
2603+
TRANSFORMER_KEYS_RENAME_DICT = {
2604+
"img_in": "x_embedder",
2605+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
2606+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
2607+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
2608+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
2609+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
2610+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
2611+
"double_blocks": "transformer_blocks",
2612+
"img_attn_q_norm": "attn.norm_q",
2613+
"img_attn_k_norm": "attn.norm_k",
2614+
"img_attn_proj": "attn.to_out.0",
2615+
"txt_attn_q_norm": "attn.norm_added_q",
2616+
"txt_attn_k_norm": "attn.norm_added_k",
2617+
"txt_attn_proj": "attn.to_add_out",
2618+
"img_mod.linear": "norm1.linear",
2619+
"img_norm1": "norm1.norm",
2620+
"img_norm2": "norm2",
2621+
"img_mlp": "ff",
2622+
"txt_mod.linear": "norm1_context.linear",
2623+
"txt_norm1": "norm1.norm",
2624+
"txt_norm2": "norm2_context",
2625+
"txt_mlp": "ff_context",
2626+
"self_attn_proj": "attn.to_out.0",
2627+
"modulation.linear": "norm.linear",
2628+
"pre_norm": "norm.norm",
2629+
"final_layer.norm_final": "norm_out.norm",
2630+
"final_layer.linear": "proj_out",
2631+
"fc1": "net.0.proj",
2632+
"fc2": "net.2",
2633+
"input_embedder": "proj_in",
2634+
}
2635+
2636+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
2637+
"txt_in": remap_txt_in_,
2638+
"img_attn_qkv": remap_img_attn_qkv_,
2639+
"txt_attn_qkv": remap_txt_attn_qkv_,
2640+
"single_blocks": remap_single_transformer_blocks_,
2641+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
2642+
}
2643+
2644+
def update_state_dict_(state_dict, old_key, new_key):
2645+
state_dict[new_key] = state_dict.pop(old_key)
2646+
2647+
for key in list(checkpoint.keys()):
2648+
new_key = key[:]
2649+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2650+
new_key = new_key.replace(replace_key, rename_key)
2651+
update_state_dict_(checkpoint, key, new_key)
2652+
2653+
for key in list(checkpoint.keys()):
2654+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2655+
if special_key not in key:
2656+
continue
2657+
handler_fn_inplace(key, checkpoint)
2658+
2659+
return checkpoint

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020

21+
from diffusers.loaders import FromOriginalModelMixin
22+
2123
from ...configuration_utils import ConfigMixin, register_to_config
2224
from ...loaders import PeftAdapterMixin
2325
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -500,7 +502,7 @@ def forward(
500502
return hidden_states, encoder_hidden_states
501503

502504

503-
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
505+
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
504506
r"""
505507
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
506508

0 commit comments

Comments
 (0)