Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA training of non-attention UNet layers #11

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/invoke_training/lora/injection/stable_diffusion_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from diffusers.models import Transformer2DModel, UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from transformers import CLIPTextModel
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention

Expand All @@ -11,15 +12,21 @@
from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer


def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection:
def inject_lora_into_unet_sd1(
unet: UNet2DConditionModel, include_non_attention_blocks: bool = False
) -> LoRALayerCollection:
"""Inject LoRA layers into a Stable Diffusion v1 UNet model.

Args:
unet (UNet2DConditionModel): The UNet model to inject LoRA layers into.

include_non_attention_blocks (bool, optional): Whether to inject LoRA layers into the linear/conv layers of the
non-attention blocks (`ResnetBlock2D`, `Downsample2D`, `Upsample2D`). Defaults to False.
Returns:
LoRALayerCollection: The LoRA layers that were added to the UNet.
"""
include_descendants_of = {Transformer2DModel}
if include_non_attention_blocks:
include_descendants_of.update({ResnetBlock2D, Downsample2D, Upsample2D})

lora_layers = inject_lora_layers(
module=unet,
Expand All @@ -29,7 +36,7 @@ def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection
torch.nn.Conv2d: LoRAConv2dLayer,
LoRACompatibleConv: LoRAConv2dLayer,
},
include_descendants_of={Transformer2DModel},
include_descendants_of=include_descendants_of,
exclude_descendants_of=None,
prefix="lora_unet",
dtype=torch.float32,
Expand Down
2 changes: 1 addition & 1 deletion src/invoke_training/training/lora/lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901

lora_layers = torch.nn.ModuleDict()
if config.train_unet:
lora_layers["unet"] = inject_lora_into_unet_sd1(unet)
lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks)
if config.train_text_encoder:
lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder)

Expand Down
5 changes: 5 additions & 0 deletions src/invoke_training/training/lora/lora_training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class LoRATrainingConfig(BaseModel):
# Whether to add LoRA layers to the text encoder and train it.
train_text_encoder: bool = True

# Whether to inject LoRA layers into the non-attention UNet blocks for training. Enabling will produce a more
# expressive LoRA model at the cost of slower training, higher training VRAM requirements, and a larger LoRA weight
# file.
train_unet_non_attention_blocks: bool = False

# The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
# Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
gradient_accumulation_steps: int = 1
Expand Down
48 changes: 36 additions & 12 deletions tests/invoke_training/lora/injection/test_stable_diffusion_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@
)


@pytest.mark.loads_model
def test_inject_lora_into_unet_sd1_smoke():
"""Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model."""
unet = UNet2DConditionModel.from_pretrained(
@pytest.fixture
def unet():
return UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
local_files_only=True,
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0",
)


@pytest.mark.loads_model
def test_inject_lora_into_unet_sd1_smoke(unet):
"""Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model."""
lora_layers = inject_lora_into_unet_sd1(unet)

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
Expand All @@ -31,6 +34,34 @@ def test_inject_lora_into_unet_sd1_smoke():
)


@pytest.mark.loads_model
def test_inject_lora_into_unet_sd1_non_attention_layers_smoke(unet):
"""Smoke test of inject_lora_into_unet_sd1(..., include_non_attention_blocks=True) on full SD 1.5 model."""
lora_layers = inject_lora_into_unet_sd1(unet, include_non_attention_blocks=True)

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(lora_layers) == 278
for layer_name in lora_layers._names:
assert layer_name.endswith(
(
"to_q",
"to_k",
"to_v",
"to_out.0",
"ff.net.0.proj",
"ff.net.2",
".proj_in",
".proj_out",
".conv1",
".conv2",
".time_emb_proj",
".conv",
".conv_shortcut",
)
)


@pytest.mark.loads_model
def test_inject_lora_into_clip_text_encoder_smoke():
"""Smoke test of inject_lora_into_clip_text_encoder(...) on full SD 1.5 model."""
Expand All @@ -51,15 +82,8 @@ def test_inject_lora_into_clip_text_encoder_smoke():


@pytest.mark.loads_model
def test_convert_lora_state_dict_to_kohya_format_sd1_smoke():
def test_convert_lora_state_dict_to_kohya_format_sd1_smoke(unet):
"""Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model."""
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
local_files_only=True,
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0",
)

lora_layers = inject_lora_into_unet_sd1(unet)
lora_state_dict = lora_layers.get_lora_state_dict()
kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(lora_state_dict)
Expand Down