From 8304adce2aa171f0328c882001ba76891ee661d2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 18 Dec 2024 18:32:53 +0530 Subject: [PATCH] Make zeroing prompt embeds for Mochi Pipeline configurable (#10284) update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 937575d26f98..aac4e32e33f0 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -188,6 +188,7 @@ def __init__( text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, + force_zeros_for_empty_prompt: bool = False, ): super().__init__() @@ -205,10 +206,11 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256 ) self.default_height = 480 self.default_width = 848 + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) def _get_t5_prompt_embeds( self, @@ -236,7 +238,11 @@ def _get_t5_prompt_embeds( text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) - if prompt == "" or prompt[-1] == "": + + # The original Mochi implementation zeros out empty negative prompts + # but this can lead to overflow when placing the entire pipeline under the autocast context + # adding this here so that we can enable zeroing prompts if necessary + if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""): text_input_ids = torch.zeros_like(text_input_ids, device=device) prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)