Skip to content

Commit

Permalink
Simplify flux model dtype conversion in model loader
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Sep 5, 2024
1 parent 6667c39 commit a16b555
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 29 deletions.
28 changes: 4 additions & 24 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""

import gc
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -35,7 +34,6 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
convert_sd_entry_to_bfloat16,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings

Expand Down Expand Up @@ -197,30 +195,12 @@ def _load_from_singlefile(
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
futures: list[torch.jit.Future[tuple[str, torch.Tensor]]] = []
cache_updated = False
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
v = sd[k]
if v.dtype != torch.bfloat16:
if not cache_updated:
# For the first iteration we are just requesting the current size of the state dict
# This is due to an expected doubling of the tensor sizes in memory after converting float8 -> float16
# This should be refined in the future if not removed entirely when we support more data types
sd_size = sum([ten.nelement() * ten.element_size() for ten in sd.values()])
self._ram_cache.make_room(sd_size)
cache_updated = True
futures.append(torch.jit.fork(convert_sd_entry_to_bfloat16, k, v))
# Clean up unused variables
del v
gc.collect() # Force garbage collection to free memory
for future in futures:
k, v = torch.jit.wait(future)
sd[k] = v
del k, v
del futures
gc.collect() # Force garbage collection to free memory
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)

return model


Expand Down
5 changes: 0 additions & 5 deletions invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,3 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]

return original_state_dict


@torch.jit.script
def convert_sd_entry_to_bfloat16(key: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
return key, tensor.to(torch.bfloat16, copy=False)

0 comments on commit a16b555

Please sign in to comment.