File tree Expand file tree Collapse file tree 27 files changed +154
-54
lines changed
stable_diffusion_k_diffusion Expand file tree Collapse file tree 27 files changed +154
-54
lines changed Original file line number Diff line number Diff line change @@ -827,7 +827,9 @@ def encode_prompt(
827
827
)
828
828
829
829
# We are only ALWAYS interested in the pooled output of the final text encoder
830
- pooled_prompt_embeds = prompt_embeds [0 ]
830
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
831
+ pooled_prompt_embeds = prompt_embeds [0 ]
832
+
831
833
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
832
834
833
835
prompt_embeds_list .append (prompt_embeds )
@@ -879,7 +881,8 @@ def encode_prompt(
879
881
output_hidden_states = True ,
880
882
)
881
883
# We are only ALWAYS interested in the pooled output of the final text encoder
882
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
884
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
885
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
883
886
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
884
887
885
888
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -290,7 +290,9 @@ def encode_prompt(
290
290
)
291
291
292
292
# We are only ALWAYS interested in the pooled output of the final text encoder
293
- pooled_prompt_embeds = prompt_embeds [0 ]
293
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
294
+ pooled_prompt_embeds = prompt_embeds [0 ]
295
+
294
296
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
295
297
296
298
prompt_embeds_list .append (prompt_embeds )
@@ -342,7 +344,8 @@ def encode_prompt(
342
344
output_hidden_states = True ,
343
345
)
344
346
# We are only ALWAYS interested in the pooled output of the final text encoder
345
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
347
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
348
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
346
349
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
347
350
348
351
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -628,7 +628,9 @@ def encode_prompt(
628
628
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
629
629
630
630
# We are only ALWAYS interested in the pooled output of the final text encoder
631
- pooled_prompt_embeds = prompt_embeds [0 ]
631
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
632
+ pooled_prompt_embeds = prompt_embeds [0 ]
633
+
632
634
if clip_skip is None :
633
635
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
634
636
else :
@@ -688,7 +690,8 @@ def encode_prompt(
688
690
output_hidden_states = True ,
689
691
)
690
692
# We are only ALWAYS interested in the pooled output of the final text encoder
691
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
693
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
694
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
692
695
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
693
696
694
697
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -359,7 +359,9 @@ def encode_prompt(
359
359
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
360
360
361
361
# We are only ALWAYS interested in the pooled output of the final text encoder
362
- pooled_prompt_embeds = prompt_embeds [0 ]
362
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
363
+ pooled_prompt_embeds = prompt_embeds [0 ]
364
+
363
365
if clip_skip is None :
364
366
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
365
367
else :
@@ -419,7 +421,8 @@ def encode_prompt(
419
421
output_hidden_states = True ,
420
422
)
421
423
# We are only ALWAYS interested in the pooled output of the final text encoder
422
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
424
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
425
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
423
426
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
424
427
425
428
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -507,7 +507,9 @@ def encode_prompt(
507
507
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
508
508
509
509
# We are only ALWAYS interested in the pooled output of the final text encoder
510
- pooled_prompt_embeds = prompt_embeds [0 ]
510
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
511
+ pooled_prompt_embeds = prompt_embeds [0 ]
512
+
511
513
if clip_skip is None :
512
514
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
513
515
else :
@@ -567,7 +569,8 @@ def encode_prompt(
567
569
output_hidden_states = True ,
568
570
)
569
571
# We are only ALWAYS interested in the pooled output of the final text encoder
570
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
572
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
573
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
571
574
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
572
575
573
576
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -394,7 +394,9 @@ def encode_prompt(
394
394
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
395
395
396
396
# We are only ALWAYS interested in the pooled output of the final text encoder
397
- pooled_prompt_embeds = prompt_embeds [0 ]
397
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
398
+ pooled_prompt_embeds = prompt_embeds [0 ]
399
+
398
400
if clip_skip is None :
399
401
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
400
402
else :
@@ -454,7 +456,8 @@ def encode_prompt(
454
456
output_hidden_states = True ,
455
457
)
456
458
# We are only ALWAYS interested in the pooled output of the final text encoder
457
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
459
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
460
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
458
461
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
459
462
460
463
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -390,7 +390,9 @@ def encode_prompt(
390
390
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
391
391
392
392
# We are only ALWAYS interested in the pooled output of the final text encoder
393
- pooled_prompt_embeds = prompt_embeds [0 ]
393
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
394
+ pooled_prompt_embeds = prompt_embeds [0 ]
395
+
394
396
if clip_skip is None :
395
397
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
396
398
else :
@@ -450,7 +452,8 @@ def encode_prompt(
450
452
output_hidden_states = True ,
451
453
)
452
454
# We are only ALWAYS interested in the pooled output of the final text encoder
453
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
455
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
456
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
454
457
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
455
458
456
459
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -438,7 +438,9 @@ def encode_prompt(
438
438
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
439
439
440
440
# We are only ALWAYS interested in the pooled output of the final text encoder
441
- pooled_prompt_embeds = prompt_embeds [0 ]
441
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
442
+ pooled_prompt_embeds = prompt_embeds [0 ]
443
+
442
444
if clip_skip is None :
443
445
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
444
446
else :
@@ -497,8 +499,10 @@ def encode_prompt(
497
499
uncond_input .input_ids .to (device ),
498
500
output_hidden_states = True ,
499
501
)
502
+
500
503
# We are only ALWAYS interested in the pooled output of the final text encoder
501
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
504
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
505
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
502
506
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
503
507
504
508
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -406,7 +406,9 @@ def encode_prompt(
406
406
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
407
407
408
408
# We are only ALWAYS interested in the pooled output of the final text encoder
409
- pooled_prompt_embeds = prompt_embeds [0 ]
409
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
410
+ pooled_prompt_embeds = prompt_embeds [0 ]
411
+
410
412
if clip_skip is None :
411
413
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
412
414
else :
@@ -465,8 +467,10 @@ def encode_prompt(
465
467
uncond_input .input_ids .to (device ),
466
468
output_hidden_states = True ,
467
469
)
470
+
468
471
# We are only ALWAYS interested in the pooled output of the final text encoder
469
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
472
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
473
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
470
474
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
471
475
472
476
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -415,7 +415,9 @@ def encode_prompt(
415
415
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
416
416
417
417
# We are only ALWAYS interested in the pooled output of the final text encoder
418
- pooled_prompt_embeds = prompt_embeds [0 ]
418
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
419
+ pooled_prompt_embeds = prompt_embeds [0 ]
420
+
419
421
if clip_skip is None :
420
422
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
421
423
else :
@@ -474,8 +476,10 @@ def encode_prompt(
474
476
uncond_input .input_ids .to (device ),
475
477
output_hidden_states = True ,
476
478
)
479
+
477
480
# We are only ALWAYS interested in the pooled output of the final text encoder
478
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
481
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
482
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
479
483
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
480
484
481
485
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -408,7 +408,9 @@ def encode_prompt(
408
408
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
409
409
410
410
# We are only ALWAYS interested in the pooled output of the final text encoder
411
- pooled_prompt_embeds = prompt_embeds [0 ]
411
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
412
+ pooled_prompt_embeds = prompt_embeds [0 ]
413
+
412
414
if clip_skip is None :
413
415
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
414
416
else :
@@ -467,8 +469,10 @@ def encode_prompt(
467
469
uncond_input .input_ids .to (device ),
468
470
output_hidden_states = True ,
469
471
)
472
+
470
473
# We are only ALWAYS interested in the pooled output of the final text encoder
471
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
474
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
475
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
472
476
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
473
477
474
478
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -388,7 +388,9 @@ def encode_prompt(
388
388
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
389
389
390
390
# We are only ALWAYS interested in the pooled output of the final text encoder
391
- pooled_prompt_embeds = prompt_embeds [0 ]
391
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
392
+ pooled_prompt_embeds = prompt_embeds [0 ]
393
+
392
394
if clip_skip is None :
393
395
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
394
396
else :
@@ -447,8 +449,10 @@ def encode_prompt(
447
449
uncond_input .input_ids .to (device ),
448
450
output_hidden_states = True ,
449
451
)
452
+
450
453
# We are only ALWAYS interested in the pooled output of the final text encoder
451
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
454
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
455
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
452
456
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
453
457
454
458
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -397,7 +397,9 @@ def encode_prompt(
397
397
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
398
398
399
399
# We are only ALWAYS interested in the pooled output of the final text encoder
400
- pooled_prompt_embeds = prompt_embeds [0 ]
400
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
401
+ pooled_prompt_embeds = prompt_embeds [0 ]
402
+
401
403
if clip_skip is None :
402
404
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
403
405
else :
@@ -456,8 +458,10 @@ def encode_prompt(
456
458
uncond_input .input_ids .to (device ),
457
459
output_hidden_states = True ,
458
460
)
461
+
459
462
# We are only ALWAYS interested in the pooled output of the final text encoder
460
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
463
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
464
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
461
465
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
462
466
463
467
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -422,7 +422,9 @@ def encode_prompt(
422
422
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
423
423
424
424
# We are only ALWAYS interested in the pooled output of the final text encoder
425
- pooled_prompt_embeds = prompt_embeds [0 ]
425
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
426
+ pooled_prompt_embeds = prompt_embeds [0 ]
427
+
426
428
if clip_skip is None :
427
429
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
428
430
else :
@@ -481,8 +483,10 @@ def encode_prompt(
481
483
uncond_input .input_ids .to (device ),
482
484
output_hidden_states = True ,
483
485
)
486
+
484
487
# We are only ALWAYS interested in the pooled output of the final text encoder
485
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
488
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
489
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
486
490
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
487
491
488
492
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -336,7 +336,9 @@ def encode_prompt(
336
336
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
337
337
338
338
# We are only ALWAYS interested in the pooled output of the final text encoder
339
- pooled_prompt_embeds = prompt_embeds [0 ]
339
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
340
+ pooled_prompt_embeds = prompt_embeds [0 ]
341
+
340
342
if clip_skip is None :
341
343
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
342
344
else :
@@ -395,8 +397,10 @@ def encode_prompt(
395
397
uncond_input .input_ids .to (device ),
396
398
output_hidden_states = True ,
397
399
)
400
+
398
401
# We are only ALWAYS interested in the pooled output of the final text encoder
399
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
402
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
403
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
400
404
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
401
405
402
406
negative_prompt_embeds_list .append (negative_prompt_embeds )
Original file line number Diff line number Diff line change @@ -421,7 +421,9 @@ def encode_prompt(
421
421
prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
422
422
423
423
# We are only ALWAYS interested in the pooled output of the final text encoder
424
- pooled_prompt_embeds = prompt_embeds [0 ]
424
+ if pooled_prompt_embeds is None and prompt_embeds [0 ].ndim == 2 :
425
+ pooled_prompt_embeds = prompt_embeds [0 ]
426
+
425
427
if clip_skip is None :
426
428
prompt_embeds = prompt_embeds .hidden_states [- 2 ]
427
429
else :
@@ -480,8 +482,10 @@ def encode_prompt(
480
482
uncond_input .input_ids .to (device ),
481
483
output_hidden_states = True ,
482
484
)
485
+
483
486
# We are only ALWAYS interested in the pooled output of the final text encoder
484
- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
487
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds [0 ].ndim == 2 :
488
+ negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
485
489
negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
486
490
487
491
negative_prompt_embeds_list .append (negative_prompt_embeds )
You can’t perform that action at this time.
0 commit comments