Skip to content

Commit

Permalink
Update flux transformer loader to more efficiently use and release me…
Browse files Browse the repository at this point in the history
…mory during upcasting
  • Loading branch information
brandonrising committed Sep 5, 2024
1 parent 67d4861 commit fed9da9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
31 changes: 26 additions & 5 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""

import gc
from pathlib import Path
from typing import Optional

import accelerate
import torch
from pympler import asizeof
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

Expand All @@ -32,7 +34,10 @@
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.util.model_util import convert_bundle_to_flux_transformer_checkpoint
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
convert_sd_entry_to_bfloat16,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings

try:
Expand Down Expand Up @@ -193,11 +198,27 @@ def _load_from_singlefile(
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
for k, v in sd.items():
if v.dtype == torch.bfloat16:
continue
sd[k] = v.to(dtype=torch.bfloat16)
futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = []
sd_size = asizeof.asizeof(sd)
cache_updated = False
for k in sd.keys():
v = sd[k]
if v.dtype != torch.bfloat16:
if not cache_updated:
self._ram_cache.make_room(sd_size)
cache_updated = True
futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v))
# Clean up unused variables
del v
gc.collect() # Force garbage collection to free memory
for future in futures:
k, v = torch.jit.wait(future)
sd[k] = v
del k, v
del futures
gc.collect() # Force garbage collection to free memory
model.load_state_dict(sd, assign=True)

return model


Expand Down
5 changes: 5 additions & 0 deletions invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,8 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]

return original_state_dict


@torch.jit.script
def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
return key, tensor.to(torch.bfloat16, copy=False)

0 comments on commit fed9da9

Please sign in to comment.