From a16b555d476f6064c72d4a36244f0d0298dbbb2f Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 5 Sep 2024 15:29:41 -0400 Subject: [PATCH] Simplify flux model dtype conversion in model loader --- .../model_manager/load/model_loaders/flux.py | 28 +++---------------- .../backend/model_manager/util/model_util.py | 5 ---- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index fb732a48da0..c7563c2c203 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,7 +1,6 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" -import gc from pathlib import Path from typing import Optional @@ -35,7 +34,6 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.util.model_util import ( convert_bundle_to_flux_transformer_checkpoint, - convert_sd_entry_to_bfloat16, ) from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -197,30 +195,12 @@ def _load_from_singlefile( sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) - futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] - cache_updated = False + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + self._ram_cache.make_room(new_sd_size) for k in sd.keys(): - v = sd[k] - if v.dtype != torch.bfloat16: - if not cache_updated: - # For the first iteration we are just requesting the current size of the state dict - # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 - # This should be refined in the future if not removed entirely when we support more data types - sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()]) - self._ram_cache.make_room(sd_size) - cache_updated = True - futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v)) - # Clean up unused variables - del v - gc.collect() # Force garbage collection to free memory - for future in futures: - k, v = torch.jit.wait(future) - sd[k] = v - del k, v - del futures - gc.collect() # Force garbage collection to free memory + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference + sd[k] = sd[k].to(torch.bfloat16) model.load_state_dict(sd, assign=True) - return model diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 353ebec1c9c..fd904ac3358 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -159,8 +159,3 @@ def convert_bundle_to_flux_transformer_checkpoint( del transformer_state_dict[k] return original_state_dict - - -@torch.jit.script -def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: - return key, tensor.to(torch.bfloat16, copy=False)