diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 04a150d52c0a..b8fd0b907684 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -217,7 +217,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 1ad8a8f3c42b..0cb61729ba7b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -232,7 +232,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, )