-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from invoke-ai/ryan/lora-injection-sd1
Add basic linear LoRA support for Stable Diffusion v1 UNet
- Loading branch information
Showing
14 changed files
with
304 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
43 changes: 43 additions & 0 deletions
43
src/invoke_training/lora/injection/lora_layer_collection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import typing | ||
|
||
import torch | ||
|
||
from invoke_training.lora.layers import BaseLoRALayer | ||
|
||
|
||
class LoRALayerCollection(torch.nn.Module): | ||
"""A collection of LoRA layers (with names). Typically used to perform operations on a group of LoRA layers during | ||
training. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
# A torch.nn.ModuleDict may seem like a more natural choice here, but it does not allow keys that contain '.' | ||
# characters. Using a standard python dict is also inconvenient, because it would be ignored by torch.nn.Module | ||
# methods such as `.parameters()` and `.train()`. | ||
self._layers = torch.nn.ModuleList() | ||
self._names = [] | ||
|
||
def add_layer(self, layer: BaseLoRALayer, name: str): | ||
self._layers.append(layer) | ||
self._names.append(name) | ||
|
||
def __len__(self): | ||
return len(self._layers) | ||
|
||
def get_lora_state_dict(self) -> typing.Dict[str, torch.Tensor]: | ||
"""A custom alternative to .state_dict() that uses the layer names provided to add_layer(...) as key | ||
prefixes. | ||
""" | ||
state_dict: typing.Dict[str, torch.Tensor] = {} | ||
|
||
for name, layer in zip(self._names, self._layers): | ||
layer_state_dict = layer.state_dict() | ||
for key, state in layer_state_dict.items(): | ||
full_key = name + "." + key | ||
if full_key in state_dict: | ||
raise RuntimeError(f"Multiple state elements map to the same key: '{full_key}'.") | ||
state_dict[full_key] = state | ||
|
||
return state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import typing | ||
|
||
import torch | ||
from diffusers.models import Transformer2DModel, UNet2DConditionModel | ||
from diffusers.models.lora import LoRACompatibleLinear | ||
|
||
from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection | ||
from invoke_training.lora.injection.utils import inject_lora_layers | ||
from invoke_training.lora.layers import LoRALinearLayer | ||
|
||
|
||
def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection: | ||
"""Inject LoRA layers into a Stable Diffusion v1 UNet model. | ||
Args: | ||
unet (UNet2DConditionModel): The UNet model to inject LoRA layers into. | ||
Returns: | ||
LoRALayerCollection: The LoRA layers that were added to the UNet. | ||
""" | ||
|
||
lora_layers = inject_lora_layers( | ||
module=unet, | ||
lora_map={torch.nn.Linear: LoRALinearLayer, LoRACompatibleLinear: LoRALinearLayer}, | ||
include_descendants_of={Transformer2DModel}, | ||
exclude_descendants_of=None, | ||
prefix="lora_unet", | ||
) | ||
|
||
return lora_layers | ||
|
||
|
||
def convert_lora_state_dict_to_kohya_format_sd1( | ||
state_dict: typing.Dict[str, torch.Tensor] | ||
) -> typing.Dict[str, torch.Tensor]: | ||
"""Convert a Stable Diffusion v1 LoRA state_dict from internal invoke-training format to kohya_ss format. | ||
Args: | ||
state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format. | ||
Raises: | ||
ValueError: If state_dict contains unexpected keys. | ||
RuntimeError: If two input keys map to the same output kohya_ss key. | ||
Returns: | ||
typing.Dict[str, torch.Tensor]: LoRA layer state_dict in kohya_ss format. | ||
""" | ||
new_state_dict = {} | ||
|
||
# The following logic converts state_dict keys from the internal invoke-training format to kohya_ss format. | ||
# Example conversion: | ||
# from: 'lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight' | ||
# to: 'lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight' | ||
for key, val in state_dict.items(): | ||
if key.endswith("._up.weight"): | ||
key_start = key.removesuffix("._up.weight") | ||
key_end = ".lora_up.weight" | ||
elif key.endswith("._down.weight"): | ||
key_start = key.removesuffix("._down.weight") | ||
key_end = ".lora_down.weight" | ||
elif key.endswith(".alpha"): | ||
key_start = key.removesuffix(".alpha") | ||
key_end = ".alpha" | ||
else: | ||
raise ValueError(f"Unexpected key in state_dict: '{key}'.") | ||
|
||
new_key = key_start.replace(".", "_") + key_end | ||
|
||
if new_key in new_state_dict: | ||
raise RuntimeError("Multiple input keys map to the same kohya_ss key: '{new_key}'.") | ||
|
||
new_state_dict[new_key] = val | ||
|
||
return new_state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
37 changes: 37 additions & 0 deletions
37
tests/invoke_training/lora/injection/test_lora_layer_collection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
|
||
from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection | ||
from invoke_training.lora.layers import LoRALinearLayer | ||
|
||
|
||
def test_lora_layer_collection_state_dict(): | ||
"""Test the behavior of LoRALayerCollection.get_lora_state_dict().""" | ||
lora_layers = LoRALayerCollection() | ||
|
||
lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1") | ||
lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_2") | ||
|
||
state_dict = lora_layers.get_lora_state_dict() | ||
|
||
expected_state_keys = { | ||
"lora_layer_1._down.weight", | ||
"lora_layer_1._up.weight", | ||
"lora_layer_1.alpha", | ||
"lora_layer_2._down.weight", | ||
"lora_layer_2._up.weight", | ||
"lora_layer_2.alpha", | ||
} | ||
assert set(state_dict.keys()) == expected_state_keys | ||
|
||
|
||
def test_lora_layer_collection_state_dict_conflicting_keys(): | ||
"""Test that LoRALayerCollection.get_lora_state_dict() raises an exception if state Tensors have conflicting | ||
keys. | ||
""" | ||
lora_layers = LoRALayerCollection() | ||
|
||
lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1") | ||
lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_1") # Insert same layer type with same key. | ||
|
||
with pytest.raises(RuntimeError): | ||
_ = lora_layers.get_lora_state_dict() |
99 changes: 99 additions & 0 deletions
99
tests/invoke_training/lora/injection/test_stable_diffusion_v1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pytest | ||
import torch | ||
from diffusers.models import UNet2DConditionModel | ||
|
||
from invoke_training.lora.injection.stable_diffusion_v1 import ( | ||
convert_lora_state_dict_to_kohya_format_sd1, | ||
inject_lora_into_unet_sd1, | ||
) | ||
|
||
|
||
@pytest.mark.loads_model | ||
def test_inject_lora_into_unet_sd1_smoke(): | ||
"""Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model.""" | ||
unet = UNet2DConditionModel.from_pretrained( | ||
"runwayml/stable-diffusion-v1-5", | ||
subfolder="unet", | ||
local_files_only=True, | ||
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", | ||
) | ||
|
||
lora_layers = inject_lora_into_unet_sd1(unet) | ||
|
||
# These assertions are based on a manual check of the injected layers and comparison against the behaviour of | ||
# kohya_ss. They are included here to force another manual review after any future breaking change. | ||
assert len(lora_layers) == 160 | ||
# assert len(lora_layers) == 192 # TODO(ryand): Enable this check once conv layers are added. | ||
for layer_name in lora_layers._names: | ||
assert layer_name.endswith(("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2")) | ||
|
||
|
||
@pytest.mark.loads_model | ||
def test_convert_lora_state_dict_to_kohya_format_sd1_smoke(): | ||
"""Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model.""" | ||
unet = UNet2DConditionModel.from_pretrained( | ||
"runwayml/stable-diffusion-v1-5", | ||
subfolder="unet", | ||
local_files_only=True, | ||
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", | ||
) | ||
|
||
lora_layers = inject_lora_into_unet_sd1(unet) | ||
lora_state_dict = lora_layers.get_lora_state_dict() | ||
kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(lora_state_dict) | ||
|
||
# These assertions are based on a manual check of the injected layers and comparison against the behaviour of | ||
# kohya_ss. They are included here to force another manual review after any future breaking change. | ||
assert len(kohya_state_dict) == 160 * 3 | ||
for key in kohya_state_dict.keys(): | ||
assert key.startswith("lora_unet_") | ||
assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha")) | ||
|
||
|
||
def test_convert_lora_state_dict_to_kohya_format_sd1(): | ||
"""Basic test of convert_lora_state_dict_to_kohya_format_sd1(...).""" | ||
down_weight = torch.Tensor(4, 2) | ||
up_weight = torch.Tensor(2, 4) | ||
alpha = torch.Tensor([1.0]) | ||
in_state_dict = { | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": down_weight, | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._up.weight": up_weight, | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.alpha": alpha, | ||
} | ||
|
||
out_state_dict = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) | ||
|
||
expected_out_state_dict = { | ||
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight": down_weight, | ||
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight": up_weight, | ||
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.alpha": alpha, | ||
} | ||
|
||
assert out_state_dict == expected_out_state_dict | ||
|
||
|
||
def test_convert_lora_state_dict_to_kohya_format_sd1_unexpected_key(): | ||
"""Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if it receives an unexpected | ||
key. | ||
""" | ||
in_state_dict = { | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.unexpected": torch.Tensor(4, 2), | ||
} | ||
|
||
with pytest.raises(ValueError): | ||
_ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) | ||
|
||
|
||
def test_convert_lora_state_dict_to_kohya_format_sd1_conflicting_keys(): | ||
"""Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if multiple keys map to the same | ||
output key. | ||
""" | ||
# Note: There are differences in the '.' and '_' characters of these keys, but they both map to the same output | ||
# kohya_ss keys. | ||
in_state_dict = { | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": torch.Tensor(4, 2), | ||
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1_to_q._down.weight": torch.Tensor(4, 2), | ||
} | ||
|
||
with pytest.raises(RuntimeError): | ||
_ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.