Skip to content

Commit

Permalink
[Flax SDXL] fix zero out sdxl (#5203)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored Sep 27, 2023
1 parent a584d42 commit cac7ada
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def _generate(
# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]
if neg_prompt_ids is None:
neg_prompt_ids = self.prepare_inputs([""] * batch_size)

neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
neg_prompt_embeds = jnp.zeros_like(prompt_embeds)
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
else:
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)

add_time_ids = self._get_add_time_ids(
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
Expand Down

0 comments on commit cac7ada

Please sign in to comment.