Skip to content

Commit

Permalink
Merge pull request #9 from invoke-ai/ryan/text-encoder-lora
Browse files Browse the repository at this point in the history
Add support for training a CLIP text encoder with LoRA
  • Loading branch information
RyanJDick authored Aug 8, 2023
2 parents bcf1a91 + fe4b6b1 commit b7e1afb
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
18 changes: 18 additions & 0 deletions src/invoke_training/lora/injection/stable_diffusion_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from diffusers.models import Transformer2DModel, UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from transformers import CLIPTextModel
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.injection.utils import inject_lora_layers
Expand Down Expand Up @@ -36,6 +38,22 @@ def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection
return lora_layers


def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel):
lora_layers = inject_lora_layers(
module=text_encoder,
lora_map={
torch.nn.Linear: LoRALinearLayer,
torch.nn.Conv2d: LoRAConv2dLayer,
},
include_descendants_of={CLIPAttention, CLIPMLP},
exclude_descendants_of=None,
prefix="lora_te",
dtype=torch.float32,
)

return lora_layers


def convert_lora_state_dict_to_kohya_format_sd1(
state_dict: typing.Dict[str, torch.Tensor]
) -> typing.Dict[str, torch.Tensor]:
Expand Down
42 changes: 26 additions & 16 deletions src/invoke_training/training/lora/lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.injection.stable_diffusion_v1 import (
convert_lora_state_dict_to_kohya_format_sd1,
inject_lora_into_clip_text_encoder,
inject_lora_into_unet_sd1,
)
from invoke_training.training.lora.lora_training_config import LoRATrainingConfig
Expand Down Expand Up @@ -178,26 +178,32 @@ def _initialize_optimizer(config: LoRATrainingConfig, trainable_params: list) ->

def _save_checkpoint(
idx: int,
lora_layers: LoRALayerCollection,
lora_layers: torch.nn.ModuleDict,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
):
"""Save a checkpoint. Old checkpoints are deleted if necessary to respect the config.max_checkpoints config.
Args:
idx (int): The checkpoint index (typically step count or epoch).
lora_layers (LoRALayerCollection): The LoRA layers to save.
lora_layers (torch.nn.ModuleDict): The LoRA layers to save in a ModuleDict mapping keys to
`LoRALayerCollection`s.
logger (logging.Logger): Logger.
checkpoint_tracker (CheckpointTracker): The checkpoint tracker.
"""
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(idx)

state_dict = lora_layers.get_lora_state_dict()
kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(state_dict)
save_state_dict(kohya_state_dict, save_path)
state_dict = {}
for model_lora_layers in lora_layers.values():
model_state_dict = model_lora_layers.get_lora_state_dict()
model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict)
state_dict.update(model_kohya_state_dict)

save_state_dict(state_dict, save_path)
# accelerator.save_state(save_path)
logger.info(f"Saved state to '{save_path}'.")

Expand Down Expand Up @@ -373,15 +379,19 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = _load_models(accelerator, config)

unet_lora_layers = inject_lora_into_unet_sd1(unet)
lora_layers = torch.nn.ModuleDict()
if config.train_unet:
lora_layers["unet"] = inject_lora_into_unet_sd1(unet)
if config.train_text_encoder:
lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder)

if config.xformers:
import xformers # noqa: F401

unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()

optimizer = _initialize_optimizer(config, unet_lora_layers.parameters())
optimizer = _initialize_optimizer(config, lora_layers.parameters())

data_loader = initialize_hf_dataloader(config.dataset, accelerator, tokenizer, config.train_batch_size)

Expand All @@ -400,12 +410,12 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
prepared_result: tuple[
UNet2DConditionModel,
CLIPTextModel,
LoRALayerCollection,
torch.nn.ModuleDict,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(unet, text_encoder, unet_lora_layers, optimizer, data_loader, lr_scheduler)
unet, text_encoder, unet_lora_layers, optimizer, data_loader, lr_scheduler = prepared_result
] = accelerator.prepare(unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler)
unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result

# Calculate the number of epochs and total training steps. A "step" represents a single weight update operation
# (i.e. takes into account gradient accumulation steps).
Expand Down Expand Up @@ -452,11 +462,11 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
progress_bar.set_description("Steps")

for epoch in range(first_epoch, num_train_epochs):
unet_lora_layers.train()
lora_layers.train()

train_loss = 0.0
for data_batch in data_loader:
with accelerator.accumulate(unet_lora_layers):
with accelerator.accumulate(lora_layers):
loss = _train_forward(
config,
data_batch,
Expand All @@ -475,7 +485,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = unet_lora_layers.parameters()
params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand All @@ -491,7 +501,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
if config.save_every_n_steps is not None and (global_step + 1) % config.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_checkpoint(global_step + 1, unet_lora_layers, logger, step_checkpoint_tracker)
_save_checkpoint(global_step + 1, lora_layers, logger, step_checkpoint_tracker)

logs = {
"step_loss": loss.detach().item(),
Expand All @@ -505,7 +515,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and (epoch + 1) % config.save_every_n_epochs == 0:
if accelerator.is_main_process:
_save_checkpoint(epoch + 1, unet_lora_layers, logger, epoch_checkpoint_tracker)
_save_checkpoint(epoch + 1, lora_layers, logger, epoch_checkpoint_tracker)
accelerator.wait_for_everyone()

# Generate validation images every n epochs.
Expand Down
6 changes: 6 additions & 0 deletions src/invoke_training/training/lora/lora_training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class LoRATrainingConfig(BaseModel):
# A seed for reproducible training.
seed: typing.Optional[int] = None

# Whether to add LoRA layers to the UNet model and train it.
train_unet: bool = True

# Whether to add LoRA layers to the text encoder and train it.
train_text_encoder: bool = True

# The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
# Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
gradient_accumulation_steps: int = 1
Expand Down
21 changes: 21 additions & 0 deletions tests/invoke_training/lora/injection/test_stable_diffusion_v1.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
import torch
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel

from invoke_training.lora.injection.stable_diffusion_v1 import (
convert_lora_state_dict_to_kohya_format_sd1,
inject_lora_into_clip_text_encoder,
inject_lora_into_unet_sd1,
)

Expand All @@ -29,6 +31,25 @@ def test_inject_lora_into_unet_sd1_smoke():
)


@pytest.mark.loads_model
def test_inject_lora_into_clip_text_encoder_smoke():
"""Smoke test of inject_lora_into_clip_text_encoder(...) on full SD 1.5 model."""
text_encoder = CLIPTextModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="text_encoder",
local_files_only=True,
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0",
)

lora_layers = inject_lora_into_clip_text_encoder(text_encoder)

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(lora_layers) == 72 # 216 / 3
for layer_name in lora_layers._names:
assert layer_name.endswith(("mlp.fc1", "mlp.fc2", "k_proj", "out_proj", "q_proj", "v_proj"))


@pytest.mark.loads_model
def test_convert_lora_state_dict_to_kohya_format_sd1_smoke():
"""Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model."""
Expand Down

0 comments on commit b7e1afb

Please sign in to comment.