Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLUX LoRA Support #6847

Merged
merged 49 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
7a80d9e
Add state_dict keys for two FLUX LoRA formats to be used in unit tests.
RyanJDick Sep 3, 2024
c41bd59
WIP - Initial logic for kohya FLUX LoRA conversion.
RyanJDick Sep 3, 2024
ade75b4
Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with…
RyanJDick Sep 4, 2024
fc380f0
Start moving SDXL-specific LoRA conversions out of the general-purpos…
RyanJDick Sep 4, 2024
d0d91ea
Fix type errors in sdxl_lora_conversion_utils.py
RyanJDick Sep 4, 2024
8518ae9
Remove unused LoRAModelRaw.name attribute.
RyanJDick Sep 4, 2024
04b37e6
Move the responsibilities of 1) state_dict loading from file, and 2) …
RyanJDick Sep 4, 2024
7b5befa
Update convert_flux_kohya_state_dict_to_invoke_format() to raise an e…
RyanJDick Sep 4, 2024
00e5686
Add utility function for detecting whether a state_dict is in the FLU…
RyanJDick Sep 4, 2024
db61ec4
Get probing of FLUX LoRA kohya models working.
RyanJDick Sep 4, 2024
01a15b4
WIP - add invocations to support FLUX LORAs.
RyanJDick Sep 4, 2024
50c9410
WIP
RyanJDick Sep 4, 2024
92b8477
Fixup FLUX LoRA unit tests.
RyanJDick Sep 5, 2024
cf9f30c
Rename flux_kohya_lora_conversion_utils.py
RyanJDick Sep 5, 2024
dc09171
WIP on supporting diffusers format FLUX LoRAs.
RyanJDick Sep 5, 2024
bb528d9
Add ConcatenateLoRALayer class.
RyanJDick Sep 9, 2024
bb917ae
(minor) Rename test file.
RyanJDick Sep 9, 2024
040cc28
First draft of lora_model_from_flux_diffusers_state_dict(...).
RyanJDick Sep 9, 2024
534e938
Add unit test for lora_model_from_flux_diffusers_state_dict(...).
RyanJDick Sep 9, 2024
31a8757
Add is_state_dict_likely_in_flux_diffusers_format(...) function with …
RyanJDick Sep 9, 2024
42d6dd3
Add utility test function for creating a dummy state_dict.
RyanJDick Sep 9, 2024
5800e60
Add model probe support for FLUX LoRA models in Diffusers format.
RyanJDick Sep 9, 2024
552a5b0
Add a check that all keys are handled in the FLUX Diffusers LoRA load…
RyanJDick Sep 9, 2024
aac97e1
Genera cleanup/documentation.
RyanJDick Sep 9, 2024
ddda60c
Rename peft/ -> lora/
RyanJDick Sep 10, 2024
ee5d8f6
lora_layer_from_state_dict(...) -> any_lora_layer_from_state_dict(...)
RyanJDick Sep 10, 2024
fef26a5
Consolidate all LoRA patching logic in the LoRAPatcher.
RyanJDick Sep 10, 2024
705173b
Remove unused layer_key property from LoRALayerBase.
RyanJDick Sep 10, 2024
2ff4dae
Add util functions calc_tensor_size(...) and calc_tensors_size(...).
RyanJDick Sep 10, 2024
049ce18
WIP - adding LoRA sidecar layers
RyanJDick Sep 10, 2024
3e12ac9
WIP - LoRA sidecar layers.
RyanJDick Sep 11, 2024
f5f8944
Bug fixes to get LoRA sidecar patching working for the first time.
RyanJDick Sep 11, 2024
45bc8fc
WIP - Implement sidecar LoRA layers using functional API.
RyanJDick Sep 11, 2024
10c3c61
Get diffusers FLUX LoRA working as sidecar patch on quantized model.
RyanJDick Sep 11, 2024
81fbaf2
Assume LoRA alpha=8 for FLUX diffusers PEFT LoRAs.
RyanJDick Sep 12, 2024
9438ea6
Update all lycoris layer types to use the new torch.nn.Module base cl…
RyanJDick Sep 12, 2024
5bb0c79
Add links to test models for loha, lokr, ia3.
RyanJDick Sep 12, 2024
7ce41bf
Fixup unit tests.
RyanJDick Sep 12, 2024
ae41651
Remove LoRA conv sidecar layers until they are needed and properly te…
RyanJDick Sep 12, 2024
61d3d56
Minor cleanup and documentation updates.
RyanJDick Sep 13, 2024
ba3ba3c
Add unit tests for LoRALinearSidecarLayer and ConcatenatedLoRALinearS…
RyanJDick Sep 13, 2024
02f27c7
Add unit tests for LoRAPatcher.apply_lora_sidecar_patches(...) and fi…
RyanJDick Sep 13, 2024
9466824
Delete duplicate file that was accidentally kept during rebase.
RyanJDick Sep 13, 2024
b1cf5e9
Replace 'torch.device("meta")' with 'accelerate.init_empty_weights()'…
RyanJDick Sep 13, 2024
78efed4
Revert change of make all LoRA layers torch.nn.Module's. While the co…
RyanJDick Sep 13, 2024
d51f2c5
Add bias to LoRA sidecar layer unit tests.
RyanJDick Sep 13, 2024
e88d3cf
Assume alpha=rank for FLUX diffusers PEFT LoRA models.
RyanJDick Sep 16, 2024
2934e31
Fix bug when applying multiple LoRA models via apply_lora_sidecar_pat…
RyanJDick Sep 16, 2024
3d6f60f
Merge branch 'main' into ryan/flux-lora-quantized
RyanJDick Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
Expand Down Expand Up @@ -81,9 +82,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
Expand Down Expand Up @@ -176,9 +178,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoRAPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),
Expand Down
8 changes: 5 additions & 3 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
Expand Down Expand Up @@ -979,9 +980,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):
Expand Down
47 changes: 45 additions & 2 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Optional
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple

import torch
import torchvision.transforms as tv_transforms
Expand Down Expand Up @@ -29,6 +30,9 @@
pack,
unpack,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -187,9 +191,41 @@ def _run_diffusion(
noise=noise,
)

with transformer_info as transformer:
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)

config = transformer_info.config
assert config is not None

# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
)
)
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
dtype=inference_dtype,
)
)
else:
raise ValueError(f"Unsupported model format: {config.format}")

RyanJDick marked this conversation as resolved.
Show resolved Hide resolved
x = denoise(
model=transformer,
img=x,
Expand Down Expand Up @@ -247,6 +283,13 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
# `latents`.
return mask.expand_as(latents)

def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
Expand Down
53 changes: 53 additions & 0 deletions invokeai/app/invocations/flux_lora_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext


@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""

transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)


@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key

if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved

if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')

transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved
LoRAField(
lora=self.lora,
weight=self.weight,
)
)

return FluxLoRALoaderOutput(transformer=transformer)
3 changes: 2 additions & 1 deletion invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CLIPField(BaseModel):

class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")


class T5EncoderField(BaseModel):
Expand Down Expand Up @@ -202,7 +203,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
assert isinstance(transformer_config, CheckpointConfigBase)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
Expand Down Expand Up @@ -204,7 +204,11 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)

with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
Expand Down
Empty file.
Loading
Loading