Skip to content

Commit

Permalink
Merge pull request #14 from invoke-ai/ryan/sdxl-lora-2
Browse files Browse the repository at this point in the history
Add SDXL fine-tuning with LoRA
  • Loading branch information
RyanJDick authored Aug 14, 2023
2 parents f8995d0 + 859c082 commit c41a4ac
Show file tree
Hide file tree
Showing 17 changed files with 1,178 additions and 191 deletions.
30 changes: 30 additions & 0 deletions configs/finetune_lora_sdxl_pokemon_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This is a sample config for finetuning a SDXL 1.0 model with LoRA to produce a Pokemon LoRA model.

output:
base_output_dir: output/

optimizer:
learning_rate: 5.0e-4

dataset:
dataset_name: lambdalabs/pokemon-blip-captions

# General
model: stabilityai/stable-diffusion-xl-base-1.0
vae_model: madebyollin/sdxl-vae-fp16-fix
seed: 1
train_text_encoder: False
gradient_accumulation_steps: 4
mixed_precision: fp16
xformers: True
gradient_checkpointing: True
max_train_steps: 4000
save_every_n_epochs: 1
save_every_n_steps: null
max_checkpoints: 100
validation_prompts:
- yoda
- astronaut
- yoda in a space suit
validate_every_n_epochs: 1
train_batch_size: 1
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer


def inject_lora_into_unet_sd1(
def inject_lora_into_unet(
unet: UNet2DConditionModel, include_non_attention_blocks: bool = False
) -> LoRALayerCollection:
"""Inject LoRA layers into a Stable Diffusion v1 UNet model.
"""Inject LoRA layers into a Stable Diffusion UNet model.
Args:
unet (UNet2DConditionModel): The UNet model to inject LoRA layers into.
Expand Down Expand Up @@ -45,7 +45,7 @@ def inject_lora_into_unet_sd1(
return lora_layers


def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel):
def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str = "lora_te"):
lora_layers = inject_lora_layers(
module=text_encoder,
lora_map={
Expand All @@ -54,17 +54,17 @@ def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel):
},
include_descendants_of={CLIPAttention, CLIPMLP},
exclude_descendants_of=None,
prefix="lora_te",
prefix=prefix,
dtype=torch.float32,
)

return lora_layers


def convert_lora_state_dict_to_kohya_format_sd1(
def convert_lora_state_dict_to_kohya_format(
state_dict: typing.Dict[str, torch.Tensor]
) -> typing.Dict[str, torch.Tensor]:
"""Convert a Stable Diffusion v1 LoRA state_dict from internal invoke-training format to kohya_ss format.
"""Convert a Stable Diffusion LoRA state_dict from internal invoke-training format to kohya_ss format.
Args:
state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format.
Expand Down
4 changes: 2 additions & 2 deletions src/invoke_training/scripts/invoke_finetune_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml

from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
FinetuneLoRASDXLConfig,
)
from invoke_training.training.finetune_lora.finetune_lora_sdxl import run_training

Expand All @@ -27,7 +27,7 @@ def main():
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = FinetuneLoRAConfig(**cfg)
train_config = FinetuneLoRASDXLConfig(**cfg)

run_training(train_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,10 @@ class FinetuneLoRAConfig(BaseModel):

# The training batch size.
train_batch_size: int = 4


class FinetuneLoRASDXLConfig(FinetuneLoRAConfig):
# The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
# model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
# with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
vae_model: typing.Optional[str] = None
18 changes: 10 additions & 8 deletions src/invoke_training/training/finetune_lora/finetune_lora_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training.lora.injection.stable_diffusion_v1 import (
convert_lora_state_dict_to_kohya_format_sd1,
from invoke_training.lora.injection.stable_diffusion import (
convert_lora_state_dict_to_kohya_format,
inject_lora_into_clip_text_encoder,
inject_lora_into_unet_sd1,
inject_lora_into_unet,
)
from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
Expand All @@ -36,8 +36,8 @@
check_base_model_version,
)
from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker
from invoke_training.training.shared.datasets.image_caption_dataloader import (
build_image_caption_dataloader,
from invoke_training.training.shared.datasets.image_caption_sd_dataloader import (
build_image_caption_sd_dataloader,
)
from invoke_training.training.shared.serialization import save_state_dict

Expand Down Expand Up @@ -122,7 +122,7 @@ def _save_checkpoint(
state_dict = {}
for model_lora_layers in lora_layers.values():
model_state_dict = model_lora_layers.get_lora_state_dict()
model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict)
model_kohya_state_dict = convert_lora_state_dict_to_kohya_format(model_state_dict)
state_dict.update(model_kohya_state_dict)

save_state_dict(state_dict, save_path)
Expand Down Expand Up @@ -189,6 +189,8 @@ def _generate_validation_images(
prompt,
num_inference_steps=30,
generator=generator,
height=config.dataset.resolution,
width=config.dataset.resolution,
).images[0]
)

Expand Down Expand Up @@ -312,7 +314,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901

lora_layers = torch.nn.ModuleDict()
if config.train_unet:
lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks)
lora_layers["unet"] = inject_lora_into_unet(unet, config.train_unet_non_attention_blocks)
if config.train_text_encoder:
lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder)

Expand All @@ -324,7 +326,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901

optimizer = _initialize_optimizer(config, lora_layers.parameters())

data_loader = build_image_caption_dataloader(config.dataset, tokenizer, config.train_batch_size)
data_loader = build_image_caption_sd_dataloader(config.dataset, tokenizer, config.train_batch_size)

# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
Expand Down
Loading

0 comments on commit c41a4ac

Please sign in to comment.