From 637461849183506a4d6a4a0b5f14f20ff8b5d9df Mon Sep 17 00:00:00 2001 From: Panchovix Date: Fri, 12 Jul 2024 13:26:37 -0400 Subject: [PATCH] Fix: multiple checkpoints at the same time For now, when reaching the limit, we unload all the models instead of the least used one, since it was giving issues. --- modules/initialize_util.py | 25 +++++++++---------------- modules/sd_models.py | 23 ++++++----------------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index aaec95ebc..9f24988d2 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -173,30 +173,23 @@ def sigint_handler(sig, frame): # as then the coverage report won't be generated. signal.signal(signal.SIGINT, sigint_handler) + def print_event_handler(event_name): print(f"Event handler triggered: {event_name}") - - + def configure_opts_onchange(): from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack from modules.call_queue import wrap_queued_call from modules_forge import main_thread - - def wrapped_callback(callback, event_name): - def wrapper(): - result = callback() - print_event_handler(event_name) - return result - return wrapper - - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_models.reload_model_weights, "sd_model_checkpoint"))), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_vae.reload_vae_weights, "sd_vae"))), call=False) - shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_vae.reload_vae_weights, "sd_vae_overrides_per_model_preferences"))), call=False) + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights, print_event_handler("sd_model_checkpoint"))), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights, print_event_handler("sd_vae"))), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights, print_event_handler("sd_vae_overrides_per_model_preferences"))), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) - shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: wrapped_callback(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model), "cross_attention_optimization")()), call=False) - shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: wrapped_callback(sd_models.reload_model_weights, "fp8_storage")()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: wrapped_callback(lambda: sd_models.reload_model_weights(forced_reload=True), "cache_fp16_weight")()), call=False) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: (sd_hijack.model_hijack.redo_hijack(shared.sd_model), print_event_handler("cross_attention_optimization"))), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: (sd_models.reload_model_weights(), print_event_handler("fp8_storage"))), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: (sd_models.reload_model_weights(forced_reload=True), print_event_handler("cache_fp16_weight"))), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index 24c762dbf..296f0738f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -21,7 +21,6 @@ from ldm_patched.modules.ops import manual_cast from ldm_patched.modules import model_management as model_management import ldm_patched.modules.model_patcher -import time model_dir = "Stable-diffusion" @@ -518,7 +517,6 @@ def get_sd_model(self): def set_sd_model(self, v, already_loaded=False): self.sd_model = v - v.last_used = time.time() if already_loaded: sd_vae.base_vae = getattr(v, "base_vae", None) sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None) @@ -558,33 +556,28 @@ def send_model_to_trash(m): def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack + checkpoint_info = checkpoint_info or select_checkpoint() timer = Timer() - # Check if the model is already loaded as the primary model if model_data.sd_model and model_data.sd_model.filename == checkpoint_info.filename: - model_data.sd_model.last_used = time.time() # Update last used time return model_data.sd_model - # Check if the model is already in the loaded models list for loaded_model in model_data.loaded_sd_models: if loaded_model.filename == checkpoint_info.filename: print(f" --> Using already loaded model {loaded_model.sd_checkpoint_info.title}: done in {timer.summary()}") - loaded_model.last_used = time.time() # Update last used time model_data.set_sd_model(loaded_model, already_loaded=True) return loaded_model - # If we've reached the model limit, unload the least recently used model if len(model_data.loaded_sd_models) >= shared.opts.sd_checkpoints_limit: - least_recent_model = min(model_data.loaded_sd_models, key=lambda m: m.last_used) - print(f" ------------ Unloading least recently used model: {least_recent_model.sd_checkpoint_info.title} -------------") - model_data.loaded_sd_models.remove(least_recent_model) - if model_data.sd_model == least_recent_model: - model_data.sd_model = None - del least_recent_model + print(" ------------ Unloading all models... -------------") + model_data.sd_model = None + model_data.loaded_sd_models = [] + model_management.unload_all_models() model_management.soft_empty_cache() gc.collect() + timer.record("unload existing models") print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") @@ -595,7 +588,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict) sd_model.filename = checkpoint_info.filename - sd_model.last_used = time.time() # Set initial last used time model_data.loaded_sd_models.append(sd_model) model_data.set_sd_model(sd_model) @@ -610,16 +602,13 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("load VAE") sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - timer.record("load textual inversion embeddings") script_callbacks.model_loaded_callback(sd_model) - timer.record("scripts callbacks") with torch.no_grad(): sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) - timer.record("calculate empty prompt") print(f"Model loaded in {timer.summary()}.")