diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c0d571a5864d..de6cd2981b96 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -270,6 +270,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/mochi_transformer3d + title: MochiTransformer3DModel - local: api/models/pixart_transformer2d title: PixArtTransformer2DModel - local: api/models/prior_transformer @@ -306,6 +308,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoderkl_mochi + title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl title: AsymmetricAutoencoderKL - local: api/models/consistency_decoder_vae @@ -400,6 +404,8 @@ title: Lumina-T2X - local: api/pipelines/marigold title: Marigold + - local: api/pipelines/mochi + title: Mochi - local: api/pipelines/panorama title: MultiDiffusion - local: api/pipelines/musicldm diff --git a/docs/source/en/api/models/autoencoderkl_mochi.md b/docs/source/en/api/models/autoencoderkl_mochi.md new file mode 100644 index 000000000000..9747de4af937 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_mochi.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLMochi + +The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLMochi + +vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLMochi + +[[autodoc]] AutoencoderKLMochi + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/mochi_transformer3d.md b/docs/source/en/api/models/mochi_transformer3d.md new file mode 100644 index 000000000000..05e28654d58c --- /dev/null +++ b/docs/source/en/api/models/mochi_transformer3d.md @@ -0,0 +1,30 @@ + + +# MochiTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo. + +The model can be loaded with the following code snippet. + +```python +from diffusers import MochiTransformer3DModel + +vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda") +``` + +## MochiTransformer3DModel + +[[autodoc]] MochiTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md new file mode 100644 index 000000000000..f29297e5901c --- /dev/null +++ b/docs/source/en/api/pipelines/mochi.md @@ -0,0 +1,36 @@ + + +# Mochi + +[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo. + +*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## MochiPipeline + +[[autodoc]] MochiPipeline + - all + - __call__ + +## MochiPipelineOutput + +[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py new file mode 100644 index 000000000000..892fd871c554 --- /dev/null +++ b/scripts/convert_mochi_to_diffusers.py @@ -0,0 +1,461 @@ +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +TOKENIZER_MAX_LENGTH = 256 + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default=None) + +args = parser.parse_args() + + +# This is specific to `AdaLayerNormContinuous`: +# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + +def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = load_file(ckpt_path, device="cpu") + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + + # Visual attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + original_state_dict.pop(old_prefix + "mlp_x.w1.weight") + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + original_state_dict.pop(old_prefix + "mlp_y.w1.weight") + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.mod.weight"), dim=0 + ) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0) + new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") + + print("Remaining Keys:", original_state_dict.keys()) + + return new_state_dict + + +def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path): + encoder_state_dict = load_file(encoder_ckpt_path, device="cpu") + decoder_state_dict = load_file(decoder_ckpt_path, device="cpu") + new_state_dict = {} + + # ==== Decoder ===== + prefix = "decoder." + + # Convert conv_in + new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight") + new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias") + + # Convert block_in (MochiMidBlock3D) + for i in range(3): # layers_per_block[-1] = 3 + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.5.bias" + ) + + # Convert up_blocks (MochiUpBlock3D) + down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4] + for block in range(3): + for i in range(down_block_layers[block]): + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.0.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.0.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.2.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.2.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.3.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.3.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.5.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.5.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.proj.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias") + + # Convert block_out (MochiMidBlock3D) + for i in range(3): # layers_per_block[0] = 3 + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.5.bias" + ) + + # Convert proj_out (Conv1x1 ~= nn.Linear) + new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight") + new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias") + + print("Remaining Decoder Keys:", decoder_state_dict.keys()) + + # ==== Encoder ===== + prefix = "encoder." + + new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight") + new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias") + + # Convert block_in (MochiMidBlock3D) + for i in range(3): # layers_per_block[0] = 3 + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.5.bias" + ) + + # Convert down_blocks (MochiDownBlock3D) + down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3] + for block in range(3): + new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.0.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.0.bias" + ) + + for i in range(down_block_layers[block]): + # Convert resnets + new_state_dict[ + f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" + ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.2.bias" + ) + new_state_dict[ + f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" + ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight") + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.5.bias" + ) + + # Convert attentions + qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" + ) + + # Convert block_out (MochiMidBlock3D) + for i in range(3): # layers_per_block[-1] = 3 + # Convert resnets + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.5.bias" + ) + + # Convert attentions + qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q + new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k + new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v + new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.attn.out.weight" + ) + new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.attn.out.bias" + ) + new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.norm.weight" + ) + new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.norm.bias" + ) + + # Convert output layers + new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight") + new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias") + new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight") + + print("Remaining Encoder Keys:", encoder_state_dict.keys()) + + return new_state_dict + + +def main(args): + if args.dtype is None: + dtype = None + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = MochiTransformer3DModel() + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + transformer = transformer.to(dtype=dtype) + + if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None: + vae = AutoencoderKLMochi(latent_channels=12, out_channels=3) + converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers( + args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path + ) + vae.load_state_dict(converted_vae_state_dict, strict=True) + if dtype is not None: + vae = vae.to(dtype=dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + pipe = MochiPipeline( + scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True), + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ff59a3839552..fb6d22084bd6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", @@ -102,6 +103,7 @@ "Kandinsky3UNet", "LatteTransformer3DModel", "LuminaNextDiT2DModel", + "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -311,6 +313,7 @@ "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", + "MochiPipeline", "MusicLDMPipeline", "PaintByExamplePipeline", "PIAPipeline", @@ -565,6 +568,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, @@ -584,6 +588,7 @@ Kandinsky3UNet, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, ModelMixin, MotionAdapter, MultiAdapter, @@ -772,6 +777,7 @@ LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, + MochiPipeline, MusicLDMPipeline, PaintByExamplePipeline, PIAPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 38dd2819133d..518ab6df65c4 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -30,6 +30,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] @@ -58,6 +59,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] @@ -85,6 +87,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, @@ -110,6 +113,7 @@ HunyuanDiT2DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index fb24a36bae75..f4318fc3cd39 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -136,6 +136,7 @@ class SwiGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.activation = nn.SiLU() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 20c5cf3d925e..da01b7a1edcd 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -120,14 +120,16 @@ def __init__( _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, + out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, + is_causal: bool = False, ): super().__init__() # To prevent circular import. - from .normalization import FP32LayerNorm, RMSNorm + from .normalization import FP32LayerNorm, LpNorm, RMSNorm self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads @@ -142,8 +144,10 @@ def __init__( self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only + self.is_causal = is_causal # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -192,6 +196,9 @@ def __init__( elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") @@ -241,7 +248,7 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -1886,6 +1893,7 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @@ -2714,6 +2722,91 @@ def __call__( return hidden_states +class MochiVaeAttnProcessor2_0: + r""" + Attention processor used in Mochi VAE. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + is_single_frame = hidden_states.shape[1] == 1 + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if is_single_frame: + hidden_states = attn.to_v(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class StableAudioAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -3389,6 +3482,94 @@ def __call__( return hidden_states +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 9628fe7f21b0..ba45d6671252 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -2,6 +2,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 68b49d72acc5..8575c7658605 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -94,11 +94,13 @@ def __init__( time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 + # TODO(aryan): configure calculation based on stride and dilation in the future. + # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi + time_pad = time_kernel_size - 1 + height_pad = (height_kernel_size - 1) // 2 + width_pad = (width_kernel_size - 1) // 2 + self.pad_mode = pad_mode self.height_pad = height_pad self.width_pad = width_pad self.time_pad = time_pad @@ -107,7 +109,7 @@ def __init__( self.temporal_dim = 2 self.time_kernel_size = time_kernel_size - stride = (stride, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = CogVideoXSafeConv3d( in_channels=in_channels, @@ -120,18 +122,24 @@ def __init__( def fake_context_parallel_forward( self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None ) -> torch.Tensor: - kernel_size = self.time_kernel_size - if kernel_size > 1: - cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - inputs = torch.cat(cached_inputs + [inputs], dim=2) + if self.pad_mode == "replicate": + inputs = F.pad(inputs, self.time_causal_padding, mode="replicate") + else: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: inputs = self.fake_context_parallel_forward(inputs, conv_cache) - conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs = F.pad(inputs, padding_2d, mode="constant", value=0) + if self.pad_mode == "replicate": + conv_cache = None + else: + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + inputs = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs) return output, conv_cache diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py new file mode 100644 index 000000000000..57e8b8f647ba --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -0,0 +1,1165 @@ +# Copyright 2024 The Mochi team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MochiChunkedGroupNorm3D(nn.Module): + r""" + Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group + normalization. + + Args: + num_channels (int): Number of channels expected in input + num_groups (int, optional): Number of groups to separate the channels into. Default: 32 + affine (bool, optional): If True, this module has learnable affine parameters. Default: True + chunk_size (int, optional): Size of each chunk for processing. Default: 8 + + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + affine: bool = True, + chunk_size: int = 8, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) + self.chunk_size = chunk_size + + def forward(self, x: torch.Tensor = None) -> torch.Tensor: + batch_size = x.size(0) + + x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) + output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0) + output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + return output + + +class MochiResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + act_fn: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(act_fn) + + self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + + def forward( + self, + inputs: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) + + hidden_states = hidden_states + inputs + return hidden_states, new_conv_cache + + +class MochiDownBlock3D(nn.Module): + r""" + An downsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + add_attention: bool = True, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + self.conv_in = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion), + stride=(temporal_expansion, spatial_expansion, spatial_expansion), + pad_mode="replicate", + ) + + resnets = [] + norms = [] + attentions = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=out_channels)) + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels)) + attentions.append( + Attention( + query_dim=out_channels, + heads=out_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + chunk_size: int = 2**15, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states) + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + conv_cache=conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + + # Perform attention in chunks to avoid following error: + # RuntimeError: CUDA error: invalid configuration argument + if hidden_states.size(0) <= chunk_size: + hidden_states = attn(hidden_states) + else: + hidden_states_chunks = [] + for i in range(0, hidden_states.size(0), chunk_size): + hidden_states_chunk = hidden_states[i : i + chunk_size] + hidden_states_chunk = attn(hidden_states_chunk) + hidden_states_chunks.append(hidden_states_chunk) + hidden_states = torch.cat(hidden_states_chunks) + + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiMidBlock3D(nn.Module): + r""" + A middle block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `3`): + Number of resnet blocks in the block. + """ + + def __init__( + self, + in_channels: int, # 768 + num_layers: int = 3, + add_attention: bool = True, + ): + super().__init__() + + resnets = [] + norms = [] + attentions = [] + + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels)) + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + hidden_states = attn(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiUpBlock3D(nn.Module): + r""" + An upsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + resnets = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + self.resnets = nn.ModuleList(resnets) + + self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + conv_cache=conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + st = self.temporal_expansion + sh = self.spatial_expansion + sw = self.spatial_expansion + + # Reshape and unpatchify + hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw) + + return hidden_states, new_conv_cache + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `FourierFeatures` class.""" + + num_channels = inputs.shape[1] + num_freqs = (self.stop - self.start) // self.step + + freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1) + + +class MochiEncoder3D(nn.Module): + r""" + The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.fourier_features = FourierFeatures() + self.proj_in = nn.Linear(in_channels, block_out_channels[0]) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0] + ) + + down_blocks = [] + for i in range(len(block_out_channels) - 1): + down_block = MochiDownBlock3D( + in_channels=block_out_channels[i], + out_channels=block_out_channels[i + 1], + num_layers=layers_per_block[i + 1], + temporal_expansion=temporal_expansions[i], + spatial_expansion=spatial_expansions[i], + add_attention=add_attention_block[i + 1], + ) + down_blocks.append(down_block) + self.down_blocks = nn.ModuleList(down_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1] + ) + self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1]) + self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False) + + def forward( + self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + r"""Forward method of the `MochiEncoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.fourier_features(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class MochiDecoder3D(nn.Module): + r""" + The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, # 12 + out_channels: int, # 3 + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_per_block[-1], + add_attention=False, + ) + + up_blocks = [] + for i in range(len(block_out_channels) - 1): + up_block = MochiUpBlock3D( + in_channels=block_out_channels[-i - 1], + out_channels=block_out_channels[-i - 2], + num_layers=layers_per_block[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + ) + up_blocks.append(up_block) + self.up_blocks = nn.ModuleList(up_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[0], + num_layers=layers_per_block[0], + add_attention=False, + ) + self.proj_out = nn.Linear(block_out_channels[0], out_channels) + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + r"""Forward method of the `MochiDecoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.conv_in(hidden_states) + + # 1. Mid + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class AutoencoderKLMochi(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [Mochi 1 preview](https://github.com/genmoai/models). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["MochiResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 15, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384), + decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768), + latent_channels: int = 12, + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + act_fn: str = "silu", + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), + latents_mean: Tuple[float, ...] = ( + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, + ), + latents_std: Tuple[float, ...] = ( + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, + ), + scaling_factor: float = 1.0, + ): + super().__init__() + + self.encoder = MochiEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=encoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + add_attention_block=add_attention_block, + act_fn=act_fn, + ) + self.decoder = MochiDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + act_fn=act_fn, + ) + + self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1) + self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with + # the original implementation, this defaults to `True`. + # - Original implementation (drop_last_temporal_frames=True): + # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1 + # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False): + # Output frames = latent_frames * temporal_compression_ratio + # The latter case is useful for frame packing and some training/finetuning scenarios where the additional. + self.drop_last_temporal_frames = True + + # This can be configured based on the amount of GPU memory available. + # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 12 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MochiEncoder3D, MochiDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _enable_framewise_encoding(self): + r""" + Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the + oneshot encoding implementation without current latent replicate padding. + + Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable + framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect. + """ + self.use_framewise_encoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _enable_framewise_decoding(self): + r""" + Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the + oneshot decoding implementation without current latent replicate padding. + """ + self.use_framewise_decoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + enc, _ = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.use_framewise_decoding: + conv_cache = None + dec = [] + + for i in range(0, num_frames, self.num_latent_frames_batch_size): + z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + dec = torch.cat(dec, dim=2) + else: + dec, _ = self.decoder(z) + + if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio: + dec = dec[:, :, self.temporal_compression_ratio - 1 :] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + time, _ = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + time = [] + conv_cache = None + + for k in range(0, num_frames, self.num_latent_frames_batch_size): + tile = z[ + :, + :, + k : k + self.num_latent_frames_batch_size, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + time = torch.cat(time, dim=2) + else: + time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + + if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: + time = time[:, :, self.temporal_compression_ratio - 1 :] + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 66917dce6107..7cbd958e1d6e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1356,6 +1356,41 @@ def forward(self, timestep, caption_feat, caption_mask): return conditioning +class MochiCombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + text_embed_dim: int, + time_embed_dim: int = 256, + num_attention_heads: int = 8, + ) -> None: + super().__init__() + + self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0) + self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) + self.pooler = MochiAttentionPool( + num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim + ) + self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim) + + def forward( + self, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ): + time_proj = self.time_proj(timestep) + time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype)) + + pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask) + caption_proj = self.caption_proj(encoder_hidden_states) + + conditioning = time_emb + pooled_projections + return conditioning, caption_proj + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -1484,6 +1519,88 @@ def shape(x): return a[:, 0, :] # cls_token +class MochiAttentionPool(nn.Module): + def __init__( + self, + num_attention_heads: int, + embed_dim: int, + output_dim: Optional[int] = None, + ) -> None: + super().__init__() + + self.output_dim = output_dim or embed_dim + self.num_attention_heads = num_attention_heads + + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim) + self.to_q = nn.Linear(embed_dim, embed_dim) + self.to_out = nn.Linear(embed_dim, self.output_dim) + + @staticmethod + def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + r""" + Args: + x (`torch.Tensor`): + Tensor of shape `(B, S, D)` of input tokens. + mask (`torch.Tensor`): + Boolean ensor of shape `(B, S)` indicating which tokens are not padding. + + Returns: + `torch.Tensor`: + `(B, D)` tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_attention_heads + kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + def get_fourier_embeds_from_boundingbox(embed_dim, box): """ Args: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 87dec66935da..817b3fff2ea6 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -234,6 +234,33 @@ def forward( return x, gate_msa, scale_mlp, gate_mlp +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). @@ -356,20 +383,21 @@ def __init__( out_dim: Optional[int] = None, ): super().__init__() + # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") - # linear_2 + + self.linear_2 = None if out_dim is not None: - self.linear_2 = nn.Linear( - embedding_dim, - out_dim, - bias=bias, - ) + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, @@ -526,3 +554,15 @@ def forward(self, x): gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * nx) + self.beta + x + + +class LpNorm(nn.Module): + def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): + super().__init__() + + self.p = p + self.dim = dim + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 873a2bbecf05..a2c087d708a4 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,5 +17,6 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py new file mode 100644 index 000000000000..7f4ad2b328fa --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -0,0 +1,387 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class MochiTransformerBlock(nn.Module): + r""" + Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + context_pre_only (`bool`, defaults to `False`): + Whether or not to process context-related conditions with additional layers. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + pooled_projection_dim: int, + qk_norm: str = "rms_norm", + activation_fn: str = "swiglu", + context_pre_only: bool = False, + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.context_pre_only = context_pre_only + self.ff_inner_dim = (4 * dim * 2) // 3 + self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 + + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) + + if not context_pre_only: + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) + else: + self.norm1_context = LuminaLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=eps, + elementwise_affine=False, + norm_type="rms_norm", + out_dim=None, + ) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=False, + qk_norm=qk_norm, + added_kv_proj_dim=pooled_projection_dim, + added_proj_bias=False, + out_dim=dim, + out_context_dim=pooled_projection_dim, + context_pre_only=context_pre_only, + processor=MochiAttnProcessor2_0(), + eps=eps, + elementwise_affine=True, + ) + + # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = None + if not context_pre_only: + self.ff_context = FeedForward( + pooled_projection_dim, + inner_dim=self.ff_context_inner_dim, + activation_fn=activation_fn, + bias=False, + ) + + self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if not self.context_pre_only: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( + encoder_hidden_states, temb + ) + else: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + + attn_hidden_states, context_attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) + + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + self.norm2_context( + context_attn_hidden_states + ) * torch.tanh(enc_gate_msa).unsqueeze(1) + norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( + enc_gate_mlp + ).unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class MochiRoPE(nn.Module): + r""" + RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + base_height (`int`, defaults to `192`): + Base height used to compute interpolation scale for rotary positional embeddings. + base_width (`int`, defaults to `192`): + Base width used to compute interpolation scale for rotary positional embeddings. + """ + + def __init__(self, base_height: int = 192, base_width: int = 192) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin + + +@maybe_allow_in_graph +class MochiTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `48`): + The number of layers of Transformer blocks to use. + in_channels (`int`, defaults to `12`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `256`): + Output dimension of timestep embeddings. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + max_sequence_length (`int`, defaults to `256`): + The maximum sequence length of text embeddings supported. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 48, + pooled_projection_dim: int = 1536, + in_channels: int = 12, + out_channels: Optional[int] = None, + qk_norm: str = "rms_norm", + text_embed_dim: int = 4096, + time_embed_dim: int = 256, + activation_fn: str = "swiglu", + max_sequence_length: int = 256, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + pos_embed_type=None, + ) + + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=inner_dim, + pooled_projection_dim=pooled_projection_dim, + text_embed_dim=text_embed_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, + ) + + self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) + self.rope = MochiRoPE() + + self.transformer_blocks = nn.ModuleList( + [ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i == num_layers - 1, + ) + for i in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = self.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 634088f1b51a..98574de1ad5f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -247,6 +247,7 @@ "MarigoldNormalsPipeline", ] ) + _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] @@ -571,6 +572,7 @@ MarigoldDepthPipeline, MarigoldNormalsPipeline, ) + from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, diff --git a/src/diffusers/pipelines/mochi/__init__.py b/src/diffusers/pipelines/mochi/__init__.py new file mode 100644 index 000000000000..a8fd4da9fd36 --- /dev/null +++ b/src/diffusers/pipelines/mochi/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_mochi"] = ["MochiPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_mochi import MochiPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py new file mode 100644 index 000000000000..7a9cc41e2dde --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -0,0 +1,724 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import MochiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MochiPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MochiPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + >>> pipe.enable_vae_tiling() + >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." + >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] + >>> export_to_video(frames, "mochi.mp4") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MochiPipeline(DiffusionPipeline): + r""" + The mochi pipeline for text-to-video generation. + + Reference: https://github.com/genmoai/models + + Args: + transformer ([`MochiTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: MochiTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + # TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_height = 480 + self.default_width = 848 + + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 19, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `self.default_height`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `self.default_width`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `19`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `4.5`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple` + is returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.default_height + width = width or self.default_width + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timestep + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + sigmas = np.array(sigmas) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return MochiPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/mochi/pipeline_output.py b/src/diffusers/pipelines/mochi/pipeline_output.py new file mode 100644 index 000000000000..cc1437279496 --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class MochiPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 937cae2e47f5..c1096dbe0c29 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -71,6 +71,7 @@ def __init__( max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, ): timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) @@ -204,9 +205,15 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8a87b04a66cb..83d1d4270920 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLMochi(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MochiTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 83d160b08df4..8b4b158efd0a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1052,6 +1052,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class MochiPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MusicLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py new file mode 100644 index 000000000000..fc1412c7cd31 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_mochi.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import MochiTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MochiTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = MochiTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + } + + @property + def input_shape(self): + return (4, 2, 16, 16) + + @property + def output_shape(self): + return (4, 2, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 2, + "num_attention_heads": 2, + "attention_head_dim": 8, + "num_layers": 2, + "pooled_projection_dim": 16, + "in_channels": 4, + "out_channels": None, + "qk_norm": "rms_norm", + "text_embed_dim": 16, + "time_embed_dim": 4, + "activation_fn": "swiglu", + "max_sequence_length": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"MochiTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/mochi/__init__.py b/tests/pipelines/mochi/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py new file mode 100644 index 000000000000..2192c171aa22 --- /dev/null +++ b/tests/pipelines/mochi/test_mochi.py @@ -0,0 +1,299 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MochiPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + pooled_projection_dim=16, + in_channels=12, + out_channels=None, + qk_norm="rms_norm", + text_embed_dim=32, + time_embed_dim=4, + activation_fn="swiglu", + max_sequence_length=16, + ) + transformer.pos_frequencies.data = transformer.pos_frequencies.new_full(transformer.pos_frequencies.shape, 0) + + torch.manual_seed(0) + vae = AutoencoderKLMochi( + latent_channels=12, + out_channels=3, + encoder_block_out_channels=(32, 32, 32, 32), + decoder_block_out_channels=(32, 32, 32, 32), + layers_per_block=(1, 1, 1, 1, 1), + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 6 * k + 1 is the recommendation + "num_frames": 7, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (7, 3, 16, 16)) + expected_video = torch.randn(7, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class MochiPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogvideox(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=480, + width=848, + num_frames=19, + generator=generator, + num_inference_steps=2, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 16, 480, 848, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}"