Skip to content

Commit

Permalink
Make zeroing prompt embeds for Mochi Pipeline configurable (#10284)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
DN6 authored Dec 18, 2024
1 parent b389f33 commit 8304adc
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/mochi/pipeline_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()

Expand All @@ -205,10 +206,11 @@ def __init__(

self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
)
self.default_height = 480
self.default_width = 848
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)

def _get_t5_prompt_embeds(
self,
Expand Down Expand Up @@ -236,7 +238,11 @@ def _get_t5_prompt_embeds(
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
if prompt == "" or prompt[-1] == "":

# The original Mochi implementation zeros out empty negative prompts
# but this can lead to overflow when placing the entire pipeline under the autocast context
# adding this here so that we can enable zeroing prompts if necessary
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)

Expand Down

0 comments on commit 8304adc

Please sign in to comment.