Skip to content

Commit a0acbdc

Browse files
bghirabghiraa-r-r-o-wsayakpaul
authored
fix for #7365, prevent pipelines from overriding provided prompt embeds (#7926)
* fix for #7365, prevent pipelines from overriding provided prompt embeds * fix-copies * fix implementation * update --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: sayakpaul <[email protected]>
1 parent 5655b22 commit a0acbdc

27 files changed

+154
-54
lines changed

examples/community/lpw_stable_diffusion_xl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,9 @@ def encode_prompt(
827827
)
828828

829829
# We are only ALWAYS interested in the pooled output of the final text encoder
830-
pooled_prompt_embeds = prompt_embeds[0]
830+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
831+
pooled_prompt_embeds = prompt_embeds[0]
832+
831833
prompt_embeds = prompt_embeds.hidden_states[-2]
832834

833835
prompt_embeds_list.append(prompt_embeds)
@@ -879,7 +881,8 @@ def encode_prompt(
879881
output_hidden_states=True,
880882
)
881883
# We are only ALWAYS interested in the pooled output of the final text encoder
882-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
884+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
885+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
883886
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
884887

885888
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_demofusion_sdxl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def encode_prompt(
290290
)
291291

292292
# We are only ALWAYS interested in the pooled output of the final text encoder
293-
pooled_prompt_embeds = prompt_embeds[0]
293+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
294+
pooled_prompt_embeds = prompt_embeds[0]
295+
294296
prompt_embeds = prompt_embeds.hidden_states[-2]
295297

296298
prompt_embeds_list.append(prompt_embeds)
@@ -342,7 +344,8 @@ def encode_prompt(
342344
output_hidden_states=True,
343345
)
344346
# We are only ALWAYS interested in the pooled output of the final text encoder
345-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
347+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
348+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
346349
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
347350

348351
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_sdxl_style_aligned.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,9 @@ def encode_prompt(
628628
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
629629

630630
# We are only ALWAYS interested in the pooled output of the final text encoder
631-
pooled_prompt_embeds = prompt_embeds[0]
631+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
632+
pooled_prompt_embeds = prompt_embeds[0]
633+
632634
if clip_skip is None:
633635
prompt_embeds = prompt_embeds.hidden_states[-2]
634636
else:
@@ -688,7 +690,8 @@ def encode_prompt(
688690
output_hidden_states=True,
689691
)
690692
# We are only ALWAYS interested in the pooled output of the final text encoder
691-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
693+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
694+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
692695
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
693696

694697
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,9 @@ def encode_prompt(
359359
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
360360

361361
# We are only ALWAYS interested in the pooled output of the final text encoder
362-
pooled_prompt_embeds = prompt_embeds[0]
362+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
363+
pooled_prompt_embeds = prompt_embeds[0]
364+
363365
if clip_skip is None:
364366
prompt_embeds = prompt_embeds.hidden_states[-2]
365367
else:
@@ -419,7 +421,8 @@ def encode_prompt(
419421
output_hidden_states=True,
420422
)
421423
# We are only ALWAYS interested in the pooled output of the final text encoder
422-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
424+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
425+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
423426
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
424427

425428
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,9 @@ def encode_prompt(
507507
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
508508

509509
# We are only ALWAYS interested in the pooled output of the final text encoder
510-
pooled_prompt_embeds = prompt_embeds[0]
510+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
511+
pooled_prompt_embeds = prompt_embeds[0]
512+
511513
if clip_skip is None:
512514
prompt_embeds = prompt_embeds.hidden_states[-2]
513515
else:
@@ -567,7 +569,8 @@ def encode_prompt(
567569
output_hidden_states=True,
568570
)
569571
# We are only ALWAYS interested in the pooled output of the final text encoder
570-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
572+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
573+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
571574
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
572575

573576
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_stable_diffusion_xl_differential_img2img.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,9 @@ def encode_prompt(
394394
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
395395

396396
# We are only ALWAYS interested in the pooled output of the final text encoder
397-
pooled_prompt_embeds = prompt_embeds[0]
397+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
398+
pooled_prompt_embeds = prompt_embeds[0]
399+
398400
if clip_skip is None:
399401
prompt_embeds = prompt_embeds.hidden_states[-2]
400402
else:
@@ -454,7 +456,8 @@ def encode_prompt(
454456
output_hidden_states=True,
455457
)
456458
# We are only ALWAYS interested in the pooled output of the final text encoder
457-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
459+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
460+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
458461
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
459462

460463
negative_prompt_embeds_list.append(negative_prompt_embeds)

examples/community/pipeline_stable_diffusion_xl_ipex.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ def encode_prompt(
390390
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
391391

392392
# We are only ALWAYS interested in the pooled output of the final text encoder
393-
pooled_prompt_embeds = prompt_embeds[0]
393+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
394+
pooled_prompt_embeds = prompt_embeds[0]
395+
394396
if clip_skip is None:
395397
prompt_embeds = prompt_embeds.hidden_states[-2]
396398
else:
@@ -450,7 +452,8 @@ def encode_prompt(
450452
output_hidden_states=True,
451453
)
452454
# We are only ALWAYS interested in the pooled output of the final text encoder
453-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
455+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
456+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
454457
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
455458

456459
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,9 @@ def encode_prompt(
438438
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
439439

440440
# We are only ALWAYS interested in the pooled output of the final text encoder
441-
pooled_prompt_embeds = prompt_embeds[0]
441+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
442+
pooled_prompt_embeds = prompt_embeds[0]
443+
442444
if clip_skip is None:
443445
prompt_embeds = prompt_embeds.hidden_states[-2]
444446
else:
@@ -497,8 +499,10 @@ def encode_prompt(
497499
uncond_input.input_ids.to(device),
498500
output_hidden_states=True,
499501
)
502+
500503
# We are only ALWAYS interested in the pooled output of the final text encoder
501-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
504+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
505+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
502506
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
503507

504508
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,9 @@ def encode_prompt(
406406
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
407407

408408
# We are only ALWAYS interested in the pooled output of the final text encoder
409-
pooled_prompt_embeds = prompt_embeds[0]
409+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
410+
pooled_prompt_embeds = prompt_embeds[0]
411+
410412
if clip_skip is None:
411413
prompt_embeds = prompt_embeds.hidden_states[-2]
412414
else:
@@ -465,8 +467,10 @@ def encode_prompt(
465467
uncond_input.input_ids.to(device),
466468
output_hidden_states=True,
467469
)
470+
468471
# We are only ALWAYS interested in the pooled output of the final text encoder
469-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
472+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
473+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
470474
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
471475

472476
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def encode_prompt(
415415
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
416416

417417
# We are only ALWAYS interested in the pooled output of the final text encoder
418-
pooled_prompt_embeds = prompt_embeds[0]
418+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
419+
pooled_prompt_embeds = prompt_embeds[0]
420+
419421
if clip_skip is None:
420422
prompt_embeds = prompt_embeds.hidden_states[-2]
421423
else:
@@ -474,8 +476,10 @@ def encode_prompt(
474476
uncond_input.input_ids.to(device),
475477
output_hidden_states=True,
476478
)
479+
477480
# We are only ALWAYS interested in the pooled output of the final text encoder
478-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
481+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
482+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
479483
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
480484

481485
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def encode_prompt(
408408
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
409409

410410
# We are only ALWAYS interested in the pooled output of the final text encoder
411-
pooled_prompt_embeds = prompt_embeds[0]
411+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
412+
pooled_prompt_embeds = prompt_embeds[0]
413+
412414
if clip_skip is None:
413415
prompt_embeds = prompt_embeds.hidden_states[-2]
414416
else:
@@ -467,8 +469,10 @@ def encode_prompt(
467469
uncond_input.input_ids.to(device),
468470
output_hidden_states=True,
469471
)
472+
470473
# We are only ALWAYS interested in the pooled output of the final text encoder
471-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
474+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
475+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
472476
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
473477

474478
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def encode_prompt(
388388
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
389389

390390
# We are only ALWAYS interested in the pooled output of the final text encoder
391-
pooled_prompt_embeds = prompt_embeds[0]
391+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
392+
pooled_prompt_embeds = prompt_embeds[0]
393+
392394
if clip_skip is None:
393395
prompt_embeds = prompt_embeds.hidden_states[-2]
394396
else:
@@ -447,8 +449,10 @@ def encode_prompt(
447449
uncond_input.input_ids.to(device),
448450
output_hidden_states=True,
449451
)
452+
450453
# We are only ALWAYS interested in the pooled output of the final text encoder
451-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
454+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
455+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
452456
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
453457

454458
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def encode_prompt(
397397
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
398398

399399
# We are only ALWAYS interested in the pooled output of the final text encoder
400-
pooled_prompt_embeds = prompt_embeds[0]
400+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
401+
pooled_prompt_embeds = prompt_embeds[0]
402+
401403
if clip_skip is None:
402404
prompt_embeds = prompt_embeds.hidden_states[-2]
403405
else:
@@ -456,8 +458,10 @@ def encode_prompt(
456458
uncond_input.input_ids.to(device),
457459
output_hidden_states=True,
458460
)
461+
459462
# We are only ALWAYS interested in the pooled output of the final text encoder
460-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
463+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
464+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
461465
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
462466

463467
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ def encode_prompt(
422422
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
423423

424424
# We are only ALWAYS interested in the pooled output of the final text encoder
425-
pooled_prompt_embeds = prompt_embeds[0]
425+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
426+
pooled_prompt_embeds = prompt_embeds[0]
427+
426428
if clip_skip is None:
427429
prompt_embeds = prompt_embeds.hidden_states[-2]
428430
else:
@@ -481,8 +483,10 @@ def encode_prompt(
481483
uncond_input.input_ids.to(device),
482484
output_hidden_states=True,
483485
)
486+
484487
# We are only ALWAYS interested in the pooled output of the final text encoder
485-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
488+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
489+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
486490
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
487491

488492
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def encode_prompt(
336336
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
337337

338338
# We are only ALWAYS interested in the pooled output of the final text encoder
339-
pooled_prompt_embeds = prompt_embeds[0]
339+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
340+
pooled_prompt_embeds = prompt_embeds[0]
341+
340342
if clip_skip is None:
341343
prompt_embeds = prompt_embeds.hidden_states[-2]
342344
else:
@@ -395,8 +397,10 @@ def encode_prompt(
395397
uncond_input.input_ids.to(device),
396398
output_hidden_states=True,
397399
)
400+
398401
# We are only ALWAYS interested in the pooled output of the final text encoder
399-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
402+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
403+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
400404
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
401405

402406
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,9 @@ def encode_prompt(
421421
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
422422

423423
# We are only ALWAYS interested in the pooled output of the final text encoder
424-
pooled_prompt_embeds = prompt_embeds[0]
424+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
425+
pooled_prompt_embeds = prompt_embeds[0]
426+
425427
if clip_skip is None:
426428
prompt_embeds = prompt_embeds.hidden_states[-2]
427429
else:
@@ -480,8 +482,10 @@ def encode_prompt(
480482
uncond_input.input_ids.to(device),
481483
output_hidden_states=True,
482484
)
485+
483486
# We are only ALWAYS interested in the pooled output of the final text encoder
484-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
487+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
488+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
485489
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
486490

487491
negative_prompt_embeds_list.append(negative_prompt_embeds)

0 commit comments

Comments
 (0)