|
108 | 108 | "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
109 | 109 | "autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
110 | 110 | "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
| 111 | + "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", |
111 | 112 | }
|
112 | 113 |
|
113 | 114 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
162 | 163 | "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
163 | 164 | "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
164 | 165 | "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
| 166 | + "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, |
165 | 167 | }
|
166 | 168 |
|
167 | 169 | # Use to configure model sample size when original config is provided
|
@@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint):
|
624 | 626 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
625 | 627 | model_type = "mochi-1-preview"
|
626 | 628 |
|
| 629 | + if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: |
| 630 | + model_type = "hunyuan-video" |
| 631 | + |
627 | 632 | else:
|
628 | 633 | model_type = "v1"
|
629 | 634 |
|
@@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
2522 | 2527 | new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
2523 | 2528 |
|
2524 | 2529 | return new_state_dict
|
| 2530 | + |
| 2531 | + |
| 2532 | +def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): |
| 2533 | + def remap_norm_scale_shift_(key, state_dict): |
| 2534 | + weight = state_dict.pop(key) |
| 2535 | + shift, scale = weight.chunk(2, dim=0) |
| 2536 | + new_weight = torch.cat([scale, shift], dim=0) |
| 2537 | + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight |
| 2538 | + |
| 2539 | + def remap_txt_in_(key, state_dict): |
| 2540 | + def rename_key(key): |
| 2541 | + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") |
| 2542 | + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") |
| 2543 | + new_key = new_key.replace("txt_in", "context_embedder") |
| 2544 | + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") |
| 2545 | + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") |
| 2546 | + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") |
| 2547 | + new_key = new_key.replace("mlp", "ff") |
| 2548 | + return new_key |
| 2549 | + |
| 2550 | + if "self_attn_qkv" in key: |
| 2551 | + weight = state_dict.pop(key) |
| 2552 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 2553 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q |
| 2554 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k |
| 2555 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v |
| 2556 | + else: |
| 2557 | + state_dict[rename_key(key)] = state_dict.pop(key) |
| 2558 | + |
| 2559 | + def remap_img_attn_qkv_(key, state_dict): |
| 2560 | + weight = state_dict.pop(key) |
| 2561 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 2562 | + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q |
| 2563 | + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k |
| 2564 | + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v |
| 2565 | + |
| 2566 | + def remap_txt_attn_qkv_(key, state_dict): |
| 2567 | + weight = state_dict.pop(key) |
| 2568 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 2569 | + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q |
| 2570 | + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k |
| 2571 | + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v |
| 2572 | + |
| 2573 | + def remap_single_transformer_blocks_(key, state_dict): |
| 2574 | + hidden_size = 3072 |
| 2575 | + |
| 2576 | + if "linear1.weight" in key: |
| 2577 | + linear1_weight = state_dict.pop(key) |
| 2578 | + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) |
| 2579 | + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) |
| 2580 | + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") |
| 2581 | + state_dict[f"{new_key}.attn.to_q.weight"] = q |
| 2582 | + state_dict[f"{new_key}.attn.to_k.weight"] = k |
| 2583 | + state_dict[f"{new_key}.attn.to_v.weight"] = v |
| 2584 | + state_dict[f"{new_key}.proj_mlp.weight"] = mlp |
| 2585 | + |
| 2586 | + elif "linear1.bias" in key: |
| 2587 | + linear1_bias = state_dict.pop(key) |
| 2588 | + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) |
| 2589 | + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) |
| 2590 | + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") |
| 2591 | + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias |
| 2592 | + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias |
| 2593 | + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias |
| 2594 | + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias |
| 2595 | + |
| 2596 | + else: |
| 2597 | + new_key = key.replace("single_blocks", "single_transformer_blocks") |
| 2598 | + new_key = new_key.replace("linear2", "proj_out") |
| 2599 | + new_key = new_key.replace("q_norm", "attn.norm_q") |
| 2600 | + new_key = new_key.replace("k_norm", "attn.norm_k") |
| 2601 | + state_dict[new_key] = state_dict.pop(key) |
| 2602 | + |
| 2603 | + TRANSFORMER_KEYS_RENAME_DICT = { |
| 2604 | + "img_in": "x_embedder", |
| 2605 | + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", |
| 2606 | + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", |
| 2607 | + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", |
| 2608 | + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", |
| 2609 | + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", |
| 2610 | + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", |
| 2611 | + "double_blocks": "transformer_blocks", |
| 2612 | + "img_attn_q_norm": "attn.norm_q", |
| 2613 | + "img_attn_k_norm": "attn.norm_k", |
| 2614 | + "img_attn_proj": "attn.to_out.0", |
| 2615 | + "txt_attn_q_norm": "attn.norm_added_q", |
| 2616 | + "txt_attn_k_norm": "attn.norm_added_k", |
| 2617 | + "txt_attn_proj": "attn.to_add_out", |
| 2618 | + "img_mod.linear": "norm1.linear", |
| 2619 | + "img_norm1": "norm1.norm", |
| 2620 | + "img_norm2": "norm2", |
| 2621 | + "img_mlp": "ff", |
| 2622 | + "txt_mod.linear": "norm1_context.linear", |
| 2623 | + "txt_norm1": "norm1.norm", |
| 2624 | + "txt_norm2": "norm2_context", |
| 2625 | + "txt_mlp": "ff_context", |
| 2626 | + "self_attn_proj": "attn.to_out.0", |
| 2627 | + "modulation.linear": "norm.linear", |
| 2628 | + "pre_norm": "norm.norm", |
| 2629 | + "final_layer.norm_final": "norm_out.norm", |
| 2630 | + "final_layer.linear": "proj_out", |
| 2631 | + "fc1": "net.0.proj", |
| 2632 | + "fc2": "net.2", |
| 2633 | + "input_embedder": "proj_in", |
| 2634 | + } |
| 2635 | + |
| 2636 | + TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| 2637 | + "txt_in": remap_txt_in_, |
| 2638 | + "img_attn_qkv": remap_img_attn_qkv_, |
| 2639 | + "txt_attn_qkv": remap_txt_attn_qkv_, |
| 2640 | + "single_blocks": remap_single_transformer_blocks_, |
| 2641 | + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, |
| 2642 | + } |
| 2643 | + |
| 2644 | + def update_state_dict_(state_dict, old_key, new_key): |
| 2645 | + state_dict[new_key] = state_dict.pop(old_key) |
| 2646 | + |
| 2647 | + for key in list(checkpoint.keys()): |
| 2648 | + new_key = key[:] |
| 2649 | + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 2650 | + new_key = new_key.replace(replace_key, rename_key) |
| 2651 | + update_state_dict_(checkpoint, key, new_key) |
| 2652 | + |
| 2653 | + for key in list(checkpoint.keys()): |
| 2654 | + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 2655 | + if special_key not in key: |
| 2656 | + continue |
| 2657 | + handler_fn_inplace(key, checkpoint) |
| 2658 | + |
| 2659 | + return checkpoint |
0 commit comments