Skip to content

Commit

Permalink
fix for #7365, prevent pipelines from overriding provided prompt embe…
Browse files Browse the repository at this point in the history
…ds (#7926)

* fix for #7365, prevent pipelines from overriding provided prompt embeds

* fix-copies

* fix implementation

* update

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: Aryan <[email protected]>
Co-authored-by: sayakpaul <[email protected]>
  • Loading branch information
4 people authored Jan 8, 2025
1 parent 5655b22 commit a0acbdc
Show file tree
Hide file tree
Showing 27 changed files with 154 additions and 54 deletions.
7 changes: 5 additions & 2 deletions examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,9 @@ def encode_prompt(
)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

prompt_embeds = prompt_embeds.hidden_states[-2]

prompt_embeds_list.append(prompt_embeds)
Expand Down Expand Up @@ -879,7 +881,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
7 changes: 5 additions & 2 deletions examples/community/pipeline_demofusion_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def encode_prompt(
)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

prompt_embeds = prompt_embeds.hidden_states[-2]

prompt_embeds_list.append(prompt_embeds)
Expand Down Expand Up @@ -342,7 +344,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
7 changes: 5 additions & 2 deletions examples/community/pipeline_sdxl_style_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -688,7 +690,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -419,7 +421,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -567,7 +569,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -454,7 +456,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
7 changes: 5 additions & 2 deletions examples/community/pipeline_stable_diffusion_xl_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -450,7 +452,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -497,8 +499,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -465,8 +467,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -474,8 +476,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -467,8 +469,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -447,8 +449,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -456,8 +458,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -481,8 +483,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -395,8 +397,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
Expand Down Expand Up @@ -480,8 +482,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)
Expand Down
Loading

0 comments on commit a0acbdc

Please sign in to comment.