Skip to content

Commit

Permalink
Handle ggml aborts
Browse files Browse the repository at this point in the history
  • Loading branch information
SkutteOleg committed Aug 27, 2024
1 parent 409e17d commit 9637ddc
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 87 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ test/
*.gguf
output*.png
models*
*.log
*.log
/cmake-build-debug
/.idea
/cmake-build-release
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "ggml"]
path = ggml
url = https://github.com/ggerganov/ggml.git
url = https://github.com/SkutteOleg/ggml.git
branch = dreamio
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 2 files
+4 −0 include/ggml.h
+20 −8 src/ggml.c
198 changes: 114 additions & 84 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "pmid.hpp"
#include "tae.hpp"
#include "vae.hpp"
#include <setjmp.h>

#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
Expand Down Expand Up @@ -1027,32 +1028,36 @@ SDError new_sd_ctx(sd_ctx_t** sd_ctx,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
try {
*sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (*sd_ctx == NULL) {
return SD_ERROR_MEMORY_ALLOCATION;
}
ggml_error_jmp_set = 1;
if (setjmp(ggml_error_jmp_buf) == 0) {
try {
*sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (*sd_ctx == NULL) {
ggml_error_jmp_set = 0;
return SD_ERROR_MEMORY_ALLOCATION;
}

std::string model_path(model_path_c_str);
std::string clip_l_path(clip_l_path_c_str);
std::string t5xxl_path(t5xxl_path_c_str);
std::string diffusion_model_path(diffusion_model_path_c_str);
std::string vae_path(vae_path_c_str);
std::string taesd_path(taesd_path_c_str);
std::string control_net_path(control_net_path_c_str);
std::string embd_path(embed_dir_c_str);
std::string id_embd_path(id_embed_dir_c_str);
std::string lora_model_dir(lora_model_dir_c_str);

(*sd_ctx)->sd = new StableDiffusionGGML(n_threads,
vae_decode_only,
free_params_immediately,
lora_model_dir,
rng_type);
if ((*sd_ctx)->sd == NULL) {
free(*sd_ctx);
return SD_ERROR_MEMORY_ALLOCATION;
}
std::string model_path(model_path_c_str);
std::string clip_l_path(clip_l_path_c_str);
std::string t5xxl_path(t5xxl_path_c_str);
std::string diffusion_model_path(diffusion_model_path_c_str);
std::string vae_path(vae_path_c_str);
std::string taesd_path(taesd_path_c_str);
std::string control_net_path(control_net_path_c_str);
std::string embd_path(embed_dir_c_str);
std::string id_embd_path(id_embed_dir_c_str);
std::string lora_model_dir(lora_model_dir_c_str);

(*sd_ctx)->sd = new StableDiffusionGGML(n_threads,
vae_decode_only,
free_params_immediately,
lora_model_dir,
rng_type);
if ((*sd_ctx)->sd == NULL) {
free(*sd_ctx);
ggml_error_jmp_set = 0;
return SD_ERROR_MEMORY_ALLOCATION;
}

if (!(*sd_ctx)->sd->load_from_file(model_path,
clip_l_path,
Expand All @@ -1071,15 +1076,24 @@ SDError new_sd_ctx(sd_ctx_t** sd_ctx,
keep_vae_on_cpu)) {
delete (*sd_ctx)->sd;
free(*sd_ctx);
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}

return SD_SUCCESS;
} catch (const std::exception& e) {
LOG_ERROR("Exception in new_sd_ctx: %s", e.what());
return SD_ERROR_PROCESSING;
} catch (...) {
LOG_ERROR("Unknown exception in new_sd_ctx");
ggml_error_jmp_set = 0;
return SD_SUCCESS;
} catch (const std::exception& e) {
LOG_ERROR("Exception in new_sd_ctx: %s", e.what());
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
} catch (...) {
LOG_ERROR("Unknown exception in new_sd_ctx");
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}
} else {
LOG_ERROR("GGML error in new_sd_ctx: %s", ggml_get_error_message());
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}
}
Expand Down Expand Up @@ -1405,55 +1419,59 @@ SDError txt2img(sd_ctx_t* sd_ctx,
bool normalize_input,
const char* input_id_images_path_c_str) {
struct ggml_context* work_ctx = NULL;
try {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
return SD_ERROR_INVALID_CONTEXT;
}
ggml_error_jmp_set = 1;
if (setjmp(ggml_error_jmp_buf) == 0) {
try {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
ggml_error_jmp_set = 0;
return SD_ERROR_INVALID_CONTEXT;
}

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B) {
params.mem_size *= 3;
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
params.mem_size *= 4;
}
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);

work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return SD_ERROR_MEMORY_ALLOCATION;
}
}
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);

work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
ggml_error_jmp_set = 0;
return SD_ERROR_MEMORY_ALLOCATION;
}

size_t t0 = ggml_time_ms();
size_t t0 = ggml_time_ms();

std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);

int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B) {
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B) {
ggml_set_f32(init_latent, 0.0609f);
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
}
} else {
ggml_set_f32(init_latent, 0.f);
}

*result_images = generate_image(sd_ctx,
work_ctx,
Expand All @@ -1475,30 +1493,42 @@ SDError txt2img(sd_ctx_t* sd_ctx,
normalize_input,
input_id_images_path_c_str);

if (*result_images == NULL) {
ggml_free(work_ctx);
return SD_ERROR_PROCESSING;
}
if (*result_images == NULL) {
ggml_free(work_ctx);
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}

*result_count = batch_count;
*result_count = batch_count;

size_t t1 = ggml_time_ms();
size_t t1 = ggml_time_ms();

LOG_INFO("txt2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);
LOG_INFO("txt2img completed in %.2fs", (t1 - t0) * 1.0f / 1000);

ggml_free(work_ctx);
return SD_SUCCESS;
} catch (const std::exception& e) {
LOG_ERROR("Exception in txt2img: %s", e.what());
if (work_ctx) {
ggml_free(work_ctx);
ggml_error_jmp_set = 0;
return SD_SUCCESS;
} catch (const std::exception& e) {
LOG_ERROR("Exception in txt2img: %s", e.what());
if (work_ctx) {
ggml_free(work_ctx);
}
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
} catch (...) {
LOG_ERROR("Unknown exception in txt2img");
if (work_ctx) {
ggml_free(work_ctx);
}
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}
return SD_ERROR_PROCESSING;
} catch (...) {
LOG_ERROR("Unknown exception in txt2img");
} else {
LOG_ERROR("GGML error in txt2img: %s", ggml_get_error_message());
if (work_ctx) {
ggml_free(work_ctx);
}
ggml_error_jmp_set = 0;
return SD_ERROR_PROCESSING;
}
}
Expand Down

0 comments on commit 9637ddc

Please sign in to comment.