diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index dcceda5ad21..c7563c2c203 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -32,7 +32,9 @@ ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.util.model_util import convert_bundle_to_flux_transformer_checkpoint +from invokeai.backend.model_manager.util.model_util import ( + convert_bundle_to_flux_transformer_checkpoint, +) from invokeai.backend.util.silence_warnings import SilenceWarnings try: @@ -193,6 +195,11 @@ def _load_from_singlefile( sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + self._ram_cache.make_room(new_sd_size) + for k in sd.keys(): + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference + sd[k] = sd[k].to(torch.bfloat16) model.load_state_dict(sd, assign=True) return model