Skip to content

Commit 9dae12f

Browse files
committed
fix: default to empty reference on pix2pix models to avoid segfault
1 parent 9c7cb44 commit 9dae12f

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

stable-diffusion.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ class StableDiffusionGGML {
448448
diffusion_model->alloc_params_buffer();
449449
diffusion_model->get_param_tensors(tensors);
450450

451+
if (sd_version_is_unet_edit(version)) {
452+
vae_decode_only = false;
453+
}
454+
451455
if (high_noise_diffusion_model) {
452456
high_noise_diffusion_model->alloc_params_buffer();
453457
high_noise_diffusion_model->get_param_tensors(tensors);
@@ -2319,23 +2323,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23192323
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height);
23202324
}
23212325

2326+
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
2327+
std::vector<sd_image_t*> ref_images;
2328+
for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) {
2329+
ref_images.push_back(&sd_img_gen_params->ref_images[i]);
2330+
}
2331+
2332+
std::vector<uint8_t> empty_image_data;
2333+
sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, 3, nullptr};
2334+
if (ref_images.empty() && sd_version_is_unet_edit(sd_ctx->sd->version))
2335+
{
2336+
LOG_WARN("This model needs at least one reference image; using an empty reference");
2337+
empty_image_data.reserve(width * height * 3);
2338+
ref_images.push_back(&empty_image);
2339+
empty_image.data = empty_image_data.data();
2340+
guidance.img_cfg = 0.f;
2341+
}
2342+
23222343
if (sd_img_gen_params->ref_images_count > 0) {
23232344
LOG_INFO("EDIT mode");
23242345
}
2325-
else if (sd_ctx->sd->version == VERSION_SD1_PIX2PIX || sd_ctx->sd->version == VERSION_SDXL_PIX2PIX) {
2326-
LOG_ERROR("This model needs at least one reference image");
2327-
return NULL;
2328-
}
23292346

23302347
std::vector<ggml_tensor*> ref_latents;
2331-
for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) {
2348+
for (int i = 0; i < ref_images.size(); i++) {
23322349
ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
23332350
GGML_TYPE_F32,
2334-
sd_img_gen_params->ref_images[i].width,
2335-
sd_img_gen_params->ref_images[i].height,
2351+
ref_images[i]->width,
2352+
ref_images[i]->height,
23362353
3,
23372354
1);
2338-
sd_image_to_tensor(sd_img_gen_params->ref_images[i].data, img);
2355+
sd_image_to_tensor(ref_images[i]->data, img);
23392356

23402357
ggml_tensor* latent = NULL;
23412358
if (sd_ctx->sd->use_tiny_autoencoder) {
@@ -2368,7 +2385,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23682385
SAFE_STR(sd_img_gen_params->prompt),
23692386
SAFE_STR(sd_img_gen_params->negative_prompt),
23702387
sd_img_gen_params->clip_skip,
2371-
sd_img_gen_params->sample_params.guidance,
2388+
guidance,
23722389
sd_img_gen_params->sample_params.eta,
23732390
width,
23742391
height,

0 commit comments

Comments
 (0)