diff --git a/docs/features/CONTROLNET.md b/docs/features/CONTROLNET.md
index 718b12b0f8b..ce7a4d2898e 100644
--- a/docs/features/CONTROLNET.md
+++ b/docs/features/CONTROLNET.md
@@ -166,7 +166,7 @@ There are several ways to install IP-Adapter models with an existing InvokeAI in
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
-3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
+3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders should be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter
diff --git a/docs/features/DATABASE.md b/docs/features/DATABASE.md
index 85829bef868..2d44dda595f 100644
--- a/docs/features/DATABASE.md
+++ b/docs/features/DATABASE.md
@@ -8,7 +8,7 @@ Invoke uses a SQLite database to store image, workflow, model, and execution dat
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
-Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
+Even so, when testing a prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
## Database Backup
diff --git a/docs/features/GALLERY.md b/docs/features/GALLERY.md
index cc84dbf704e..1c12f59c7a7 100644
--- a/docs/features/GALLERY.md
+++ b/docs/features/GALLERY.md
@@ -70,7 +70,7 @@ Each image also has a context menu (ctrl+click / right-click).
- ***Use Prompt **** this will load only the image's text prompts into the left-hand control panel
- ***Use Seed **** this will load only the image's Seed into the left-hand control panel
- ***Use All **** this will load all of the image's generation information into the left-hand control panel
-- ***Send to Image to Image*** this will put the image into the left-hand panel in the Image to Image tab ana automatically open it
+- ***Send to Image to Image*** this will put the image into the left-hand panel in the Image to Image tab and automatically open it
- ***Send to Unified Canvas*** This will (bold)replace whatever is already present(bold) in the Unified Canvas tab with the image and automatically open the tab
- ***Change Board*** this will oipen a small window that will let you move the image to a different board. This is the same as dragging the image to that board's thumbnail.
- ***Star Image*** this will add the image to the board's list of starred images that are always kept at the top of the gallery. This is the same as clicking on the star on the top right-hand side of the image that appears when you hover over the image with the mouse
diff --git a/docs/features/MODEL_MERGING.md b/docs/features/MODEL_MERGING.md
index e384662ef5d..1aea9472c0c 100644
--- a/docs/features/MODEL_MERGING.md
+++ b/docs/features/MODEL_MERGING.md
@@ -8,7 +8,7 @@ be used to teach an old model new tricks.
## How to Merge Models
-Model Merging can be be done by navigating to the Model Manager and clicking the "Merge Models" tab. From there, you can select the models and settings you want to use to merge th models.
+Model Merging can be done by navigating to the Model Manager and clicking the "Merge Models" tab. From there, you can select the models and settings you want to use to merge the models.
## Settings
diff --git a/docs/features/UNIFIED_CANVAS.md b/docs/features/UNIFIED_CANVAS.md
index 476d2009be2..f0aaa61e22e 100644
--- a/docs/features/UNIFIED_CANVAS.md
+++ b/docs/features/UNIFIED_CANVAS.md
@@ -232,7 +232,7 @@ clarity on the intent and common use cases we expect for utilizing them.
### Compositing / Seam Correction
When doing Inpainting or Outpainting, Invoke needs to merge the pixels generated
-by Stable Diffusion into your existing image. This is achieved through compositing - the area around the the boundary between your image and the new generation is
+by Stable Diffusion into your existing image. This is achieved through compositing - the area around the boundary between your image and the new generation is
automatically blended to produce a seamless output. In a fully automatic
process, a mask is generated to cover the boundary, and then the area of the boundary is
Inpainted.
@@ -242,13 +242,13 @@ help to alter the parameters that control the Compositing. A larger blur and
a blur setting have been noted as producing
consistently strong results . Strength of 0.7 is best for reducing hard seams.
-- **Mode** - What part of the image will have the the Compositing applied to it.
+- **Mode** - What part of the image will have the Compositing applied to it.
- **Mask edge** will apply Compositing to the edge of the masked area
- **Mask** will apply Compositing to the entire masked area
- **Unmasked** will apply Compositing to the entire image
- **Steps** - Number of generation steps that will occur during the Coherence Pass, similar to Denoising Steps. Higher step counts will generally have better results.
- **Strength** - How much noise is added for the Coherence Pass, similar to Denoising Strength. A strength of 0 will result in an unchanged image, while a strength of 1 will result in an image with a completely new area as defined by the Mode setting.
-- **Blur** - Adjusts the pixel radius of the the mask. A larger blur radius will cause the mask to extend past the visibly masked area, while too small of a blur radius will result in a mask that is smaller than the visibly masked area.
+- **Blur** - Adjusts the pixel radius of the mask. A larger blur radius will cause the mask to extend past the visibly masked area, while too small of a blur radius will result in a mask that is smaller than the visibly masked area.
- **Blur Method** - The method of blur applied to the masked area.
diff --git a/docs/features/UTILITIES.md b/docs/features/UTILITIES.md
index 2d62fe3a79f..ba0573ed98d 100644
--- a/docs/features/UTILITIES.md
+++ b/docs/features/UTILITIES.md
@@ -296,7 +296,7 @@ finding and fixing three problems that can arise over time:
into the database.
3. The thumbnail for an image is missing, again causing a black
- gallery thumbnail. This is fixed by running the "thumbnaiils"
+ gallery thumbnail. This is fixed by running the "thumbnails"
operation, which simply regenerates and re-registers the missing
thumbnail.
diff --git a/docs/nodes/communityNodes.md b/docs/nodes/communityNodes.md
index 296fbb7ee61..e885b516cb8 100644
--- a/docs/nodes/communityNodes.md
+++ b/docs/nodes/communityNodes.md
@@ -10,7 +10,7 @@ The suggested method is to use `git clone` to clone the repository the node is f
If you'd prefer, you can also just download the whole node folder from the linked repository and add it to the `nodes` folder.
-To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
+To use a community workflow, download the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Adapters-Linked](#adapters-linked-nodes)
@@ -427,7 +427,7 @@ This node works best with SDXL models, especially as the style can be described
5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend()
6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option
-The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.
+The following Nodes are now included in v3.2 of Invoke and are no longer in this set of tools.
- `Prompt Join` -> `String Join`
- `Prompt Join Three` -> `String Join Three`
- `Prompt Replace` -> `String Replace`
@@ -456,7 +456,7 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
### BriaAI Remove Background
-**Description**: Implements one click background removal with BriaAI's new version 1.4 model which seems to be be producing better results than any other previous background removal tool.
+**Description**: Implements one click background removal with BriaAI's new version 1.4 model which seems to be producing better results than any other previous background removal tool.
**Node Link:** https://github.com/blessedcoolant/invoke_bria_rmbg
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 8cbc5c00faf..e829829cf36 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -20,6 +20,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -81,9 +82,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
- ModelPatcher.apply_lora_text_encoder(
- text_encoder,
- loras=_lora_loader(),
+ LoRAPatcher.apply_lora_patches(
+ model=text_encoder,
+ patches=_lora_loader(),
+ prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -176,9 +178,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
- ModelPatcher.apply_lora(
+ LoRAPatcher.apply_lora_patches(
text_encoder,
- loras=_lora_loader(),
+ patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),
diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index 71766994886..34295b5e229 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -37,6 +37,7 @@
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@@ -979,9 +980,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
- ModelPatcher.apply_lora_unet(
- unet,
- loras=_lora_loader(),
+ LoRAPatcher.apply_lora_patches(
+ model=unet,
+ patches=_lora_loader(),
+ prefix="lora_unet_",
cached_weights=cached_weights,
),
):
diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py
index 0260a2f476d..7035d62f365 100644
--- a/invokeai/app/invocations/flux_denoise.py
+++ b/invokeai/app/invocations/flux_denoise.py
@@ -1,4 +1,5 @@
-from typing import Callable, Optional
+from contextlib import ExitStack
+from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
@@ -29,6 +30,9 @@
pack,
unpack,
)
+from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.lora_patcher import LoRAPatcher
+from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -187,9 +191,41 @@ def _run_diffusion(
noise=noise,
)
- with transformer_info as transformer:
+ with (
+ transformer_info.model_on_device() as (cached_weights, transformer),
+ ExitStack() as exit_stack,
+ ):
assert isinstance(transformer, Flux)
+ config = transformer_info.config
+ assert config is not None
+
+ # Apply LoRA models to the transformer.
+ # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
+ if config.format in [ModelFormat.Checkpoint]:
+ # The model is non-quantized, so we can apply the LoRA weights directly into the model.
+ exit_stack.enter_context(
+ LoRAPatcher.apply_lora_patches(
+ model=transformer,
+ patches=self._lora_iterator(context),
+ prefix="",
+ cached_weights=cached_weights,
+ )
+ )
+ elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
+ # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
+ # than directly patching the weights, but is agnostic to the quantization format.
+ exit_stack.enter_context(
+ LoRAPatcher.apply_lora_sidecar_patches(
+ model=transformer,
+ patches=self._lora_iterator(context),
+ prefix="",
+ dtype=inference_dtype,
+ )
+ )
+ else:
+ raise ValueError(f"Unsupported model format: {config.format}")
+
x = denoise(
model=transformer,
img=x,
@@ -247,6 +283,13 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
# `latents`.
return mask.expand_as(latents)
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
+ for lora in self.transformer.loras:
+ lora_info = context.models.load(lora.lora)
+ assert isinstance(lora_info.model, LoRAModelRaw)
+ yield (lora_info.model, lora.weight)
+ del lora_info
+
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py
new file mode 100644
index 00000000000..a12f21cb9af
--- /dev/null
+++ b/invokeai/app/invocations/flux_lora_loader.py
@@ -0,0 +1,53 @@
+from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
+from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
+from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
+from invokeai.app.services.shared.invocation_context import InvocationContext
+
+
+@invocation_output("flux_lora_loader_output")
+class FluxLoRALoaderOutput(BaseInvocationOutput):
+ """FLUX LoRA Loader Output"""
+
+ transformer: TransformerField = OutputField(
+ default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
+ )
+
+
+@invocation(
+ "flux_lora_loader",
+ title="FLUX LoRA",
+ tags=["lora", "model", "flux"],
+ category="model",
+ version="1.0.0",
+)
+class FluxLoRALoaderInvocation(BaseInvocation):
+ """Apply a LoRA model to a FLUX transformer."""
+
+ lora: ModelIdentifierField = InputField(
+ description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
+ )
+ weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
+ transformer: TransformerField = InputField(
+ description=FieldDescriptions.transformer,
+ input=Input.Connection,
+ title="FLUX Transformer",
+ )
+
+ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
+ lora_key = self.lora.key
+
+ if not context.models.exists(lora_key):
+ raise ValueError(f"Unknown lora: {lora_key}!")
+
+ if any(lora.lora.key == lora_key for lora in self.transformer.loras):
+ raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
+
+ transformer = self.transformer.model_copy(deep=True)
+ transformer.loras.append(
+ LoRAField(
+ lora=self.lora,
+ weight=self.weight,
+ )
+ )
+
+ return FluxLoRALoaderOutput(transformer=transformer)
diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
index ef32f7bc011..c0d0a4a7f76 100644
--- a/invokeai/app/invocations/model.py
+++ b/invokeai/app/invocations/model.py
@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
+ loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -202,7 +203,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
- transformer=TransformerField(transformer=transformer),
+ transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py
index 6285e67230b..556600b4128 100644
--- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py
+++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py
@@ -23,7 +23,7 @@
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
-from invokeai.backend.model_patcher import ModelPatcher
+from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -204,7 +204,11 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
- with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
+ with (
+ ExitStack() as exit_stack,
+ unet_info as unet,
+ LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
+ ):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
diff --git a/invokeai/app/services/board_image_records/board_image_records_sqlite.py b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
index 33ac76b06fb..c189c65c213 100644
--- a/invokeai/app/services/board_image_records/board_image_records_sqlite.py
+++ b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
@@ -146,7 +146,11 @@ def get_image_count_for_board(self, board_id: str) -> int:
self._lock.acquire()
self._cursor.execute(
"""--sql
- SELECT COUNT(*) FROM board_images WHERE board_id = ?;
+ SELECT COUNT(*)
+ FROM board_images
+ INNER JOIN images ON board_images.image_name = images.image_name
+ WHERE images.is_intermediate = FALSE
+ AND board_images.board_id = ?;
""",
(board_id,),
)
diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py
index b0c2155a18a..2eafdfa2de9 100644
--- a/invokeai/app/services/image_records/image_records_sqlite.py
+++ b/invokeai/app/services/image_records/image_records_sqlite.py
@@ -407,6 +407,7 @@ def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
+ AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py
index 27578dd5d6b..9ce80dc0355 100644
--- a/invokeai/app/services/model_install/model_install_base.py
+++ b/invokeai/app/services/model_install/model_install_base.py
@@ -254,7 +254,7 @@ def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
is periodically cleared of infrequently-used entries when the model
converter runs.
- Note that this doesn't automaticallly install or register the model, but is
+ Note that this doesn't automatically install or register the model, but is
intended for use by nodes that need access to models that aren't directly
supported by InvokeAI. The downloading process takes advantage of the download queue
to avoid interrupting other operations.
diff --git a/invokeai/backend/lora/conversions/__init__.py b/invokeai/backend/lora/conversions/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py
new file mode 100644
index 00000000000..5cdbd15c4ae
--- /dev/null
+++ b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py
@@ -0,0 +1,206 @@
+from typing import Dict
+
+import torch
+
+from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
+from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
+from invokeai.backend.lora.layers.lora_layer import LoRALayer
+from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+
+
+def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
+ """Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
+
+ This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
+ perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
+ """
+ # First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
+ all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
+
+ # Next, check that this is likely a FLUX model by spot-checking a few keys.
+ expected_keys = [
+ "transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
+ "transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
+ "transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
+ "transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
+ ]
+ all_expected_keys_present = all(k in state_dict for k in expected_keys)
+
+ return all_keys_in_peft_format and all_expected_keys_present
+
+
+def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
+ """Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
+
+ This function is based on:
+ https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
+ """
+ # Group keys by layer.
+ grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
+
+ # Remove the "transformer." prefix from all keys.
+ grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
+
+ # Constants for FLUX.1
+ num_double_layers = 19
+ num_single_layers = 38
+ # inner_dim = 3072
+ # mlp_ratio = 4.0
+
+ layers: dict[str, AnyLoRALayer] = {}
+
+ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
+ if src_key in grouped_state_dict:
+ src_layer_dict = grouped_state_dict.pop(src_key)
+ value = {
+ "lora_down.weight": src_layer_dict.pop("lora_A.weight"),
+ "lora_up.weight": src_layer_dict.pop("lora_B.weight"),
+ }
+ if alpha is not None:
+ value["alpha"] = torch.tensor(alpha)
+ layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
+ assert len(src_layer_dict) == 0
+
+ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
+ """Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
+ stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
+ """
+ # We expect that either all src keys are present or none of them are. Verify this.
+ keys_present = [key in grouped_state_dict for key in src_keys]
+ assert all(keys_present) or not any(keys_present)
+
+ # If none of the keys are present, return early.
+ if not any(keys_present):
+ return
+
+ src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
+ sub_layers: list[LoRALayer] = []
+ for src_layer_dict in src_layer_dicts:
+ values = {
+ "lora_down.weight": src_layer_dict.pop("lora_A.weight"),
+ "lora_up.weight": src_layer_dict.pop("lora_B.weight"),
+ }
+ if alpha is not None:
+ values["alpha"] = torch.tensor(alpha)
+ sub_layers.append(LoRALayer.from_state_dict_values(values=values))
+ assert len(src_layer_dict) == 0
+ layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
+
+ # time_text_embed.timestep_embedder -> time_in.
+ add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
+ add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
+
+ # time_text_embed.text_embedder -> vector_in.
+ add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
+ add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
+
+ # time_text_embed.guidance_embedder -> guidance_in.
+ add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
+ add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
+
+ # context_embedder -> txt_in.
+ add_lora_layer_if_present("context_embedder", "txt_in")
+
+ # x_embedder -> img_in.
+ add_lora_layer_if_present("x_embedder", "img_in")
+
+ # Double transformer blocks.
+ for i in range(num_double_layers):
+ # norms.
+ add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
+ add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
+
+ # Q, K, V
+ add_qkv_lora_layer_if_present(
+ [
+ f"transformer_blocks.{i}.attn.to_q",
+ f"transformer_blocks.{i}.attn.to_k",
+ f"transformer_blocks.{i}.attn.to_v",
+ ],
+ f"double_blocks.{i}.img_attn.qkv",
+ )
+ add_qkv_lora_layer_if_present(
+ [
+ f"transformer_blocks.{i}.attn.add_q_proj",
+ f"transformer_blocks.{i}.attn.add_k_proj",
+ f"transformer_blocks.{i}.attn.add_v_proj",
+ ],
+ f"double_blocks.{i}.txt_attn.qkv",
+ )
+
+ # ff img_mlp
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.ff.net.0.proj",
+ f"double_blocks.{i}.img_mlp.0",
+ )
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.ff.net.2",
+ f"double_blocks.{i}.img_mlp.2",
+ )
+
+ # ff txt_mlp
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.ff_context.net.0.proj",
+ f"double_blocks.{i}.txt_mlp.0",
+ )
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.ff_context.net.2",
+ f"double_blocks.{i}.txt_mlp.2",
+ )
+
+ # output projections.
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.attn.to_out.0",
+ f"double_blocks.{i}.img_attn.proj",
+ )
+ add_lora_layer_if_present(
+ f"transformer_blocks.{i}.attn.to_add_out",
+ f"double_blocks.{i}.txt_attn.proj",
+ )
+
+ # Single transformer blocks.
+ for i in range(num_single_layers):
+ # norms
+ add_lora_layer_if_present(
+ f"single_transformer_blocks.{i}.norm.linear",
+ f"single_blocks.{i}.modulation.lin",
+ )
+
+ # Q, K, V, mlp
+ add_qkv_lora_layer_if_present(
+ [
+ f"single_transformer_blocks.{i}.attn.to_q",
+ f"single_transformer_blocks.{i}.attn.to_k",
+ f"single_transformer_blocks.{i}.attn.to_v",
+ f"single_transformer_blocks.{i}.proj_mlp",
+ ],
+ f"single_blocks.{i}.linear1",
+ )
+
+ # Output projections.
+ add_lora_layer_if_present(
+ f"single_transformer_blocks.{i}.proj_out",
+ f"single_blocks.{i}.linear2",
+ )
+
+ # Final layer.
+ add_lora_layer_if_present("proj_out", "final_layer.linear")
+
+ # Assert that all keys were processed.
+ assert len(grouped_state_dict) == 0
+
+ return LoRAModelRaw(layers=layers)
+
+
+def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
+ """Groups the keys in the state dict by layer."""
+ layer_dict: dict[str, dict[str, torch.Tensor]] = {}
+ for key in state_dict:
+ # Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
+ parts = key.rsplit(".", maxsplit=2)
+ layer_name = parts[0]
+ key_name = ".".join(parts[1:])
+ if layer_name not in layer_dict:
+ layer_dict[layer_name] = {}
+ layer_dict[layer_name][key_name] = state_dict[key]
+ return layer_dict
diff --git a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py
new file mode 100644
index 00000000000..3e1ccf64938
--- /dev/null
+++ b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py
@@ -0,0 +1,80 @@
+import re
+from typing import Any, Dict, TypeVar
+
+import torch
+
+from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
+from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
+from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+
+# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
+# Example keys:
+# lora_unet_double_blocks_0_img_attn_proj.alpha
+# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
+# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
+FLUX_KOHYA_KEY_REGEX = (
+ r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
+)
+
+
+def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
+ """Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
+
+ This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
+ perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
+ """
+ return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
+
+
+def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
+ # Group keys by layer.
+ grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
+ for key, value in state_dict.items():
+ layer_name, param_name = key.split(".", 1)
+ if layer_name not in grouped_state_dict:
+ grouped_state_dict[layer_name] = {}
+ grouped_state_dict[layer_name][param_name] = value
+
+ # Convert the state dict to the InvokeAI format.
+ grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
+
+ # Create LoRA layers.
+ layers: dict[str, AnyLoRALayer] = {}
+ for layer_key, layer_state_dict in grouped_state_dict.items():
+ layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
+
+ # Create and return the LoRAModelRaw.
+ return LoRAModelRaw(layers=layers)
+
+
+T = TypeVar("T")
+
+
+def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
+ """Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
+
+ Example key conversions:
+ "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
+ "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
+ "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
+ "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
+ "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
+ "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
+ """
+
+ def replace_func(match: re.Match[str]) -> str:
+ s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
+ if match.group(4):
+ s += f".{match.group(4)}"
+ return s
+
+ converted_dict: dict[str, T] = {}
+ for k, v in state_dict.items():
+ match = re.match(FLUX_KOHYA_KEY_REGEX, k)
+ if match:
+ new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
+ converted_dict[new_key] = v
+ else:
+ raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
+
+ return converted_dict
diff --git a/invokeai/backend/lora/conversions/sd_lora_conversion_utils.py b/invokeai/backend/lora/conversions/sd_lora_conversion_utils.py
new file mode 100644
index 00000000000..0563854ef07
--- /dev/null
+++ b/invokeai/backend/lora/conversions/sd_lora_conversion_utils.py
@@ -0,0 +1,29 @@
+from typing import Dict
+
+import torch
+
+from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
+from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
+from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+
+
+def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
+ grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
+
+ layers: dict[str, AnyLoRALayer] = {}
+ for layer_key, values in grouped_state_dict.items():
+ layers[layer_key] = any_lora_layer_from_state_dict(values)
+
+ return LoRAModelRaw(layers=layers)
+
+
+def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
+ state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
+
+ for key, value in state_dict.items():
+ stem, leaf = key.split(".", 1)
+ if stem not in state_dict_groupped:
+ state_dict_groupped[stem] = {}
+ state_dict_groupped[stem][leaf] = value
+
+ return state_dict_groupped
diff --git a/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py b/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
new file mode 100644
index 00000000000..e3780a7e8a4
--- /dev/null
+++ b/invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
@@ -0,0 +1,154 @@
+import bisect
+from typing import Dict, List, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
+ """Convert the keys of an SDXL LoRA state_dict to diffusers format.
+
+ The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
+ diffusers format, then this function will have no effect.
+
+ This function is adapted from:
+ https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
+
+ Args:
+ state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
+
+ Raises:
+ ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
+
+ Returns:
+ Dict[str, Tensor]: The diffusers-format state_dict.
+ """
+ converted_count = 0 # The number of Stability AI keys converted to diffusers format.
+ not_converted_count = 0 # The number of keys that were not converted.
+
+ # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
+ # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
+ # `input_blocks_4_1_proj_in`.
+ stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
+ stability_unet_keys.sort()
+
+ new_state_dict: dict[str, T] = {}
+ for full_key, value in state_dict.items():
+ if full_key.startswith("lora_unet_"):
+ search_key = full_key.replace("lora_unet_", "")
+ # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
+ position = bisect.bisect_right(stability_unet_keys, search_key)
+ map_key = stability_unet_keys[position - 1]
+ # Now, check if the map_key *actually* matches the search_key.
+ if search_key.startswith(map_key):
+ new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
+ new_state_dict[new_key] = value
+ converted_count += 1
+ else:
+ new_state_dict[full_key] = value
+ not_converted_count += 1
+ elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
+ # The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
+ new_state_dict[full_key] = value
+ continue
+ else:
+ raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
+
+ if converted_count > 0 and not_converted_count > 0:
+ raise ValueError(
+ f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
+ f" not_converted={not_converted_count}"
+ )
+
+ return new_state_dict
+
+
+# code from
+# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
+def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
+ """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
+ unet_conversion_map_layer: list[tuple[str, str]] = []
+
+ for i in range(3): # num_blocks is 3 in sdxl
+ # loop over downblocks/upblocks
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ # if i > 0: commentout for sdxl
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ hf_mid_atn_prefix = "mid_block.attentions.0."
+ sd_mid_atn_prefix = "middle_block.1."
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+ for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0.", "norm1."),
+ ("in_layers.2.", "conv1."),
+ ("out_layers.0.", "norm2."),
+ ("out_layers.3.", "conv2."),
+ ("emb_layers.1.", "time_emb_proj."),
+ ("skip_connection.", "conv_shortcut."),
+ ]
+
+ unet_conversion_map: list[tuple[str, str]] = []
+ for sd, hf in unet_conversion_map_layer:
+ if "resnets" in hf:
+ for sd_res, hf_res in unet_conversion_map_resnet:
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
+ else:
+ unet_conversion_map.append((sd, hf))
+
+ for j in range(2):
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
+ sd_time_embed_prefix = f"time_embed.{j*2}."
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
+
+ for j in range(2):
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
+
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
+ unet_conversion_map.append(("out.2.", "conv_out."))
+
+ return unet_conversion_map
+
+
+SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
+ sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
+}
diff --git a/invokeai/backend/lora/layers/any_lora_layer.py b/invokeai/backend/lora/layers/any_lora_layer.py
index 630e2edd5a4..997fcd4e06f 100644
--- a/invokeai/backend/lora/layers/any_lora_layer.py
+++ b/invokeai/backend/lora/layers/any_lora_layer.py
@@ -1,5 +1,6 @@
from typing import Union
+from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
@@ -7,4 +8,4 @@
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
-AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
+AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
diff --git a/invokeai/backend/lora/layers/concatenated_lora_layer.py b/invokeai/backend/lora/layers/concatenated_lora_layer.py
new file mode 100644
index 00000000000..d764843f5b4
--- /dev/null
+++ b/invokeai/backend/lora/layers/concatenated_lora_layer.py
@@ -0,0 +1,55 @@
+from typing import Optional, Sequence
+
+import torch
+
+from invokeai.backend.lora.layers.lora_layer import LoRALayer
+from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+
+
+class ConcatenatedLoRALayer(LoRALayerBase):
+ """A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
+
+ This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
+ Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
+ stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
+ """
+
+ def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
+ super().__init__(alpha=None, bias=None)
+
+ self.lora_layers = lora_layers
+ self.concat_axis = concat_axis
+
+ def rank(self) -> int | None:
+ return None
+
+ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
+ # TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
+ # require this value, we will need to implement chunking of the original weight tensor here.
+ # Note that we must apply the sub-layer scales here.
+ layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
+ return torch.cat(layer_weights, dim=self.concat_axis)
+
+ def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
+ # TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
+ # require this value, we will need to implement chunking of the original bias tensor here.
+ # Note that we must apply the sub-layer scales here.
+ layer_biases: list[torch.Tensor] = []
+ for lora_layer in self.lora_layers:
+ layer_bias = lora_layer.get_bias(None)
+ if layer_bias is not None:
+ layer_biases.append(layer_bias * lora_layer.scale())
+
+ if len(layer_biases) == 0:
+ return None
+
+ assert len(layer_biases) == len(self.lora_layers)
+ return torch.cat(layer_biases, dim=self.concat_axis)
+
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ super().to(device=device, dtype=dtype)
+ for lora_layer in self.lora_layers:
+ lora_layer.to(device=device, dtype=dtype)
+
+ def calc_size(self) -> int:
+ return super().calc_size() + sum(lora_layer.calc_size() for lora_layer in self.lora_layers)
diff --git a/invokeai/backend/lora/layers/full_layer.py b/invokeai/backend/lora/layers/full_layer.py
index 7d6611c20c1..af68a0b393f 100644
--- a/invokeai/backend/lora/layers/full_layer.py
+++ b/invokeai/backend/lora/layers/full_layer.py
@@ -3,35 +3,32 @@
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class FullLayer(LoRALayerBase):
- # bias handled in LoRALayerBase(calc_size, to)
- # weight: torch.Tensor
- # bias: Optional[torch.Tensor]
+ def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
+ super().__init__(alpha=None, bias=bias)
+ self.weight = torch.nn.Parameter(weight)
- def __init__(
- self,
- layer_key: str,
+ @classmethod
+ def from_state_dict_values(
+ cls,
values: Dict[str, torch.Tensor],
):
- super().__init__(layer_key, values)
+ layer = cls(weight=values["diff"], bias=values.get("diff_b", None))
+ cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
+ return layer
- self.weight = values["diff"]
- self.bias = values.get("diff_b", None)
-
- self.rank = None # unscaled
- self.check_keys(values, {"diff", "diff_b"})
+ def rank(self) -> int | None:
+ return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- model_size += self.weight.nelement() * self.weight.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
-
self.weight = self.weight.to(device=device, dtype=dtype)
+
+ def calc_size(self) -> int:
+ return super().calc_size() + calc_tensor_size(self.weight)
diff --git a/invokeai/backend/lora/layers/ia3_layer.py b/invokeai/backend/lora/layers/ia3_layer.py
index a5b058e5a24..b2edb8f4a28 100644
--- a/invokeai/backend/lora/layers/ia3_layer.py
+++ b/invokeai/backend/lora/layers/ia3_layer.py
@@ -6,37 +6,53 @@
class IA3Layer(LoRALayerBase):
- # weight: torch.Tensor
- # on_input: torch.Tensor
+ """IA3 Layer
- def __init__(
- self,
- layer_key: str,
- values: Dict[str, torch.Tensor],
- ):
- super().__init__(layer_key, values)
+ Example model for testing this layer type: https://civitai.com/models/123930/gwendolyn-tennyson-ben-10-ia3
+ """
- self.weight = values["weight"]
- self.on_input = values["on_input"]
+ def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
+ super().__init__(alpha=None, bias=bias)
+ self.weight = weight
+ self.on_input = on_input
- self.rank = None # unscaled
- self.check_keys(values, {"weight", "on_input"})
+ def rank(self) -> int | None:
+ return None
+
+ @classmethod
+ def from_state_dict_values(
+ cls,
+ values: Dict[str, torch.Tensor],
+ ):
+ bias = cls._parse_bias(
+ values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
+ )
+ layer = cls(
+ weight=values["weight"],
+ on_input=values["on_input"],
+ bias=bias,
+ )
+ cls.warn_on_unhandled_keys(
+ values=values,
+ handled_keys={
+ # Default keys.
+ "bias_indices",
+ "bias_values",
+ "bias_size",
+ # Layer-specific keys.
+ "weight",
+ "on_input",
+ },
+ )
+ return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
- assert orig_weight is not None
return orig_weight * weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- model_size += self.weight.nelement() * self.weight.element_size()
- model_size += self.on_input.nelement() * self.on_input.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
- super().to(device=device, dtype=dtype)
-
- self.weight = self.weight.to(device=device, dtype=dtype)
- self.on_input = self.on_input.to(device=device, dtype=dtype)
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ super().to(device, dtype)
+ self.weight = self.weight.to(device, dtype)
+ self.on_input = self.on_input.to(device, dtype)
diff --git a/invokeai/backend/lora/layers/loha_layer.py b/invokeai/backend/lora/layers/loha_layer.py
index 865fa672c6c..d3be51322ef 100644
--- a/invokeai/backend/lora/layers/loha_layer.py
+++ b/invokeai/backend/lora/layers/loha_layer.py
@@ -1,32 +1,69 @@
-from typing import Dict, Optional
+from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoHALayer(LoRALayerBase):
- # w1_a: torch.Tensor
- # w1_b: torch.Tensor
- # w2_a: torch.Tensor
- # w2_b: torch.Tensor
- # t1: Optional[torch.Tensor] = None
- # t2: Optional[torch.Tensor] = None
+ """LoHA LyCoris layer.
- def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
- super().__init__(layer_key, values)
+ Example model for testing this layer type: https://civitai.com/models/27397/loha-renoir-the-dappled-light-style
+ """
- self.w1_a = values["hada_w1_a"]
- self.w1_b = values["hada_w1_b"]
- self.w2_a = values["hada_w2_a"]
- self.w2_b = values["hada_w2_b"]
- self.t1 = values.get("hada_t1", None)
- self.t2 = values.get("hada_t2", None)
+ def __init__(
+ self,
+ w1_a: torch.Tensor,
+ w1_b: torch.Tensor,
+ w2_a: torch.Tensor,
+ w2_b: torch.Tensor,
+ t1: torch.Tensor | None,
+ t2: torch.Tensor | None,
+ alpha: float | None,
+ bias: torch.Tensor | None,
+ ):
+ super().__init__(alpha=alpha, bias=bias)
+ self.w1_a = w1_a
+ self.w1_b = w1_b
+ self.w2_a = w2_a
+ self.w2_b = w2_b
+ self.t1 = t1
+ self.t2 = t2
+ assert (self.t1 is None) == (self.t2 is None)
- self.rank = self.w1_b.shape[0]
- self.check_keys(
- values,
- {
+ def rank(self) -> int | None:
+ return self.w1_b.shape[0]
+
+ @classmethod
+ def from_state_dict_values(
+ cls,
+ values: Dict[str, torch.Tensor],
+ ):
+ alpha = cls._parse_alpha(values.get("alpha", None))
+ bias = cls._parse_bias(
+ values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
+ )
+ layer = cls(
+ w1_a=values["hada_w1_a"],
+ w1_b=values["hada_w1_b"],
+ w2_a=values["hada_w2_a"],
+ w2_b=values["hada_w2_b"],
+ t1=values.get("hada_t1", None),
+ t2=values.get("hada_t2", None),
+ alpha=alpha,
+ bias=bias,
+ )
+
+ cls.warn_on_unhandled_keys(
+ values=values,
+ handled_keys={
+ # Default keys.
+ "alpha",
+ "bias_indices",
+ "bias_values",
+ "bias_size",
+ # Layer-specific keys.
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
@@ -36,10 +73,11 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
},
)
+ return layer
+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
-
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
@@ -47,22 +85,14 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
- if val is not None:
- model_size += val.nelement() * val.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
-
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
- if self.t1 is not None:
- self.t1 = self.t1.to(device=device, dtype=dtype)
-
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
- if self.t2 is not None:
- self.t2 = self.t2.to(device=device, dtype=dtype)
+ self.t1 = self.t1.to(device=device, dtype=dtype) if self.t1 is not None else self.t1
+ self.t2 = self.t2.to(device=device, dtype=dtype) if self.t2 is not None else self.t2
+
+ def calc_size(self) -> int:
+ return super().calc_size() + calc_tensors_size([self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2])
diff --git a/invokeai/backend/lora/layers/lokr_layer.py b/invokeai/backend/lora/layers/lokr_layer.py
index 19d9e7f7a23..001194e8ee1 100644
--- a/invokeai/backend/lora/layers/lokr_layer.py
+++ b/invokeai/backend/lora/layers/lokr_layer.py
@@ -1,54 +1,82 @@
-from typing import Dict, Optional
+from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoKRLayer(LoRALayerBase):
- # w1: Optional[torch.Tensor] = None
- # w1_a: Optional[torch.Tensor] = None
- # w1_b: Optional[torch.Tensor] = None
- # w2: Optional[torch.Tensor] = None
- # w2_a: Optional[torch.Tensor] = None
- # w2_b: Optional[torch.Tensor] = None
- # t2: Optional[torch.Tensor] = None
+ """LoKR LyCoris layer.
+
+ Example model for testing this layer type: https://civitai.com/models/346747/lokrnekopara-allgirl-for-jru2
+ """
def __init__(
self,
- layer_key: str,
- values: Dict[str, torch.Tensor],
+ w1: torch.Tensor | None,
+ w1_a: torch.Tensor | None,
+ w1_b: torch.Tensor | None,
+ w2: torch.Tensor | None,
+ w2_a: torch.Tensor | None,
+ w2_b: torch.Tensor | None,
+ t2: torch.Tensor | None,
+ alpha: float | None,
+ bias: torch.Tensor | None,
):
- super().__init__(layer_key, values)
-
- self.w1 = values.get("lokr_w1", None)
- if self.w1 is None:
- self.w1_a = values["lokr_w1_a"]
- self.w1_b = values["lokr_w1_b"]
- else:
- self.w1_b = None
- self.w1_a = None
-
- self.w2 = values.get("lokr_w2", None)
- if self.w2 is None:
- self.w2_a = values["lokr_w2_a"]
- self.w2_b = values["lokr_w2_b"]
- else:
- self.w2_a = None
- self.w2_b = None
-
- self.t2 = values.get("lokr_t2", None)
-
+ super().__init__(alpha=alpha, bias=bias)
+ self.w1 = w1
+ self.w1_a = w1_a
+ self.w1_b = w1_b
+ self.w2 = w2
+ self.w2_a = w2_a
+ self.w2_b = w2_b
+ self.t2 = t2
+
+ # Validate parameters.
+ assert (self.w1 is None) != (self.w1_a is None)
+ assert (self.w1_a is None) == (self.w1_b is None)
+ assert (self.w2 is None) != (self.w2_a is None)
+ assert (self.w2_a is None) == (self.w2_b is None)
+
+ def rank(self) -> int | None:
if self.w1_b is not None:
- self.rank = self.w1_b.shape[0]
+ return self.w1_b.shape[0]
elif self.w2_b is not None:
- self.rank = self.w2_b.shape[0]
+ return self.w2_b.shape[0]
else:
- self.rank = None # unscaled
+ return None
+
+ @classmethod
+ def from_state_dict_values(
+ cls,
+ values: Dict[str, torch.Tensor],
+ ):
+ alpha = cls._parse_alpha(values.get("alpha", None))
+ bias = cls._parse_bias(
+ values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
+ )
+ layer = cls(
+ w1=values.get("lokr_w1", None),
+ w1_a=values.get("lokr_w1_a", None),
+ w1_b=values.get("lokr_w1_b", None),
+ w2=values.get("lokr_w2", None),
+ w2_a=values.get("lokr_w2_a", None),
+ w2_b=values.get("lokr_w2_b", None),
+ t2=values.get("lokr_t2", None),
+ alpha=alpha,
+ bias=bias,
+ )
- self.check_keys(
+ cls.warn_on_unhandled_keys(
values,
{
+ # Default keys.
+ "alpha",
+ "bias_indices",
+ "bias_values",
+ "bias_size",
+ # Layer-specific keys.
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
@@ -59,8 +87,10 @@ def __init__(
},
)
+ return layer
+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
- w1: Optional[torch.Tensor] = self.w1
+ w1 = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
@@ -78,37 +108,20 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
- assert w1 is not None
- assert w2 is not None
weight = torch.kron(w1, w2)
-
return weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
- if val is not None:
- model_size += val.nelement() * val.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
+ self.w1 = self.w1.to(device=device, dtype=dtype) if self.w1 is not None else self.w1
+ self.w1_a = self.w1_a.to(device=device, dtype=dtype) if self.w1_a is not None else self.w1_a
+ self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.w1_b is not None else self.w1_b
+ self.w2 = self.w2.to(device=device, dtype=dtype) if self.w2 is not None else self.w2
+ self.w2_a = self.w2_a.to(device=device, dtype=dtype) if self.w2_a is not None else self.w2_a
+ self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.w2_b is not None else self.w2_b
+ self.t2 = self.t2.to(device=device, dtype=dtype) if self.t2 is not None else self.t2
- if self.w1 is not None:
- self.w1 = self.w1.to(device=device, dtype=dtype)
- else:
- assert self.w1_a is not None
- assert self.w1_b is not None
- self.w1_a = self.w1_a.to(device=device, dtype=dtype)
- self.w1_b = self.w1_b.to(device=device, dtype=dtype)
-
- if self.w2 is not None:
- self.w2 = self.w2.to(device=device, dtype=dtype)
- else:
- assert self.w2_a is not None
- assert self.w2_b is not None
- self.w2_a = self.w2_a.to(device=device, dtype=dtype)
- self.w2_b = self.w2_b.to(device=device, dtype=dtype)
-
- if self.t2 is not None:
- self.t2 = self.t2.to(device=device, dtype=dtype)
+ def calc_size(self) -> int:
+ return super().calc_size() + calc_tensors_size(
+ [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]
+ )
diff --git a/invokeai/backend/lora/layers/lora_layer.py b/invokeai/backend/lora/layers/lora_layer.py
index fe980059d71..95270e359c5 100644
--- a/invokeai/backend/lora/layers/lora_layer.py
+++ b/invokeai/backend/lora/layers/lora_layer.py
@@ -3,35 +3,61 @@
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+from invokeai.backend.util.calc_tensor_size import calc_tensors_size
-# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
- # up: torch.Tensor
- # mid: Optional[torch.Tensor]
- # down: torch.Tensor
-
def __init__(
self,
- layer_key: str,
+ up: torch.Tensor,
+ mid: Optional[torch.Tensor],
+ down: torch.Tensor,
+ alpha: float | None,
+ bias: Optional[torch.Tensor],
+ ):
+ super().__init__(alpha, bias)
+ self.up = up
+ self.mid = mid
+ self.down = down
+
+ @classmethod
+ def from_state_dict_values(
+ cls,
values: Dict[str, torch.Tensor],
):
- super().__init__(layer_key, values)
+ alpha = cls._parse_alpha(values.get("alpha", None))
+ bias = cls._parse_bias(
+ values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
+ )
- self.up = values["lora_up.weight"]
- self.down = values["lora_down.weight"]
- self.mid = values.get("lora_mid.weight", None)
+ layer = cls(
+ up=values["lora_up.weight"],
+ down=values["lora_down.weight"],
+ mid=values.get("lora_mid.weight", None),
+ alpha=alpha,
+ bias=bias,
+ )
- self.rank = self.down.shape[0]
- self.check_keys(
- values,
- {
+ cls.warn_on_unhandled_keys(
+ values=values,
+ handled_keys={
+ # Default keys.
+ "alpha",
+ "bias_indices",
+ "bias_values",
+ "bias_size",
+ # Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
+ return layer
+
+ def rank(self) -> int:
+ return self.down.shape[0]
+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
@@ -42,18 +68,12 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- for val in [self.up, self.mid, self.down]:
- if val is not None:
- model_size += val.nelement() * val.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
-
self.up = self.up.to(device=device, dtype=dtype)
- self.down = self.down.to(device=device, dtype=dtype)
-
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
+ self.down = self.down.to(device=device, dtype=dtype)
+
+ def calc_size(self) -> int:
+ return super().calc_size() + calc_tensors_size([self.up, self.mid, self.down])
diff --git a/invokeai/backend/lora/layers/lora_layer_base.py b/invokeai/backend/lora/layers/lora_layer_base.py
index 363cff4979f..ce4ba308332 100644
--- a/invokeai/backend/lora/layers/lora_layer_base.py
+++ b/invokeai/backend/lora/layers/lora_layer_base.py
@@ -3,40 +3,48 @@
import torch
import invokeai.backend.util.logging as logger
+from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoRALayerBase:
- # rank: Optional[int]
- # alpha: Optional[float]
- # bias: Optional[torch.Tensor]
- # layer_key: str
-
- # @property
- # def scale(self):
- # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
-
- def __init__(
- self,
- layer_key: str,
- values: Dict[str, torch.Tensor],
- ):
- if "alpha" in values:
- self.alpha = values["alpha"].item()
- else:
- self.alpha = None
-
- if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
- self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
- values["bias_indices"],
- values["bias_values"],
- tuple(values["bias_size"]),
- )
+ """Base class for all LoRA-like patching layers."""
+
+ # Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
+ # could inherit automatic .to(...) behavior for this class, its subclasses, and all sidecar layers that wrap a
+ # LoRALayerBase. We would also be able to implement a single calc_size() method that could be inherited by all
+ # subclasses. But, it turns out that the speed overhead of the default .to(...) implementation in torch.nn.Module is
+ # noticeable, so for now we have opted not to use torch.nn.Module.
+
+ def __init__(self, alpha: float | None, bias: torch.Tensor | None):
+ self._alpha = alpha
+ self.bias = bias
- else:
- self.bias = None
+ @classmethod
+ def _parse_bias(
+ cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
+ ) -> torch.Tensor | None:
+ assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
- self.rank = None # set in layer implementation
- self.layer_key = layer_key
+ bias = None
+ if bias_indices is not None:
+ bias = torch.sparse_coo_tensor(bias_indices, bias_values, tuple(bias_size))
+ return bias
+
+ @classmethod
+ def _parse_alpha(
+ cls,
+ alpha: torch.Tensor | None,
+ ) -> float | None:
+ return alpha.item() if alpha is not None else None
+
+ def rank(self) -> int | None:
+ raise NotImplementedError()
+
+ def scale(self) -> float:
+ rank = self.rank()
+ if self._alpha is None or rank is None:
+ return 1.0
+ return self._alpha / rank
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@@ -51,24 +59,18 @@ def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor
params["bias"] = bias
return params
- def calc_size(self) -> int:
- model_size = 0
- for val in [self.bias]:
- if val is not None:
- model_size += val.nelement() * val.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
- if self.bias is not None:
- self.bias = self.bias.to(device=device, dtype=dtype)
-
- def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
+ @classmethod
+ def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
- # {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
- # `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
- all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
- unknown_keys = set(values.keys()) - all_known_keys
+ unknown_keys = set(values.keys()) - handled_keys
if unknown_keys:
logger.warning(
- f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
+ f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
)
+
+ def calc_size(self) -> int:
+ return calc_tensors_size([self.bias])
+
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ if self.bias is not None:
+ self.bias = self.bias.to(device=device, dtype=dtype)
diff --git a/invokeai/backend/lora/layers/norm_layer.py b/invokeai/backend/lora/layers/norm_layer.py
index 0c8c187d485..fa7c16b3045 100644
--- a/invokeai/backend/lora/layers/norm_layer.py
+++ b/invokeai/backend/lora/layers/norm_layer.py
@@ -1,37 +1,34 @@
-from typing import Dict, Optional
+from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
+from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class NormLayer(LoRALayerBase):
- # bias handled in LoRALayerBase(calc_size, to)
- # weight: torch.Tensor
- # bias: Optional[torch.Tensor]
+ def __init__(self, weight: torch.Tensor, bias: torch.Tensor | None):
+ super().__init__(alpha=None, bias=bias)
+ self.weight = weight
- def __init__(
- self,
- layer_key: str,
+ @classmethod
+ def from_state_dict_values(
+ cls,
values: Dict[str, torch.Tensor],
):
- super().__init__(layer_key, values)
+ layer = cls(weight=values["w_norm"], bias=values.get("b_norm", None))
+ cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
+ return layer
- self.weight = values["w_norm"]
- self.bias = values.get("b_norm", None)
-
- self.rank = None # unscaled
- self.check_keys(values, {"w_norm", "b_norm"})
+ def rank(self) -> int | None:
+ return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
- def calc_size(self) -> int:
- model_size = super().calc_size()
- model_size += self.weight.nelement() * self.weight.element_size()
- return model_size
-
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
-
self.weight = self.weight.to(device=device, dtype=dtype)
+
+ def calc_size(self) -> int:
+ return super().calc_size() + calc_tensor_size(self.weight)
diff --git a/invokeai/backend/lora/layers/utils.py b/invokeai/backend/lora/layers/utils.py
new file mode 100644
index 00000000000..24879abd9d7
--- /dev/null
+++ b/invokeai/backend/lora/layers/utils.py
@@ -0,0 +1,33 @@
+from typing import Dict
+
+import torch
+
+from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
+from invokeai.backend.lora.layers.full_layer import FullLayer
+from invokeai.backend.lora.layers.ia3_layer import IA3Layer
+from invokeai.backend.lora.layers.loha_layer import LoHALayer
+from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
+from invokeai.backend.lora.layers.lora_layer import LoRALayer
+from invokeai.backend.lora.layers.norm_layer import NormLayer
+
+
+def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
+ # Detect layers according to LyCORIS detection logic(`weight_list_det`)
+ # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
+
+ if "lora_up.weight" in state_dict:
+ # LoRA a.k.a LoCon
+ return LoRALayer.from_state_dict_values(state_dict)
+ elif "hada_w1_a" in state_dict:
+ return LoHALayer.from_state_dict_values(state_dict)
+ elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
+ return LoKRLayer.from_state_dict_values(state_dict)
+ elif "diff" in state_dict:
+ # Full a.k.a Diff
+ return FullLayer.from_state_dict_values(state_dict)
+ elif "on_input" in state_dict:
+ return IA3Layer.from_state_dict_values(state_dict)
+ elif "w_norm" in state_dict:
+ return NormLayer.from_state_dict_values(state_dict)
+ else:
+ raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
diff --git a/invokeai/backend/lora/lora_model_raw.py b/invokeai/backend/lora/lora_model_raw.py
index 33b7076cba7..cc8f942bfeb 100644
--- a/invokeai/backend/lora/lora_model_raw.py
+++ b/invokeai/backend/lora/lora_model_raw.py
@@ -1,43 +1,17 @@
# Copyright (c) 2024 The InvokeAI Development team
-"""LoRA model support."""
-
-import bisect
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Mapping, Optional
import torch
-from safetensors.torch import load_file
-from typing_extensions import Self
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
-from invokeai.backend.lora.layers.full_layer import FullLayer
-from invokeai.backend.lora.layers.ia3_layer import IA3Layer
-from invokeai.backend.lora.layers.loha_layer import LoHALayer
-from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
-from invokeai.backend.lora.layers.lora_layer import LoRALayer
-from invokeai.backend.lora.layers.norm_layer import NormLayer
-from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
- _name: str
- layers: Dict[str, AnyLoRALayer]
-
- def __init__(
- self,
- name: str,
- layers: Dict[str, AnyLoRALayer],
- ):
- self._name = name
+ def __init__(self, layers: Mapping[str, AnyLoRALayer]):
self.layers = layers
- @property
- def name(self) -> str:
- return self._name
-
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
- # TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
@@ -46,234 +20,3 @@ def calc_size(self) -> int:
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
-
- @classmethod
- def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- """Convert the keys of an SDXL LoRA state_dict to diffusers format.
-
- The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
- diffusers format, then this function will have no effect.
-
- This function is adapted from:
- https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
-
- Args:
- state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
-
- Raises:
- ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
-
- Returns:
- Dict[str, Tensor]: The diffusers-format state_dict.
- """
- converted_count = 0 # The number of Stability AI keys converted to diffusers format.
- not_converted_count = 0 # The number of keys that were not converted.
-
- # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
- # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
- # `input_blocks_4_1_proj_in`.
- stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
- stability_unet_keys.sort()
-
- new_state_dict = {}
- for full_key, value in state_dict.items():
- if full_key.startswith("lora_unet_"):
- search_key = full_key.replace("lora_unet_", "")
- # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
- position = bisect.bisect_right(stability_unet_keys, search_key)
- map_key = stability_unet_keys[position - 1]
- # Now, check if the map_key *actually* matches the search_key.
- if search_key.startswith(map_key):
- new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
- new_state_dict[new_key] = value
- converted_count += 1
- else:
- new_state_dict[full_key] = value
- not_converted_count += 1
- elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
- # The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
- new_state_dict[full_key] = value
- continue
- else:
- raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
-
- if converted_count > 0 and not_converted_count > 0:
- raise ValueError(
- f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
- f" not_converted={not_converted_count}"
- )
-
- return new_state_dict
-
- @classmethod
- def from_checkpoint(
- cls,
- file_path: Union[str, Path],
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- base_model: Optional[BaseModelType] = None,
- ) -> Self:
- device = device or torch.device("cpu")
- dtype = dtype or torch.float32
-
- if isinstance(file_path, str):
- file_path = Path(file_path)
-
- model = cls(
- name=file_path.stem,
- layers={},
- )
-
- if file_path.suffix == ".safetensors":
- sd = load_file(file_path.absolute().as_posix(), device="cpu")
- else:
- sd = torch.load(file_path, map_location="cpu")
-
- state_dict = cls._group_state(sd)
-
- if base_model == BaseModelType.StableDiffusionXL:
- state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
-
- for layer_key, values in state_dict.items():
- # Detect layers according to LyCORIS detection logic(`weight_list_det`)
- # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
-
- # lora and locon
- if "lora_up.weight" in values:
- layer: AnyLoRALayer = LoRALayer(layer_key, values)
-
- # loha
- elif "hada_w1_a" in values:
- layer = LoHALayer(layer_key, values)
-
- # lokr
- elif "lokr_w1" in values or "lokr_w1_a" in values:
- layer = LoKRLayer(layer_key, values)
-
- # diff
- elif "diff" in values:
- layer = FullLayer(layer_key, values)
-
- # ia3
- elif "on_input" in values:
- layer = IA3Layer(layer_key, values)
-
- # norms
- elif "w_norm" in values:
- layer = NormLayer(layer_key, values)
-
- else:
- print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
- raise Exception("Unknown lora format!")
-
- # lower memory consumption by removing already parsed layer values
- state_dict[layer_key].clear()
-
- layer.to(device=device, dtype=dtype)
- model.layers[layer_key] = layer
-
- return model
-
- @staticmethod
- def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
- state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
-
- for key, value in state_dict.items():
- stem, leaf = key.split(".", 1)
- if stem not in state_dict_groupped:
- state_dict_groupped[stem] = {}
- state_dict_groupped[stem][leaf] = value
-
- return state_dict_groupped
-
-
-# code from
-# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
-def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
- """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
- unet_conversion_map_layer = []
-
- for i in range(3): # num_blocks is 3 in sdxl
- # loop over downblocks/upblocks
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- # if i > 0: commentout for sdxl
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0.", "norm1."),
- ("in_layers.2.", "conv1."),
- ("out_layers.0.", "norm2."),
- ("out_layers.3.", "conv2."),
- ("emb_layers.1.", "time_emb_proj."),
- ("skip_connection.", "conv_shortcut."),
- ]
-
- unet_conversion_map = []
- for sd, hf in unet_conversion_map_layer:
- if "resnets" in hf:
- for sd_res, hf_res in unet_conversion_map_resnet:
- unet_conversion_map.append((sd + sd_res, hf + hf_res))
- else:
- unet_conversion_map.append((sd, hf))
-
- for j in range(2):
- hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
- sd_time_embed_prefix = f"time_embed.{j*2}."
- unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
-
- for j in range(2):
- hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
- sd_label_embed_prefix = f"label_emb.0.{j*2}."
- unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
-
- unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
- unet_conversion_map.append(("out.0.", "conv_norm_out."))
- unet_conversion_map.append(("out.2.", "conv_out."))
-
- return unet_conversion_map
-
-
-SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
- sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
-}
diff --git a/invokeai/backend/lora/lora_patcher.py b/invokeai/backend/lora/lora_patcher.py
new file mode 100644
index 00000000000..c0a584a81c0
--- /dev/null
+++ b/invokeai/backend/lora/lora_patcher.py
@@ -0,0 +1,302 @@
+from contextlib import contextmanager
+from typing import Dict, Iterable, Optional, Tuple
+
+import torch
+
+from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
+from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
+from invokeai.backend.lora.layers.lora_layer import LoRALayer
+from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
+ ConcatenatedLoRALinearSidecarLayer,
+)
+from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
+from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
+from invokeai.backend.util.devices import TorchDevice
+from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
+
+
+class LoRAPatcher:
+ @staticmethod
+ @torch.no_grad()
+ @contextmanager
+ def apply_lora_patches(
+ model: torch.nn.Module,
+ patches: Iterable[Tuple[LoRAModelRaw, float]],
+ prefix: str,
+ cached_weights: Optional[Dict[str, torch.Tensor]] = None,
+ ):
+ """Apply one or more LoRA patches to a model within a context manager.
+
+ Args:
+ model (torch.nn.Module): The model to patch.
+ patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
+ associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
+ all at once.
+ prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
+ cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
+ CPU RAM, for efficient unpatching purposes.
+ """
+ original_weights = OriginalWeightsStorage(cached_weights)
+ try:
+ for patch, patch_weight in patches:
+ LoRAPatcher.apply_lora_patch(
+ model=model,
+ prefix=prefix,
+ patch=patch,
+ patch_weight=patch_weight,
+ original_weights=original_weights,
+ )
+ del patch
+
+ yield
+ finally:
+ for param_key, weight in original_weights.get_changed_weights():
+ model.get_parameter(param_key).copy_(weight)
+
+ @staticmethod
+ @torch.no_grad()
+ def apply_lora_patch(
+ model: torch.nn.Module,
+ prefix: str,
+ patch: LoRAModelRaw,
+ patch_weight: float,
+ original_weights: OriginalWeightsStorage,
+ ):
+ """Apply a single LoRA patch to a model.
+
+ Args:
+ model (torch.nn.Module): The model to patch.
+ prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
+ patch (LoRAModelRaw): The LoRA model to patch in.
+ patch_weight (float): The weight of the LoRA patch.
+ original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
+ """
+ if patch_weight == 0:
+ return
+
+ # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
+ # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
+ # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
+ # without searching, but some legacy code still uses flattened keys.
+ layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
+
+ prefix_len = len(prefix)
+
+ for layer_key, layer in patch.layers.items():
+ if not layer_key.startswith(prefix):
+ continue
+
+ module_key, module = LoRAPatcher._get_submodule(
+ model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
+ )
+
+ # All of the LoRA weight calculations will be done on the same device as the module weight.
+ # (Performance will be best if this is a CUDA device.)
+ device = module.weight.device
+ dtype = module.weight.dtype
+
+ layer_scale = layer.scale()
+
+ # We intentionally move to the target device first, then cast. Experimentally, this was found to
+ # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
+ # same thing in a single call to '.to(...)'.
+ layer.to(device=device)
+ layer.to(dtype=torch.float32)
+
+ # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
+ # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
+ for param_name, lora_param_weight in layer.get_parameters(module).items():
+ param_key = module_key + "." + param_name
+ module_param = module.get_parameter(param_name)
+
+ # Save original weight
+ original_weights.save(param_key, module_param)
+
+ if module_param.shape != lora_param_weight.shape:
+ lora_param_weight = lora_param_weight.reshape(module_param.shape)
+
+ lora_param_weight *= patch_weight * layer_scale
+ module_param += lora_param_weight.to(dtype=dtype)
+
+ layer.to(device=TorchDevice.CPU_DEVICE)
+
+ @staticmethod
+ @torch.no_grad()
+ @contextmanager
+ def apply_lora_sidecar_patches(
+ model: torch.nn.Module,
+ patches: Iterable[Tuple[LoRAModelRaw, float]],
+ prefix: str,
+ dtype: torch.dtype,
+ ):
+ """Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
+ overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
+ quantization format.
+
+ Args:
+ model (torch.nn.Module): The model to patch.
+ patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
+ associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
+ all at once.
+ prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
+ dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
+ since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
+ different from their compute dtype.
+ """
+ original_modules: dict[str, torch.nn.Module] = {}
+ try:
+ for patch, patch_weight in patches:
+ LoRAPatcher._apply_lora_sidecar_patch(
+ model=model,
+ prefix=prefix,
+ patch=patch,
+ patch_weight=patch_weight,
+ original_modules=original_modules,
+ dtype=dtype,
+ )
+ yield
+ finally:
+ # Restore original modules.
+ # Note: This logic assumes no nested modules in original_modules.
+ for module_key, orig_module in original_modules.items():
+ module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
+ parent_module = model.get_submodule(module_parent_key)
+ LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
+
+ @staticmethod
+ def _apply_lora_sidecar_patch(
+ model: torch.nn.Module,
+ patch: LoRAModelRaw,
+ patch_weight: float,
+ prefix: str,
+ original_modules: dict[str, torch.nn.Module],
+ dtype: torch.dtype,
+ ):
+ """Apply a single LoRA sidecar patch to a model."""
+
+ if patch_weight == 0:
+ return
+
+ # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
+ # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
+ # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
+ # without searching, but some legacy code still uses flattened keys.
+ layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
+
+ prefix_len = len(prefix)
+
+ for layer_key, layer in patch.layers.items():
+ if not layer_key.startswith(prefix):
+ continue
+
+ module_key, module = LoRAPatcher._get_submodule(
+ model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
+ )
+
+ # Initialize the LoRA sidecar layer.
+ lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
+
+ # Replace the original module with a LoRASidecarModule if it has not already been done.
+ if module_key in original_modules:
+ # The module has already been patched with a LoRASidecarModule. Append to it.
+ assert isinstance(module, LoRASidecarModule)
+ lora_sidecar_module = module
+ else:
+ # The module has not yet been patched with a LoRASidecarModule. Create one.
+ lora_sidecar_module = LoRASidecarModule(module, [])
+ original_modules[module_key] = module
+ module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
+ module_parent = model.get_submodule(module_parent_key)
+ LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
+
+ # Move the LoRA sidecar layer to the same device/dtype as the orig module.
+ # TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
+ lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
+
+ # Add the LoRA sidecar layer to the LoRASidecarModule.
+ lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
+
+ @staticmethod
+ def _split_parent_key(module_key: str) -> tuple[str, str]:
+ """Split a module key into its parent key and module name.
+
+ Args:
+ module_key (str): The module key to split.
+
+ Returns:
+ tuple[str, str]: A tuple containing the parent key and module name.
+ """
+ split_key = module_key.rsplit(".", 1)
+ if len(split_key) == 2:
+ return tuple(split_key)
+ elif len(split_key) == 1:
+ return "", split_key[0]
+ else:
+ raise ValueError(f"Invalid module key: {module_key}")
+
+ @staticmethod
+ def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
+ # TODO(ryand): Add support for more original layer types and LoRA layer types.
+ if isinstance(orig_layer, torch.nn.Linear) or (
+ isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
+ ):
+ if isinstance(lora_layer, LoRALayer):
+ return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
+ elif isinstance(lora_layer, ConcatenatedLoRALayer):
+ return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
+ else:
+ raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
+ else:
+ raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
+
+ @staticmethod
+ def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
+ try:
+ submodule_index = int(module_name)
+ # If the module name is an integer, then we use the __setitem__ method to set the submodule.
+ parent_module[submodule_index] = submodule # type: ignore
+ except ValueError:
+ # If the module name is not an integer, then we use the setattr method to set the submodule.
+ setattr(parent_module, module_name, submodule)
+
+ @staticmethod
+ def _get_submodule(
+ model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
+ ) -> tuple[str, torch.nn.Module]:
+ """Get the submodule corresponding to the given layer key.
+
+ Args:
+ model (torch.nn.Module): The model to search.
+ layer_key (str): The layer key to search for.
+ layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
+ replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
+ directly without searching, but some legacy code still uses flattened keys.
+
+ Returns:
+ tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
+ """
+ if not layer_key_is_flattened:
+ return layer_key, model.get_submodule(layer_key)
+
+ # Handle flattened keys.
+ assert "." not in layer_key
+
+ module = model
+ module_key = ""
+ key_parts = layer_key.split("_")
+
+ submodule_name = key_parts.pop(0)
+
+ while len(key_parts) > 0:
+ try:
+ module = module.get_submodule(submodule_name)
+ module_key += "." + submodule_name
+ submodule_name = key_parts.pop(0)
+ except Exception:
+ submodule_name += "_" + key_parts.pop(0)
+
+ module = module.get_submodule(submodule_name)
+ module_key = (module_key + "." + submodule_name).lstrip(".")
+
+ return module_key, module
diff --git a/invokeai/backend/lora/sidecar_layers/__init__.py b/invokeai/backend/lora/sidecar_layers/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/invokeai/backend/lora/sidecar_layers/concatenated_lora/__init__.py b/invokeai/backend/lora/sidecar_layers/concatenated_lora/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/invokeai/backend/lora/sidecar_layers/concatenated_lora/concatenated_lora_linear_sidecar_layer.py b/invokeai/backend/lora/sidecar_layers/concatenated_lora/concatenated_lora_linear_sidecar_layer.py
new file mode 100644
index 00000000000..aa924644487
--- /dev/null
+++ b/invokeai/backend/lora/sidecar_layers/concatenated_lora/concatenated_lora_linear_sidecar_layer.py
@@ -0,0 +1,34 @@
+import torch
+
+from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
+
+
+class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
+ def __init__(
+ self,
+ concatenated_lora_layer: ConcatenatedLoRALayer,
+ weight: float,
+ ):
+ super().__init__()
+
+ self._concatenated_lora_layer = concatenated_lora_layer
+ self._weight = weight
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ x_chunks: list[torch.Tensor] = []
+ for lora_layer in self._concatenated_lora_layer.lora_layers:
+ x_chunk = torch.nn.functional.linear(input, lora_layer.down)
+ if lora_layer.mid is not None:
+ x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
+ x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
+ x_chunk *= self._weight * lora_layer.scale()
+ x_chunks.append(x_chunk)
+
+ # TODO(ryand): Generalize to support concat_axis != 0.
+ assert self._concatenated_lora_layer.concat_axis == 0
+ x = torch.cat(x_chunks, dim=-1)
+ return x
+
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ self._concatenated_lora_layer.to(device=device, dtype=dtype)
+ return self
diff --git a/invokeai/backend/lora/sidecar_layers/lora/__init__.py b/invokeai/backend/lora/sidecar_layers/lora/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py b/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py
new file mode 100644
index 00000000000..8bf96c97b61
--- /dev/null
+++ b/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py
@@ -0,0 +1,27 @@
+import torch
+
+from invokeai.backend.lora.layers.lora_layer import LoRALayer
+
+
+class LoRALinearSidecarLayer(torch.nn.Module):
+ def __init__(
+ self,
+ lora_layer: LoRALayer,
+ weight: float,
+ ):
+ super().__init__()
+
+ self._lora_layer = lora_layer
+ self._weight = weight
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.nn.functional.linear(x, self._lora_layer.down)
+ if self._lora_layer.mid is not None:
+ x = torch.nn.functional.linear(x, self._lora_layer.mid)
+ x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
+ x *= self._weight * self._lora_layer.scale()
+ return x
+
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ self._lora_layer.to(device=device, dtype=dtype)
+ return self
diff --git a/invokeai/backend/lora/sidecar_layers/lora_sidecar_layer.py b/invokeai/backend/lora/sidecar_layers/lora_sidecar_layer.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py b/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py
new file mode 100644
index 00000000000..80cd9125edf
--- /dev/null
+++ b/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class LoRASidecarModule(torch.nn.Module):
+ """A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
+
+ def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
+ super().__init__()
+ self.orig_module = orig_module
+ self._lora_layers = lora_layers
+
+ def add_lora_layer(self, lora_layer: torch.nn.Module):
+ self._lora_layers.append(lora_layer)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ x = self.orig_module(input)
+ for lora_layer in self._lora_layers:
+ x += lora_layer(input)
+ return x
+
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
+ self._orig_module.to(device=device, dtype=dtype)
+ for lora_layer in self._lora_layers:
+ lora_layer.to(device=device, dtype=dtype)
diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py
index 5a192196910..2ff26d5301a 100644
--- a/invokeai/backend/model_manager/load/model_loaders/lora.py
+++ b/invokeai/backend/model_manager/load/model_loaders/lora.py
@@ -5,8 +5,18 @@
from pathlib import Path
from typing import Optional
+import torch
+from safetensors.torch import load_file
+
from invokeai.app.services.config import InvokeAIAppConfig
-from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
+ lora_model_from_flux_diffusers_state_dict,
+)
+from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
+ lora_model_from_flux_kohya_state_dict,
+)
+from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
+from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -45,14 +55,38 @@ def _load_model(
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
- model = LoRAModelRaw.from_checkpoint(
- file_path=model_path,
- dtype=self._torch_dtype,
- base_model=self._model_base,
- )
+
+ # Load the state dict from the model file.
+ if model_path.suffix == ".safetensors":
+ state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
+ else:
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ # Apply state_dict key conversions, if necessary.
+ if self._model_base == BaseModelType.StableDiffusionXL:
+ state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
+ model = lora_model_from_sd_state_dict(state_dict=state_dict)
+ elif self._model_base == BaseModelType.Flux:
+ if config.format == ModelFormat.Diffusers:
+ # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
+ # distributed as a single file without the associated metadata containing the alpha value. We chose
+ # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
+ # is a popular choice. For example, in the diffusers training scripts:
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
+ model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
+ elif config.format == ModelFormat.LyCORIS:
+ model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
+ else:
+ raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
+ elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
+ # Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
+ model = lora_model_from_sd_state_dict(state_dict=state_dict)
+ else:
+ raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
+
+ model.to(dtype=self._torch_dtype)
return model
- # override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base
diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py
index b54670715f2..f7d20d20c65 100644
--- a/invokeai/backend/model_manager/load/model_util.py
+++ b/invokeai/backend/model_manager/load/model_util.py
@@ -20,6 +20,7 @@
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
+from invokeai.backend.util.calc_tensor_size import calc_tensor_size
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
@@ -83,10 +84,9 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
- mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
- mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
- mem: int = mem_params + mem_bufs # in bytes
- return mem
+ mem_params = sum([calc_tensor_size(param) for param in model.parameters()])
+ mem_bufs = sum([calc_tensor_size(buf) for buf in model.buffers()])
+ return mem_params + mem_bufs
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py
index c4f51b464cb..48db855943d 100644
--- a/invokeai/backend/model_manager/probe.py
+++ b/invokeai/backend/model_manager/probe.py
@@ -10,6 +10,10 @@
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
+from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
+ is_state_dict_likely_in_flux_diffusers_format,
+)
+from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -244,7 +248,9 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
- elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
+ # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
+ # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
+ elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
@@ -554,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
- return ModelFormat("lycoris")
+ if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
+ # TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
+ # ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
+ # file, but the weight keys are in the diffusers format.
+ return ModelFormat.Diffusers
+ return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
- checkpoint = self.checkpoint
- token_vector_length = lora_token_vector_length(checkpoint)
+ if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
+ self.checkpoint
+ ):
+ return BaseModelType.Flux
+ # If we've gotten here, we assume that the model is a Stable Diffusion model.
+ token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py
index 0a2c64ba848..705ac6e685d 100644
--- a/invokeai/backend/model_patcher.py
+++ b/invokeai/backend/model_patcher.py
@@ -5,32 +5,18 @@
import pickle
from contextlib import contextmanager
-from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
-from diffusers import OnnxRuntimeModel, UNet2DConditionModel
+from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
-from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
-from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
-from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
-
-"""
-loras = [
- (lora_model1, 0.7),
- (lora_model2, 0.4),
-]
-with LoRAHelper.apply_lora_unet(unet, loras):
- # unet with applied loras
-# unmodified unet
-
-"""
class ModelPatcher:
@@ -54,95 +40,6 @@ def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Ty
finally:
unet.set_attn_processor(unet_orig_processors)
- @staticmethod
- def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
- assert "." not in lora_key
-
- if not lora_key.startswith(prefix):
- raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
-
- module = model
- module_key = ""
- key_parts = lora_key[len(prefix) :].split("_")
-
- submodule_name = key_parts.pop(0)
-
- while len(key_parts) > 0:
- try:
- module = module.get_submodule(submodule_name)
- module_key += "." + submodule_name
- submodule_name = key_parts.pop(0)
- except Exception:
- submodule_name += "_" + key_parts.pop(0)
-
- module = module.get_submodule(submodule_name)
- module_key = (module_key + "." + submodule_name).lstrip(".")
-
- return (module_key, module)
-
- @classmethod
- @contextmanager
- def apply_lora_unet(
- cls,
- unet: UNet2DConditionModel,
- loras: Iterator[Tuple[LoRAModelRaw, float]],
- cached_weights: Optional[Dict[str, torch.Tensor]] = None,
- ) -> Generator[None, None, None]:
- with cls.apply_lora(
- unet,
- loras=loras,
- prefix="lora_unet_",
- cached_weights=cached_weights,
- ):
- yield
-
- @classmethod
- @contextmanager
- def apply_lora_text_encoder(
- cls,
- text_encoder: CLIPTextModel,
- loras: Iterator[Tuple[LoRAModelRaw, float]],
- cached_weights: Optional[Dict[str, torch.Tensor]] = None,
- ) -> Generator[None, None, None]:
- with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
- yield
-
- @classmethod
- @contextmanager
- def apply_lora(
- cls,
- model: AnyModel,
- loras: Iterator[Tuple[LoRAModelRaw, float]],
- prefix: str,
- cached_weights: Optional[Dict[str, torch.Tensor]] = None,
- ) -> Generator[None, None, None]:
- """
- Apply one or more LoRAs to a model.
-
- :param model: The model to patch.
- :param loras: An iterator that returns the LoRA to patch in and its patch weight.
- :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
- :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
- """
- original_weights = OriginalWeightsStorage(cached_weights)
- try:
- for lora_model, lora_weight in loras:
- LoRAExt.patch_model(
- model=model,
- prefix=prefix,
- lora=lora_model,
- lora_weight=lora_weight,
- original_weights=original_weights,
- )
- del lora_model
-
- yield
-
- finally:
- with torch.no_grad():
- for param_key, weight in original_weights.get_changed_weights():
- model.get_parameter(param_key).copy_(weight)
-
@classmethod
@contextmanager
def apply_ti(
@@ -282,26 +179,6 @@ def apply_freeu(
class ONNXModelPatcher:
- @classmethod
- @contextmanager
- def apply_lora_unet(
- cls,
- unet: OnnxRuntimeModel,
- loras: Iterator[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(unet, loras, "lora_unet_"):
- yield
-
- @classmethod
- @contextmanager
- def apply_lora_text_encoder(
- cls,
- text_encoder: OnnxRuntimeModel,
- loras: List[Tuple[LoRAModelRaw, float]],
- ) -> None:
- with cls.apply_lora(text_encoder, loras, "lora_te_"):
- yield
-
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod
diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py
index b36f0b94524..bbc394a5cbe 100644
--- a/invokeai/backend/stable_diffusion/extensions/lora.py
+++ b/invokeai/backend/stable_diffusion/extensions/lora.py
@@ -1,14 +1,13 @@
from __future__ import annotations
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING
-import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
+from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
-from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@@ -31,107 +30,14 @@ def __init__(
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
- self.patch_model(
+ assert isinstance(lora_model, LoRAModelRaw)
+ LoRAPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
- lora=lora_model,
- lora_weight=self._weight,
+ patch=lora_model,
+ patch_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
-
- @classmethod
- @torch.no_grad()
- def patch_model(
- cls,
- model: torch.nn.Module,
- prefix: str,
- lora: LoRAModelRaw,
- lora_weight: float,
- original_weights: OriginalWeightsStorage,
- ):
- """
- Apply one or more LoRAs to a model.
- :param model: The model to patch.
- :param lora: LoRA model to patch in.
- :param lora_weight: LoRA patch weight.
- :param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
- :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
- """
-
- if lora_weight == 0:
- return
-
- # assert lora.device.type == "cpu"
- for layer_key, layer in lora.layers.items():
- if not layer_key.startswith(prefix):
- continue
-
- # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
- # should be improved in the following ways:
- # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
- # LoRA model is applied.
- # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
- # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
- # weights to have valid keys.
- assert isinstance(model, torch.nn.Module)
- module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
-
- # All of the LoRA weight calculations will be done on the same device as the module weight.
- # (Performance will be best if this is a CUDA device.)
- device = module.weight.device
- dtype = module.weight.dtype
-
- layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
-
- # We intentionally move to the target device first, then cast. Experimentally, this was found to
- # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
- # same thing in a single call to '.to(...)'.
- layer.to(device=device)
- layer.to(dtype=torch.float32)
-
- # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
- # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
- for param_name, lora_param_weight in layer.get_parameters(module).items():
- param_key = module_key + "." + param_name
- module_param = module.get_parameter(param_name)
-
- # save original weight
- original_weights.save(param_key, module_param)
-
- if module_param.shape != lora_param_weight.shape:
- # TODO: debug on lycoris
- lora_param_weight = lora_param_weight.reshape(module_param.shape)
-
- lora_param_weight *= lora_weight * layer_scale
- module_param += lora_param_weight.to(dtype=dtype)
-
- layer.to(device=TorchDevice.CPU_DEVICE)
-
- @staticmethod
- def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
- assert "." not in lora_key
-
- if not lora_key.startswith(prefix):
- raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
-
- module = model
- module_key = ""
- key_parts = lora_key[len(prefix) :].split("_")
-
- submodule_name = key_parts.pop(0)
-
- while len(key_parts) > 0:
- try:
- module = module.get_submodule(submodule_name)
- module_key += "." + submodule_name
- submodule_name = key_parts.pop(0)
- except Exception:
- submodule_name += "_" + key_parts.pop(0)
-
- module = module.get_submodule(submodule_name)
- module_key = (module_key + "." + submodule_name).lstrip(".")
-
- return (module_key, module)
diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py
index 0345478b975..b83d769a8d1 100644
--- a/invokeai/backend/textual_inversion.py
+++ b/invokeai/backend/textual_inversion.py
@@ -10,6 +10,7 @@
from typing_extensions import Self
from invokeai.backend.raw_model import RawModel
+from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class TextualInversionModelRaw(RawModel):
@@ -74,11 +75,7 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
- embedding_size = self.embedding.element_size() * self.embedding.nelement()
- embedding_2_size = 0
- if self.embedding_2 is not None:
- embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
- return embedding_size + embedding_2_size
+ return calc_tensors_size([self.embedding, self.embedding_2])
class TextualInversionManager(BaseTextualInversionManager):
diff --git a/invokeai/backend/util/calc_tensor_size.py b/invokeai/backend/util/calc_tensor_size.py
new file mode 100644
index 00000000000..70b99cd8849
--- /dev/null
+++ b/invokeai/backend/util/calc_tensor_size.py
@@ -0,0 +1,11 @@
+import torch
+
+
+def calc_tensor_size(t: torch.Tensor) -> int:
+ """Calculate the size of a tensor in bytes."""
+ return t.nelement() * t.element_size()
+
+
+def calc_tensors_size(tensors: list[torch.Tensor | None]) -> int:
+ """Calculate the size of a list of tensors in bytes."""
+ return sum(calc_tensor_size(t) for t in tensors if t is not None)
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index bfcf0bf842f..f773b1f7b4b 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -353,230 +353,288 @@
"closeViewer": "Close Viewer"
},
"hotkeys": {
+ "hotkeys": "Hotkeys",
"searchHotkeys": "Search Hotkeys",
"clearSearch": "Clear Search",
"noHotkeysFound": "No Hotkeys Found",
- "acceptStagingImage": {
- "desc": "Accept Current Staging Area Image",
- "title": "Accept Staging Image"
- },
- "addNodes": {
- "desc": "Opens the add node menu",
- "title": "Add Nodes"
- },
- "appHotkeys": "App",
- "cancel": {
- "desc": "Cancel current queue item",
- "title": "Cancel"
- },
- "cancelAndClear": {
- "desc": "Cancel current queue item and clear all pending items",
- "title": "Cancel and Clear"
- },
- "changeTabs": {
- "desc": "Switch to another workspace",
- "title": "Change Tabs"
- },
- "clearMask": {
- "desc": "Clear the entire mask",
- "title": "Clear Mask"
- },
- "closePanels": {
- "desc": "Closes open panels",
- "title": "Close Panels"
- },
- "colorPicker": {
- "desc": "Selects the canvas color picker",
- "title": "Select Color Picker"
- },
- "consoleToggle": {
- "desc": "Open and close console",
- "title": "Console Toggle"
- },
- "copyToClipboard": {
- "desc": "Copy current canvas to clipboard",
- "title": "Copy to Clipboard"
- },
- "decreaseBrushOpacity": {
- "desc": "Decreases the opacity of the canvas brush",
- "title": "Decrease Brush Opacity"
- },
- "decreaseBrushSize": {
- "desc": "Decreases the size of the canvas brush/eraser",
- "title": "Decrease Brush Size"
- },
- "decreaseGalleryThumbSize": {
- "desc": "Decreases gallery thumbnails size",
- "title": "Decrease Gallery Image Size"
- },
- "deleteImage": {
- "desc": "Delete the current image",
- "title": "Delete Image"
- },
- "downloadImage": {
- "desc": "Download current canvas",
- "title": "Download Image"
- },
- "eraseBoundingBox": {
- "desc": "Erases the bounding box area",
- "title": "Erase Bounding Box"
- },
- "fillBoundingBox": {
- "desc": "Fills the bounding box with brush color",
- "title": "Fill Bounding Box"
- },
- "focusPrompt": {
- "desc": "Focus the prompt input area",
- "title": "Focus Prompt"
- },
- "galleryHotkeys": "Gallery",
- "generalHotkeys": "General",
- "hideMask": {
- "desc": "Hide and unhide mask",
- "title": "Hide Mask"
- },
- "increaseBrushOpacity": {
- "desc": "Increases the opacity of the canvas brush",
- "title": "Increase Brush Opacity"
+ "app": {
+ "title": "App",
+ "invoke": {
+ "title": "Invoke",
+ "desc": "Queue a generation, adding it to the end of the queue."
+ },
+ "invokeFront": {
+ "title": "Invoke (Front)",
+ "desc": "Queue a generation, adding it to the front of the queue."
+ },
+ "cancelQueueItem": {
+ "title": "Cancel",
+ "desc": "Cancel the currently processing queue item."
+ },
+ "clearQueue": {
+ "title": "Clear Queue",
+ "desc": "Cancel and clear all queue items."
+ },
+ "selectCanvasTab": {
+ "title": "Select the Canvas Tab",
+ "desc": "Selects the Canvas tab."
+ },
+ "selectUpscalingTab": {
+ "title": "Select the Upscaling Tab",
+ "desc": "Selects the Upscaling tab."
+ },
+ "selectWorkflowsTab": {
+ "title": "Select the Workflows Tab",
+ "desc": "Selects the Workflows tab."
+ },
+ "selectModelsTab": {
+ "title": "Select the Models Tab",
+ "desc": "Selects the Models tab."
+ },
+ "selectQueueTab": {
+ "title": "Select the Queue Tab",
+ "desc": "Selects the Queue tab."
+ },
+ "focusPrompt": {
+ "title": "Focus Prompt",
+ "desc": "Move cursor focus to the positive prompt."
+ },
+ "toggleLeftPanel": {
+ "title": "Toggle Left Panel",
+ "desc": "Show or hide the left panel."
+ },
+ "toggleRightPanel": {
+ "title": "Toggle Right Panel",
+ "desc": "Show or hide the right panel."
+ },
+ "resetPanelLayout": {
+ "title": "Reset Panel Layout",
+ "desc": "Reset the left and right panels to their default size and layout."
+ },
+ "togglePanels": {
+ "title": "Toggle Panels",
+ "desc": "Show or hide both left and right panels at once."
+ }
},
- "increaseBrushSize": {
- "desc": "Increases the size of the canvas brush/eraser",
- "title": "Increase Brush Size"
+ "canvas": {
+ "title": "Canvas",
+ "selectBrushTool": {
+ "title": "Brush Tool",
+ "desc": "Select the brush tool."
+ },
+ "selectBboxTool": {
+ "title": "Bbox Tool",
+ "desc": "Select the bounding box tool."
+ },
+ "decrementToolWidth": {
+ "title": "Decrement Tool Width",
+ "desc": "Decrement the brush or eraser tool width, whichever is selected."
+ },
+ "incrementToolWidth": {
+ "title": "Increment Tool Width",
+ "desc": "Increment the brush or eraser tool width, whichever is selected."
+ },
+ "selectColorPickerTool": {
+ "title": "Color Picker Tool",
+ "desc": "Select the color picker tool."
+ },
+ "selectEraserTool": {
+ "title": "Eraser Tool",
+ "desc": "Select the eraser tool."
+ },
+ "selectMoveTool": {
+ "title": "Move Tool",
+ "desc": "Select the move tool."
+ },
+ "selectRectTool": {
+ "title": "Rect Tool",
+ "desc": "Select the rect tool."
+ },
+ "selectViewTool": {
+ "title": "View Tool",
+ "desc": "Select the view tool."
+ },
+ "fitLayersToCanvas": {
+ "title": "Fit Layers to Canvas",
+ "desc": "Scale and position the view to fit all visible layers."
+ },
+ "setZoomTo100Percent": {
+ "title": "Zoom to 100%",
+ "desc": "Set the canvas zoom to 100%."
+ },
+ "setZoomTo200Percent": {
+ "title": "Zoom to 200%",
+ "desc": "Set the canvas zoom to 200%."
+ },
+ "setZoomTo400Percent": {
+ "title": "Zoom to 400%",
+ "desc": "Set the canvas zoom to 400%."
+ },
+ "setZoomTo800Percent": {
+ "title": "Zoom to 800%",
+ "desc": "Set the canvas zoom to 800%."
+ },
+ "quickSwitch": {
+ "title": "Layer Quick Switch",
+ "desc": "Switch between the last two selected layers. If a layer is bookmarked, always switch between it and the last non-bookmarked layer."
+ },
+ "deleteSelected": {
+ "title": "Delete Layer",
+ "desc": "Delete the selected layer."
+ },
+ "resetSelected": {
+ "title": "Reset Layer",
+ "desc": "Reset the selected layer. Only applies to Inpaint Mask and Regional Guidance."
+ },
+ "undo": {
+ "title": "Undo",
+ "desc": "Undo the last canvas action."
+ },
+ "redo": {
+ "title": "Redo",
+ "desc": "Redo the last canvas action."
+ },
+ "nextEntity": {
+ "title": "Next Layer",
+ "desc": "Select the next layer in the list."
+ },
+ "prevEntity": {
+ "title": "Prev Layer",
+ "desc": "Select the previous layer in the list."
+ },
+ "setFillToWhite": {
+ "title": "Set Color to White",
+ "desc": "Set the current tool color to white."
+ }
},
- "increaseGalleryThumbSize": {
- "desc": "Increases gallery thumbnails size",
- "title": "Increase Gallery Image Size"
+ "workflows": {
+ "title": "Workflows",
+ "addNode": {
+ "title": "Add Node",
+ "descl": "Open the add node menu."
+ },
+ "copySelection": {
+ "title": "Copy",
+ "descl": "Copy selected nodes and edges."
+ },
+ "pasteSelection": {
+ "title": "Paste",
+ "descl": "Paste copied nodes and edges."
+ },
+ "pasteSelectionWithEdges": {
+ "title": "Paste with Edges",
+ "descl": "Paste copied nodes, edges, and all edges connected to copied nodes."
+ },
+ "selectAll": {
+ "title": "Select All",
+ "descl": "Select all nodes and edges."
+ },
+ "deleteSelection": {
+ "title": "Delete",
+ "descl": "Delete selected nodes and edges."
+ },
+ "undo": {
+ "title": "Undo",
+ "descl": "Undo the last workflow action."
+ },
+ "redo": {
+ "title": "Redo",
+ "descl": "Redo the last workflow action."
+ }
},
- "invoke": {
- "desc": "Generate an image",
- "title": "Invoke"
+ "viewer": {
+ "title": "Image Viewer",
+ "toggleViewer": {
+ "title": "Show/Hide Image Viewer",
+ "desc": "Show or hide the image viewer. Only available on the Canvas tab."
+ },
+ "swapImages": {
+ "title": "Swap Comparison Images",
+ "desc": "Swap the images being compared."
+ },
+ "nextComparisonMode": {
+ "title": "Next Comparison Mode",
+ "desc": "Cycle through comparison modes."
+ },
+ "loadWorkflow": {
+ "title": "Load Workflow",
+ "desc": "Load the current image's saved workflow (if it has one)."
+ },
+ "recallAll": {
+ "title": "Recall All Metadata",
+ "desc": "Recall all metadata for the current image."
+ },
+ "recallSeed": {
+ "title": "Recall Seed",
+ "desc": "Recall the seed for the current image."
+ },
+ "recallPrompts": {
+ "title": "Recall Prompts",
+ "desc": "Recall the positive and negative prompts for the current image."
+ },
+ "remix": {
+ "title": "Remix",
+ "desc": "Recall all metadata except for the seed for the current image."
+ },
+ "useSize": {
+ "title": "Use Size",
+ "desc": "Use the current image's size as the bbox size."
+ },
+ "runPostprocessing": {
+ "title": "Run Postprocessing",
+ "desc": "Run the selected postprocessing on the current image."
+ },
+ "toggleMetadata": {
+ "title": "Show/Hide Metadata",
+ "desc": "Show or hide the current image's metadata overlay."
+ }
},
- "keyboardShortcuts": "Hotkeys",
- "maximizeWorkSpace": {
- "desc": "Close panels and maximize work area",
- "title": "Maximize Workspace"
- },
- "mergeVisible": {
- "desc": "Merge all visible layers of canvas",
- "title": "Merge Visible"
- },
- "moveTool": {
- "desc": "Allows canvas navigation",
- "title": "Move Tool"
- },
- "nextImage": {
- "desc": "Display the next image in gallery",
- "title": "Next Image"
- },
- "nextStagingImage": {
- "desc": "Next Staging Area Image",
- "title": "Next Staging Image"
- },
- "nodesHotkeys": "Nodes",
- "pinOptions": {
- "desc": "Pin the options panel",
- "title": "Pin Options"
- },
- "previousImage": {
- "desc": "Display the previous image in gallery",
- "title": "Previous Image"
- },
- "previousStagingImage": {
- "desc": "Previous Staging Area Image",
- "title": "Previous Staging Image"
- },
- "quickToggleMove": {
- "desc": "Temporarily toggles Move mode",
- "title": "Quick Toggle Move"
- },
- "redoStroke": {
- "desc": "Redo a brush stroke",
- "title": "Redo Stroke"
- },
- "resetView": {
- "desc": "Reset Canvas View",
- "title": "Reset View"
- },
- "restoreFaces": {
- "desc": "Restore the current image",
- "title": "Restore Faces"
- },
- "saveToGallery": {
- "desc": "Save current canvas to gallery",
- "title": "Save To Gallery"
- },
- "selectBrush": {
- "desc": "Selects the canvas brush",
- "title": "Select Brush"
- },
- "selectEraser": {
- "desc": "Selects the canvas eraser",
- "title": "Select Eraser"
- },
- "sendToImageToImage": {
- "desc": "Send current image to Image to Image",
- "title": "Send To Image To Image"
- },
- "remixImage": {
- "desc": "Use all parameters except seed from the current image",
- "title": "Remix image"
- },
- "setParameters": {
- "desc": "Use all parameters of the current image",
- "title": "Set Parameters"
- },
- "setPrompt": {
- "desc": "Use the prompt of the current image",
- "title": "Set Prompt"
- },
- "setSeed": {
- "desc": "Use the seed of the current image",
- "title": "Set Seed"
- },
- "showHideBoundingBox": {
- "desc": "Toggle visibility of bounding box",
- "title": "Show/Hide Bounding Box"
- },
- "showInfo": {
- "desc": "Show metadata info of the current image",
- "title": "Show Info"
- },
- "toggleGallery": {
- "desc": "Open and close the gallery drawer",
- "title": "Toggle Gallery"
- },
- "toggleOptions": {
- "desc": "Open and close the options panel",
- "title": "Toggle Options"
- },
- "toggleOptionsAndGallery": {
- "desc": "Open and close the options and gallery panels",
- "title": "Toggle Options and Gallery"
- },
- "resetOptionsAndGallery": {
- "desc": "Resets the options and gallery panels",
- "title": "Reset Options and Gallery"
- },
- "toggleLayer": {
- "desc": "Toggles mask/base layer selection",
- "title": "Toggle Layer"
- },
- "toggleSnap": {
- "desc": "Toggles Snap to Grid",
- "title": "Toggle Snap"
- },
- "undoStroke": {
- "desc": "Undo a brush stroke",
- "title": "Undo Stroke"
- },
- "unifiedCanvasHotkeys": "Unified Canvas",
- "postProcess": {
- "desc": "Process the current image using the selected post-processing model",
- "title": "Process Image"
- },
- "toggleViewer": {
- "desc": "Switches between the Image Viewer and workspace for the current tab.",
- "title": "Toggle Image Viewer"
+ "gallery": {
+ "title": "Gallery",
+ "selectAllOnPage": {
+ "title": "Select All On Page",
+ "desc": "Select all images on the current page."
+ },
+ "clearSelection": {
+ "title": "Clear Selection",
+ "desc": "Clear the current selection, if any."
+ },
+ "galleryNavUp": {
+ "title": "Navigate Up",
+ "desc": "Navigate up in the gallery grid, selecting that image. If at the top of the page, go to the previous page."
+ },
+ "galleryNavRight": {
+ "title": "Navigate Right",
+ "desc": "Navigate right in the gallery grid, selecting that image. If at the last image of the row, go to the next row. If at the last image of the page, go to the next page."
+ },
+ "galleryNavDown": {
+ "title": "Navigate Down",
+ "desc": "Navigate down in the gallery grid, selecting that image. If at the bottom of the page, go to the next page."
+ },
+ "galleryNavLeft": {
+ "title": "Navigate Left",
+ "desc": "Navigate left in the gallery grid, selecting that image. If at the first image of the row, go to the previous row. If at the first image of the page, go to the previous page."
+ },
+ "galleryNavUpAlt": {
+ "title": "Navigate Up (Compare Image)",
+ "desc": "Same as Navigate Up, but selects the compare image, opening compare mode if it isn't already open."
+ },
+ "galleryNavRightAlt": {
+ "title": "Navigate Right (Compare Image)",
+ "desc": "Same as Navigate Right, but selects the compare image, opening compare mode if it isn't already open."
+ },
+ "galleryNavDownAlt": {
+ "title": "Navigate Down (Compare Image)",
+ "desc": "Same as Navigate Down, but selects the compare image, opening compare mode if it isn't already open."
+ },
+ "galleryNavLeftAlt": {
+ "title": "Navigate Left (Compare Image)",
+ "desc": "Same as Navigate Left, but selects the compare image, opening compare mode if it isn't already open."
+ },
+ "deleteSelection": {
+ "title": "Delete",
+ "desc": "Delete all selected images. By default, you will be prompted to confirm deletion. If the images are currently in use in the app, you will be warned."
+ }
}
},
"metadata": {
@@ -883,7 +941,8 @@
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
- "modelAccessError": "Unable to find model {{key}}, resetting to default"
+ "modelAccessError": "Unable to find model {{key}}, resetting to default",
+ "saveToGallery": "Save To Gallery"
},
"parameters": {
"aspect": "Aspect",
@@ -1669,11 +1728,13 @@
"referenceImage": "Reference Image",
"regionalReferenceImage": "Regional Reference Image",
"globalReferenceImage": "Global Reference Image",
- "sendingToCanvas": "Sending to Canvas",
- "sendingToGallery": "Sending to Gallery",
+ "sendingToCanvas": "Staging Generations on Canvas",
+ "sendingToGallery": "Sending Generations to Gallery",
"sendToGallery": "Send To Gallery",
"sendToGalleryDesc": "Pressing Invoke generates and saves a unique image to your gallery.",
"sendToCanvas": "Send To Canvas",
+ "newLayerFromImage": "New Layer from Image",
+ "newCanvasFromImage": "New Canvas from Image",
"copyToClipboard": "Copy to Clipboard",
"sendToCanvasDesc": "Pressing Invoke stages your work in progress on the canvas.",
"viewProgressInViewer": "View progress and outputs in the Image Viewer.",
@@ -1871,7 +1932,8 @@
"preserveMask": {
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
- }
+ },
+ "showOnlyRasterLayersWhileStaging": "Show Only Raster Layers While Staging"
},
"HUD": {
"bbox": "Bbox",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts
index c209574520e..0624c20393c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts
@@ -1,7 +1,8 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
-import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
+import { canvasReset } from 'features/controlLayers/store/actions';
+import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts
index 7da7ac99de6..60a8ea814fe 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts
@@ -4,7 +4,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
-import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
+import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners.ts
index a2cd92e6d43..68d7595bd9a 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners.ts
@@ -40,19 +40,22 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
});
};
-// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
-// state.canvas.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
-// if (
-// imageObject?.image.image_name === imageDTO.image_name ||
-// processedImageObject?.image.image_name === imageDTO.image_name
-// ) {
-// dispatch(caImageChanged({ id, imageDTO: null }));
-// dispatch(caProcessedImageChanged({ id, imageDTO: null }));
-// }
-// });
-// };
-
-const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
+const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
+ selectCanvasSlice(state).controlLayers.entities.forEach(({ id, objects }) => {
+ let shouldDelete = false;
+ for (const obj of objects) {
+ if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
+ shouldDelete = true;
+ break;
+ }
+ }
+ if (shouldDelete) {
+ dispatch(entityDeleted({ entityIdentifier: { id, type: 'control_layer' } }));
+ }
+ });
+};
+
+const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
@@ -60,7 +63,7 @@ const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO
});
};
-const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
+const deleteRasterLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
@@ -124,9 +127,9 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
}
deleteNodesImages(state, dispatch, imageDTO);
- // deleteControlAdapterImages(state, dispatch, imageDTO);
- deleteIPAdapterImages(state, dispatch, imageDTO);
- deleteLayerImages(state, dispatch, imageDTO);
+ deleteReferenceImages(state, dispatch, imageDTO);
+ deleteRasterLayerImages(state, dispatch, imageDTO);
+ deleteControlLayerImages(state, dispatch, imageDTO);
} catch {
// no-op
} finally {
@@ -165,9 +168,9 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
imageDTOs.forEach((imageDTO) => {
deleteNodesImages(state, dispatch, imageDTO);
- // deleteControlAdapterImages(state, dispatch, imageDTO);
- deleteIPAdapterImages(state, dispatch, imageDTO);
- deleteLayerImages(state, dispatch, imageDTO);
+ deleteControlLayerImages(state, dispatch, imageDTO);
+ deleteReferenceImages(state, dispatch, imageDTO);
+ deleteRasterLayerImages(state, dispatch, imageDTO);
});
} catch {
// no-op
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
index 13a256ad4e8..39a6494d20a 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
@@ -1,9 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
+import { bboxOptimalDimensionChanged, bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
+import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
+import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@@ -68,6 +71,11 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
+ // When staging, we don't want to change the bbox, but we must keep the optimal dimension in sync.
+ dispatch(bboxOptimalDimensionChanged({ optimalDimension: getOptimalDimension(newModel) }));
+ if (!selectIsStaging(state)) {
+ dispatch(bboxSyncedToOptimalDimension());
+ }
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
index d19020cbf13..e6662d6f2aa 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
@@ -3,12 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import {
- bboxHeightChanged,
- bboxWidthChanged,
+ bboxOptimalDimensionChanged,
+ bboxSyncedToOptimalDimension,
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
+import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
@@ -20,10 +21,9 @@ import {
} from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
-import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
-import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
+import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
@@ -95,15 +95,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
- const { bbox } = selectCanvasSlice(state);
- const optimalDimension = getOptimalDimension(defaultModelInList);
- if (getIsSizeOptimal(bbox.rect.width, bbox.rect.height, optimalDimension)) {
- return;
+ // When staging, we don't want to change the bbox, but we must keep the optimal dimension in sync.
+ dispatch(bboxOptimalDimensionChanged({ optimalDimension: getOptimalDimension(defaultModelInList) }));
+ if (!selectIsStaging(state)) {
+ dispatch(bboxSyncedToOptimalDimension());
}
- const { width, height } = calculateNewSize(bbox.aspectRatio.value, optimalDimension * optimalDimension);
-
- dispatch(bboxWidthChanged({ width }));
- dispatch(bboxHeightChanged({ height }));
return;
}
}
@@ -116,6 +112,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
}
dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
+ // When staging, we don't want to change the bbox, but we must keep the optimal dimension in sync.
+ dispatch(bboxOptimalDimensionChanged({ optimalDimension: getOptimalDimension(result.data) }));
+ if (!selectIsStaging(state)) {
+ dispatch(bboxSyncedToOptimalDimension());
+ }
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts
index 42e17b938b3..87eefe8602c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts
@@ -1,5 +1,6 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
+import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
setCfgRescaleMultiplier,
setCfgScale,
@@ -96,13 +97,15 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
const setSizeOptions = { updateAspectRatio: true, clamp: true };
- if (width) {
+
+ const isStaging = selectIsStaging(getState());
+ if (!isStaging && width) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
}
- if (height) {
+ if (!isStaging && height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
diff --git a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
index 6715f53d8ee..cacfbd0c627 100644
--- a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
+++ b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
@@ -3,36 +3,38 @@ import { addScope, removeScope, setScopes } from 'common/hooks/interactionScopes
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { useInvoke } from 'features/queue/hooks/useInvoke';
+import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
-import { useHotkeys } from 'react-hotkeys-hook';
export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch();
const isModelManagerEnabled = useFeatureStatus('modelManager');
const queue = useInvoke();
- useHotkeys(
- ['ctrl+enter', 'meta+enter'],
- queue.queueBack,
- {
+ useRegisteredHotkeys({
+ id: 'invoke',
+ category: 'app',
+ callback: queue.queueBack,
+ options: {
enabled: !queue.isDisabled && !queue.isLoading,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
- [queue]
- );
+ dependencies: [queue],
+ });
- useHotkeys(
- ['ctrl+shift+enter', 'meta+shift+enter'],
- queue.queueFront,
- {
+ useRegisteredHotkeys({
+ id: 'invokeFront',
+ category: 'app',
+ callback: queue.queueFront,
+ options: {
enabled: !queue.isDisabled && !queue.isLoading,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
- [queue]
- );
+ dependencies: [queue],
+ });
const {
cancelQueueItem,
@@ -40,75 +42,83 @@ export const useGlobalHotkeys = () => {
isLoading: isLoadingCancelQueueItem,
} = useCancelCurrentQueueItem();
- useHotkeys(
- ['shift+x'],
- cancelQueueItem,
- {
+ useRegisteredHotkeys({
+ id: 'cancelQueueItem',
+ category: 'app',
+ callback: cancelQueueItem,
+ options: {
enabled: !isDisabledCancelQueueItem && !isLoadingCancelQueueItem,
preventDefault: true,
},
- [cancelQueueItem, isDisabledCancelQueueItem, isLoadingCancelQueueItem]
- );
+ dependencies: [cancelQueueItem, isDisabledCancelQueueItem, isLoadingCancelQueueItem],
+ });
const { clearQueue, isDisabled: isDisabledClearQueue, isLoading: isLoadingClearQueue } = useClearQueue();
- useHotkeys(
- ['ctrl+shift+x', 'meta+shift+x'],
- clearQueue,
- {
+ useRegisteredHotkeys({
+ id: 'clearQueue',
+ category: 'app',
+ callback: clearQueue,
+ options: {
enabled: !isDisabledClearQueue && !isLoadingClearQueue,
preventDefault: true,
},
- [clearQueue, isDisabledClearQueue, isLoadingClearQueue]
- );
+ dependencies: [clearQueue, isDisabledClearQueue, isLoadingClearQueue],
+ });
- useHotkeys(
- '1',
- () => {
+ useRegisteredHotkeys({
+ id: 'selectCanvasTab',
+ category: 'app',
+ callback: () => {
dispatch(setActiveTab('canvas'));
addScope('canvas');
removeScope('workflows');
},
- [dispatch]
- );
+ dependencies: [dispatch],
+ });
- useHotkeys(
- '2',
- () => {
+ useRegisteredHotkeys({
+ id: 'selectUpscalingTab',
+ category: 'app',
+ callback: () => {
dispatch(setActiveTab('upscaling'));
removeScope('canvas');
removeScope('workflows');
},
- [dispatch]
- );
+ dependencies: [dispatch],
+ });
- useHotkeys(
- '3',
- () => {
+ useRegisteredHotkeys({
+ id: 'selectWorkflowsTab',
+ category: 'app',
+ callback: () => {
dispatch(setActiveTab('workflows'));
removeScope('canvas');
addScope('workflows');
},
- [dispatch]
- );
+ dependencies: [dispatch],
+ });
- useHotkeys(
- '4',
- () => {
- if (isModelManagerEnabled) {
- dispatch(setActiveTab('models'));
- setScopes([]);
- }
+ useRegisteredHotkeys({
+ id: 'selectModelsTab',
+ category: 'app',
+ callback: () => {
+ dispatch(setActiveTab('models'));
+ setScopes([]);
+ },
+ options: {
+ enabled: isModelManagerEnabled,
},
- [dispatch, isModelManagerEnabled]
- );
+ dependencies: [dispatch, isModelManagerEnabled],
+ });
- useHotkeys(
- isModelManagerEnabled ? '5' : '4',
- () => {
+ useRegisteredHotkeys({
+ id: 'selectQueueTab',
+ category: 'app',
+ callback: () => {
dispatch(setActiveTab('queue'));
setScopes([]);
},
- [dispatch, isModelManagerEnabled]
- );
+ dependencies: [dispatch, isModelManagerEnabled],
+ });
};
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo.tsx
index 11df592ad57..b8ba97680e6 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo.tsx
@@ -11,10 +11,11 @@ import {
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
+import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
- setRightPanelTabToGallery,
- setRightPanelTabToLayers,
-} from 'features/controlLayers/components/CanvasRightPanel';
+ selectCanvasRightPanelGalleryTab,
+ selectCanvasRightPanelLayersTab,
+} from 'features/controlLayers/store/ephemeral';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useCurrentDestination } from 'features/queue/hooks/useCurrentDestination';
import { selectShowSendingToAlerts, showSendingToAlertsChanged } from 'features/system/store/systemSlice';
@@ -29,7 +30,7 @@ const ActivateImageViewerButton = (props: PropsWithChildren) => {
const imageViewer = useImageViewer();
const onClick = useCallback(() => {
imageViewer.open();
- setRightPanelTabToGallery();
+ selectCanvasRightPanelGalleryTab();
}, [imageViewer]);
return (