From a0a542702869f963baeea63847f17a3145ca2287 Mon Sep 17 00:00:00 2001 From: Nan <510934379@qq.com> Date: Wed, 19 Jun 2024 06:41:18 +0800 Subject: [PATCH] [SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5 (text_encoder_3=None) (#8558) * fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None * style * fix copies --------- Co-authored-by: YiYi Xu Co-authored-by: yiyixuxu --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 6 +++++- .../pipeline_stable_diffusion_3_img2img.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 04a150d52c0a..b8fd0b907684 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -217,7 +217,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 1ad8a8f3c42b..0cb61729ba7b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -232,7 +232,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, )