From fd428ee0c59d53b9ed0da0e34fffedbb246704e5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 8 Aug 2023 11:21:18 -0400 Subject: [PATCH] Add LoRA training of non-attention UNet blocks. --- .../lora/injection/stable_diffusion_v1.py | 13 +++-- .../training/lora/lora_training.py | 2 +- .../training/lora/lora_training_config.py | 5 ++ .../injection/test_stable_diffusion_v1.py | 48 ++++++++++++++----- 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/invoke_training/lora/injection/stable_diffusion_v1.py b/src/invoke_training/lora/injection/stable_diffusion_v1.py index 622ecafa..440a870f 100644 --- a/src/invoke_training/lora/injection/stable_diffusion_v1.py +++ b/src/invoke_training/lora/injection/stable_diffusion_v1.py @@ -3,6 +3,7 @@ import torch from diffusers.models import Transformer2DModel, UNet2DConditionModel from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from transformers import CLIPTextModel from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention @@ -11,15 +12,21 @@ from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer -def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection: +def inject_lora_into_unet_sd1( + unet: UNet2DConditionModel, include_non_attention_blocks: bool = False +) -> LoRALayerCollection: """Inject LoRA layers into a Stable Diffusion v1 UNet model. Args: unet (UNet2DConditionModel): The UNet model to inject LoRA layers into. - + include_non_attention_blocks (bool, optional): Whether to inject LoRA layers into the linear/conv layers of the + non-attention blocks (`ResnetBlock2D`, `Downsample2D`, `Upsample2D`). Defaults to False. Returns: LoRALayerCollection: The LoRA layers that were added to the UNet. """ + include_descendants_of = {Transformer2DModel} + if include_non_attention_blocks: + include_descendants_of.update({ResnetBlock2D, Downsample2D, Upsample2D}) lora_layers = inject_lora_layers( module=unet, @@ -29,7 +36,7 @@ def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection torch.nn.Conv2d: LoRAConv2dLayer, LoRACompatibleConv: LoRAConv2dLayer, }, - include_descendants_of={Transformer2DModel}, + include_descendants_of=include_descendants_of, exclude_descendants_of=None, prefix="lora_unet", dtype=torch.float32, diff --git a/src/invoke_training/training/lora/lora_training.py b/src/invoke_training/training/lora/lora_training.py index bbb9dfc0..4bd0ed18 100644 --- a/src/invoke_training/training/lora/lora_training.py +++ b/src/invoke_training/training/lora/lora_training.py @@ -383,7 +383,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901 lora_layers = torch.nn.ModuleDict() if config.train_unet: - lora_layers["unet"] = inject_lora_into_unet_sd1(unet) + lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks) if config.train_text_encoder: lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder) diff --git a/src/invoke_training/training/lora/lora_training_config.py b/src/invoke_training/training/lora/lora_training_config.py index cf0e4a5c..09ca6961 100644 --- a/src/invoke_training/training/lora/lora_training_config.py +++ b/src/invoke_training/training/lora/lora_training_config.py @@ -113,6 +113,11 @@ class LoRATrainingConfig(BaseModel): # Whether to add LoRA layers to the text encoder and train it. train_text_encoder: bool = True + # Whether to inject LoRA layers into the non-attention UNet blocks for training. Enabling will produce a more + # expressive LoRA model at the cost of slower training, higher training VRAM requirements, and a larger LoRA weight + # file. + train_unet_non_attention_blocks: bool = False + # The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face # Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. gradient_accumulation_steps: int = 1 diff --git a/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py b/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py index e32c849e..6209a40a 100644 --- a/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py +++ b/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py @@ -10,16 +10,19 @@ ) -@pytest.mark.loads_model -def test_inject_lora_into_unet_sd1_smoke(): - """Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model.""" - unet = UNet2DConditionModel.from_pretrained( +@pytest.fixture +def unet(): + return UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", local_files_only=True, revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", ) + +@pytest.mark.loads_model +def test_inject_lora_into_unet_sd1_smoke(unet): + """Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model.""" lora_layers = inject_lora_into_unet_sd1(unet) # These assertions are based on a manual check of the injected layers and comparison against the behaviour of @@ -31,6 +34,34 @@ def test_inject_lora_into_unet_sd1_smoke(): ) +@pytest.mark.loads_model +def test_inject_lora_into_unet_sd1_non_attention_layers_smoke(unet): + """Smoke test of inject_lora_into_unet_sd1(..., include_non_attention_blocks=True) on full SD 1.5 model.""" + lora_layers = inject_lora_into_unet_sd1(unet, include_non_attention_blocks=True) + + # These assertions are based on a manual check of the injected layers and comparison against the behaviour of + # kohya_ss. They are included here to force another manual review after any future breaking change. + assert len(lora_layers) == 278 + for layer_name in lora_layers._names: + assert layer_name.endswith( + ( + "to_q", + "to_k", + "to_v", + "to_out.0", + "ff.net.0.proj", + "ff.net.2", + ".proj_in", + ".proj_out", + ".conv1", + ".conv2", + ".time_emb_proj", + ".conv", + ".conv_shortcut", + ) + ) + + @pytest.mark.loads_model def test_inject_lora_into_clip_text_encoder_smoke(): """Smoke test of inject_lora_into_clip_text_encoder(...) on full SD 1.5 model.""" @@ -51,15 +82,8 @@ def test_inject_lora_into_clip_text_encoder_smoke(): @pytest.mark.loads_model -def test_convert_lora_state_dict_to_kohya_format_sd1_smoke(): +def test_convert_lora_state_dict_to_kohya_format_sd1_smoke(unet): """Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model.""" - unet = UNet2DConditionModel.from_pretrained( - "runwayml/stable-diffusion-v1-5", - subfolder="unet", - local_files_only=True, - revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", - ) - lora_layers = inject_lora_into_unet_sd1(unet) lora_state_dict = lora_layers.get_lora_state_dict() kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(lora_state_dict)