Skip to content

Commit

Permalink
Fix: multiple checkpoints at the same time
Browse files Browse the repository at this point in the history
For now, when reaching the limit, we unload all the models instead of the least used one, since it was giving issues.
  • Loading branch information
Panchovix committed Jul 12, 2024
1 parent 7e1aa11 commit 6374618
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 33 deletions.
25 changes: 9 additions & 16 deletions modules/initialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,23 @@ def sigint_handler(sig, frame):
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)


def print_event_handler(event_name):
print(f"Event handler triggered: {event_name}")



def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
from modules_forge import main_thread

def wrapped_callback(callback, event_name):
def wrapper():
result = callback()
print_event_handler(event_name)
return result
return wrapper

shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_models.reload_model_weights, "sd_model_checkpoint"))), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_vae.reload_vae_weights, "sd_vae"))), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(wrapped_callback(sd_vae.reload_vae_weights, "sd_vae_overrides_per_model_preferences"))), call=False)

shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights, print_event_handler("sd_model_checkpoint"))), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights, print_event_handler("sd_vae"))), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights, print_event_handler("sd_vae_overrides_per_model_preferences"))), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: wrapped_callback(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model), "cross_attention_optimization")()), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: wrapped_callback(sd_models.reload_model_weights, "fp8_storage")()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: wrapped_callback(lambda: sd_models.reload_model_weights(forced_reload=True), "cache_fp16_weight")()), call=False)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: (sd_hijack.model_hijack.redo_hijack(shared.sd_model), print_event_handler("cross_attention_optimization"))), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: (sd_models.reload_model_weights(), print_event_handler("fp8_storage"))), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: (sd_models.reload_model_weights(forced_reload=True), print_event_handler("cache_fp16_weight"))), call=False)
startup_timer.record("opts onchange")


Expand Down
23 changes: 6 additions & 17 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ldm_patched.modules.ops import manual_cast
from ldm_patched.modules import model_management as model_management
import ldm_patched.modules.model_patcher
import time


model_dir = "Stable-diffusion"
Expand Down Expand Up @@ -518,7 +517,6 @@ def get_sd_model(self):

def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
v.last_used = time.time()
if already_loaded:
sd_vae.base_vae = getattr(v, "base_vae", None)
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
Expand Down Expand Up @@ -558,33 +556,28 @@ def send_model_to_trash(m):

def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack

checkpoint_info = checkpoint_info or select_checkpoint()

timer = Timer()

# Check if the model is already loaded as the primary model
if model_data.sd_model and model_data.sd_model.filename == checkpoint_info.filename:
model_data.sd_model.last_used = time.time() # Update last used time
return model_data.sd_model

# Check if the model is already in the loaded models list
for loaded_model in model_data.loaded_sd_models:
if loaded_model.filename == checkpoint_info.filename:
print(f" --> Using already loaded model {loaded_model.sd_checkpoint_info.title}: done in {timer.summary()}")
loaded_model.last_used = time.time() # Update last used time
model_data.set_sd_model(loaded_model, already_loaded=True)
return loaded_model

# If we've reached the model limit, unload the least recently used model
if len(model_data.loaded_sd_models) >= shared.opts.sd_checkpoints_limit:
least_recent_model = min(model_data.loaded_sd_models, key=lambda m: m.last_used)
print(f" ------------ Unloading least recently used model: {least_recent_model.sd_checkpoint_info.title} -------------")
model_data.loaded_sd_models.remove(least_recent_model)
if model_data.sd_model == least_recent_model:
model_data.sd_model = None
del least_recent_model
print(" ------------ Unloading all models... -------------")
model_data.sd_model = None
model_data.loaded_sd_models = []
model_management.unload_all_models()
model_management.soft_empty_cache()
gc.collect()
timer.record("unload existing models")

print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")

Expand All @@ -595,7 +588,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
sd_model.filename = checkpoint_info.filename
sd_model.last_used = time.time() # Set initial last used time

model_data.loaded_sd_models.append(sd_model)
model_data.set_sd_model(sd_model)
Expand All @@ -610,16 +602,13 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("load VAE")

sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)

timer.record("load textual inversion embeddings")

script_callbacks.model_loaded_callback(sd_model)

timer.record("scripts callbacks")

with torch.no_grad():
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)

timer.record("calculate empty prompt")

print(f"Model loaded in {timer.summary()}.")
Expand Down

0 comments on commit 6374618

Please sign in to comment.