From 991d264da4356d389f1bb7e0c9cfa5f35b3a6cb7 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 4 Sep 2024 14:27:47 -0400 Subject: [PATCH 1/5] Cast tensors in unquantized flux models to bfloat16 during loading --- invokeai/backend/model_manager/load/model_loaders/flux.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index dcceda5ad21..2d54911753d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -193,6 +193,10 @@ def _load_from_singlefile( sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) + for k, v in sd.items(): + if v.dtype == torch.bfloat16: + continue + sd[k] = v.to(dtype=torch.bfloat16) model.load_state_dict(sd, assign=True) return model From 8150a58804730cba8719059c71e596a24ea93833 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 5 Sep 2024 13:06:07 -0400 Subject: [PATCH 2/5] Update flux transformer loader to more efficiently use and release memory during upcasting --- .../model_manager/load/model_loaders/flux.py | 31 ++++++++++++++++--- .../backend/model_manager/util/model_util.py | 5 +++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 2d54911753d..cd48cefb803 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,11 +1,13 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" +import gc from pathlib import Path from typing import Optional import accelerate import torch +from pympler import asizeof from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -32,7 +34,10 @@ ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.util.model_util import convert_bundle_to_flux_transformer_checkpoint +from invokeai.backend.model_manager.util.model_util import ( + convert_bundle_to_flux_transformer_checkpoint, + convert_sd_entry_to_bfloat16, +) from invokeai.backend.util.silence_warnings import SilenceWarnings try: @@ -193,11 +198,27 @@ def _load_from_singlefile( sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) - for k, v in sd.items(): - if v.dtype == torch.bfloat16: - continue - sd[k] = v.to(dtype=torch.bfloat16) + futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] + sd_size = asizeof.asizeof(sd) + cache_updated = False + for k in sd.keys(): + v = sd[k] + if v.dtype != torch.bfloat16: + if not cache_updated: + self._ram_cache.make_room(sd_size) + cache_updated = True + futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v)) + # Clean up unused variables + del v + gc.collect() # Force garbage collection to free memory + for future in futures: + k, v = torch.jit.wait(future) + sd[k] = v + del k, v + del futures + gc.collect() # Force garbage collection to free memory model.load_state_dict(sd, assign=True) + return model diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index fd904ac3358..353ebec1c9c 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -159,3 +159,8 @@ def convert_bundle_to_flux_transformer_checkpoint( del transformer_state_dict[k] return original_state_dict + + +@torch.jit.script +def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: + return key, tensor.to(torch.bfloat16, copy=False) From 289ee1274fa9cb16f94e278d71bd84e6fe187e0e Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 5 Sep 2024 13:12:25 -0400 Subject: [PATCH 3/5] Add comment explaining the cache make room call --- invokeai/backend/model_manager/load/model_loaders/flux.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index cd48cefb803..934fffbbf3e 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -199,6 +199,9 @@ def _load_from_singlefile( if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] + # For the first iteration we are just requesting the current size of the state dict + # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 + # This should be refined in the future if not removed entirely when we support more data types sd_size = asizeof.asizeof(sd) cache_updated = False for k in sd.keys(): From 816aac81df9c6573269a070a137e30134e2b9823 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 5 Sep 2024 14:44:42 -0400 Subject: [PATCH 4/5] Remove dependency of asizeof --- .../backend/model_manager/load/model_loaders/flux.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 934fffbbf3e..fb732a48da0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -7,7 +7,6 @@ import accelerate import torch -from pympler import asizeof from safetensors.torch import load_file from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -199,15 +198,15 @@ def _load_from_singlefile( if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] - # For the first iteration we are just requesting the current size of the state dict - # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 - # This should be refined in the future if not removed entirely when we support more data types - sd_size = asizeof.asizeof(sd) cache_updated = False for k in sd.keys(): v = sd[k] if v.dtype != torch.bfloat16: if not cache_updated: + # For the first iteration we are just requesting the current size of the state dict + # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 + # This should be refined in the future if not removed entirely when we support more data types + sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()]) self._ram_cache.make_room(sd_size) cache_updated = True futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v)) From f08e942830f28de3f1400674328411374cf4a2e7 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 5 Sep 2024 15:29:41 -0400 Subject: [PATCH 5/5] Simplify flux model dtype conversion in model loader --- .../model_manager/load/model_loaders/flux.py | 28 +++---------------- .../backend/model_manager/util/model_util.py | 5 ---- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index fb732a48da0..c7563c2c203 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,7 +1,6 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" -import gc from pathlib import Path from typing import Optional @@ -35,7 +34,6 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.util.model_util import ( convert_bundle_to_flux_transformer_checkpoint, - convert_sd_entry_to_bfloat16, ) from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -197,30 +195,12 @@ def _load_from_singlefile( sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: sd = convert_bundle_to_flux_transformer_checkpoint(sd) - futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = [] - cache_updated = False + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + self._ram_cache.make_room(new_sd_size) for k in sd.keys(): - v = sd[k] - if v.dtype != torch.bfloat16: - if not cache_updated: - # For the first iteration we are just requesting the current size of the state dict - # This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16 - # This should be refined in the future if not removed entirely when we support more data types - sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()]) - self._ram_cache.make_room(sd_size) - cache_updated = True - futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v)) - # Clean up unused variables - del v - gc.collect() # Force garbage collection to free memory - for future in futures: - k, v = torch.jit.wait(future) - sd[k] = v - del k, v - del futures - gc.collect() # Force garbage collection to free memory + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference + sd[k] = sd[k].to(torch.bfloat16) model.load_state_dict(sd, assign=True) - return model diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 353ebec1c9c..fd904ac3358 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -159,8 +159,3 @@ def convert_bundle_to_flux_transformer_checkpoint( del transformer_state_dict[k] return original_state_dict - - -@torch.jit.script -def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]: - return key, tensor.to(torch.bfloat16, copy=False)