diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index b8651f721ca9..3acb5ae538a4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -188,9 +188,10 @@ def _generate( # Get unconditional embeddings batch_size = prompt_embeds.shape[0] if neg_prompt_ids is None: - neg_prompt_ids = self.prepare_inputs([""] * batch_size) - - neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + neg_prompt_embeds = jnp.zeros_like(prompt_embeds) + negative_pooled_embeds = jnp.zeros_like(pooled_embeds) + else: + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) add_time_ids = self._get_add_time_ids( (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype