@@ -443,6 +443,10 @@ class StableDiffusionGGML {
443
443
diffusion_model->alloc_params_buffer ();
444
444
diffusion_model->get_param_tensors (tensors);
445
445
446
+ if (sd_version_is_unet_edit (version)) {
447
+ vae_decode_only = false ;
448
+ }
449
+
446
450
if (high_noise_diffusion_model) {
447
451
high_noise_diffusion_model->alloc_params_buffer ();
448
452
high_noise_diffusion_model->get_param_tensors (tensors);
@@ -748,15 +752,15 @@ class StableDiffusionGGML {
748
752
denoiser->scheduler ->version = version;
749
753
break ;
750
754
case SGM_UNIFORM:
751
- LOG_INFO (" Running with SGM Uniform schedule" );
752
- denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
753
- denoiser->scheduler ->version = version;
754
- break ;
755
+ LOG_INFO (" Running with SGM Uniform schedule" );
756
+ denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
757
+ denoiser->scheduler ->version = version;
758
+ break ;
755
759
case SIMPLE:
756
- LOG_INFO (" Running with Simple schedule" );
757
- denoiser->scheduler = std::make_shared<SimpleSchedule>();
758
- denoiser->scheduler ->version = version;
759
- break ;
760
+ LOG_INFO (" Running with Simple schedule" );
761
+ denoiser->scheduler = std::make_shared<SimpleSchedule>();
762
+ denoiser->scheduler ->version = version;
763
+ break ;
760
764
case SMOOTHSTEP:
761
765
LOG_INFO (" Running with SmoothStep scheduler" );
762
766
denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
@@ -1053,7 +1057,7 @@ class StableDiffusionGGML {
1053
1057
ggml_tensor* denoise_mask = NULL ,
1054
1058
ggml_tensor* vace_context = NULL ,
1055
1059
float vace_strength = 1 .f) {
1056
- if (shifted_timestep > 0 && !sd_version_is_sdxl (version)) {
1060
+ if (shifted_timestep > 0 && !sd_version_is_sdxl (version)) {
1057
1061
LOG_WARN (" timestep shifting is only supported for SDXL models!" );
1058
1062
shifted_timestep = 0 ;
1059
1063
}
@@ -1127,7 +1131,7 @@ class StableDiffusionGGML {
1127
1131
} else {
1128
1132
timesteps_vec.assign (1 , t);
1129
1133
}
1130
-
1134
+
1131
1135
timesteps_vec = process_timesteps (timesteps_vec, init_latent, denoise_mask);
1132
1136
auto timesteps = vector_to_ggml_tensor (work_ctx, timesteps_vec);
1133
1137
std::vector<float > guidance_vec (1 , guidance.distilled_guidance );
@@ -2387,19 +2391,35 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
2387
2391
init_latent = generate_init_latent (sd_ctx, work_ctx, width, height);
2388
2392
}
2389
2393
2390
- if (sd_img_gen_params->ref_images_count > 0 ) {
2394
+ sd_guidance_params_t guidance = sd_img_gen_params->sample_params .guidance ;
2395
+ std::vector<sd_image_t *> ref_images;
2396
+ for (int i = 0 ; i < sd_img_gen_params->ref_images_count ; i++) {
2397
+ ref_images.push_back (&sd_img_gen_params->ref_images [i]);
2398
+ }
2399
+
2400
+ std::vector<uint8_t > empty_image_data;
2401
+ sd_image_t empty_image = {(uint32_t )width, (uint32_t )height, 3 , nullptr };
2402
+ if (ref_images.empty () && sd_version_is_unet_edit (sd_ctx->sd ->version )) {
2403
+ LOG_WARN (" This model needs at least one reference image; using an empty reference" );
2404
+ empty_image_data.resize (width * height * 3 );
2405
+ ref_images.push_back (&empty_image);
2406
+ empty_image.data = empty_image_data.data ();
2407
+ guidance.img_cfg = 0 .f ;
2408
+ }
2409
+
2410
+ if (ref_images.size () > 0 ) {
2391
2411
LOG_INFO (" EDIT mode" );
2392
2412
}
2393
2413
2394
2414
std::vector<ggml_tensor*> ref_latents;
2395
- for (int i = 0 ; i < sd_img_gen_params-> ref_images_count ; i++) {
2415
+ for (int i = 0 ; i < ref_images. size () ; i++) {
2396
2416
ggml_tensor* img = ggml_new_tensor_4d (work_ctx,
2397
2417
GGML_TYPE_F32,
2398
- sd_img_gen_params-> ref_images [i]. width ,
2399
- sd_img_gen_params-> ref_images [i]. height ,
2418
+ ref_images[i]-> width ,
2419
+ ref_images[i]-> height ,
2400
2420
3 ,
2401
2421
1 );
2402
- sd_image_to_tensor (sd_img_gen_params-> ref_images [i], img);
2422
+ sd_image_to_tensor (* ref_images[i], img);
2403
2423
2404
2424
ggml_tensor* latent = NULL ;
2405
2425
if (sd_ctx->sd ->use_tiny_autoencoder ) {
@@ -2437,7 +2457,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
2437
2457
SAFE_STR (sd_img_gen_params->prompt ),
2438
2458
SAFE_STR (sd_img_gen_params->negative_prompt ),
2439
2459
sd_img_gen_params->clip_skip ,
2440
- sd_img_gen_params-> sample_params . guidance ,
2460
+ guidance,
2441
2461
sd_img_gen_params->sample_params .eta ,
2442
2462
sd_img_gen_params->sample_params .shifted_timestep ,
2443
2463
width,
0 commit comments