Skip to content

Commit

Permalink
Remove dependency of asizeof
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Sep 5, 2024
1 parent 667188d commit 416e0d1
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import accelerate
import torch
from pympler import asizeof
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

Expand Down Expand Up @@ -199,15 +198,15 @@ def _load_from_singlefile(
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = []
# For the first iteration we are just requesting the current size of the state dict
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
# This should be refined in the future if not removed entirely when we support more data types
sd_size = asizeof.asizeof(sd)
cache_updated = False
for k in sd.keys():
v = sd[k]
if v.dtype != torch.bfloat16:
if not cache_updated:
# For the first iteration we are just requesting the current size of the state dict
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
# This should be refined in the future if not removed entirely when we support more data types
sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()])
self._ram_cache.make_room(sd_size)
cache_updated = True
futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v))
Expand Down

0 comments on commit 416e0d1

Please sign in to comment.