Skip to content

Commit

Permalink
convert_bundle_to_flux_transformer_checkpoint now removes processed k…
Browse files Browse the repository at this point in the history
…eys to decrease memory usage
  • Loading branch information
brandonrising committed Sep 4, 2024
1 parent d10d258 commit d20335d
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,22 @@ def convert_bundle_to_flux_transformer_checkpoint(
transformer_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
original_state_dict: dict[str, torch.Tensor] = {}
keys_to_remove: list[str] = []

for k, v in transformer_state_dict.items():
if not k.startswith("model.diffusion_model"):
continue
if k.endswith("scale"):
# Scale math must be done at bfloat16 due to our current flux model
# support limitations at inference time
v = v.to(dtype=torch.bfloat16)
original_state_dict[k.replace("model.diffusion_model.", "")] = v
new_key = k.replace("model.diffusion_model.", "")
original_state_dict[new_key] = v
keys_to_remove.append(k)

# Remove processed keys from the original dictionary, leaving others in case
# other model state dicts need to be pulled
for k in keys_to_remove:
del transformer_state_dict[k]

return original_state_dict

0 comments on commit d20335d

Please sign in to comment.