Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SDXL fine-tuning with LoRA #14

Merged
merged 10 commits into from
Aug 14, 2023
30 changes: 30 additions & 0 deletions configs/finetune_lora_sdxl_pokemon_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This is a sample config for finetuning a SDXL 1.0 model with LoRA to produce a Pokemon LoRA model.

output:
base_output_dir: output/

optimizer:
learning_rate: 5.0e-4

dataset:
dataset_name: lambdalabs/pokemon-blip-captions

# General
model: stabilityai/stable-diffusion-xl-base-1.0
vae_model: madebyollin/sdxl-vae-fp16-fix
seed: 1
train_text_encoder: False
gradient_accumulation_steps: 4
mixed_precision: fp16
xformers: True
gradient_checkpointing: True
max_train_steps: 4000
save_every_n_epochs: 1
save_every_n_steps: null
max_checkpoints: 100
validation_prompts:
- yoda
- astronaut
- yoda in a space suit
validate_every_n_epochs: 1
train_batch_size: 1
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer


def inject_lora_into_unet_sd1(
def inject_lora_into_unet(
unet: UNet2DConditionModel, include_non_attention_blocks: bool = False
) -> LoRALayerCollection:
"""Inject LoRA layers into a Stable Diffusion v1 UNet model.
"""Inject LoRA layers into a Stable Diffusion UNet model.

Args:
unet (UNet2DConditionModel): The UNet model to inject LoRA layers into.
Expand Down Expand Up @@ -45,7 +45,7 @@ def inject_lora_into_unet_sd1(
return lora_layers


def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel):
def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str = "lora_te"):
lora_layers = inject_lora_layers(
module=text_encoder,
lora_map={
Expand All @@ -54,17 +54,17 @@ def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel):
},
include_descendants_of={CLIPAttention, CLIPMLP},
exclude_descendants_of=None,
prefix="lora_te",
prefix=prefix,
dtype=torch.float32,
)

return lora_layers


def convert_lora_state_dict_to_kohya_format_sd1(
def convert_lora_state_dict_to_kohya_format(
state_dict: typing.Dict[str, torch.Tensor]
) -> typing.Dict[str, torch.Tensor]:
"""Convert a Stable Diffusion v1 LoRA state_dict from internal invoke-training format to kohya_ss format.
"""Convert a Stable Diffusion LoRA state_dict from internal invoke-training format to kohya_ss format.

Args:
state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format.
Expand Down
4 changes: 2 additions & 2 deletions src/invoke_training/scripts/invoke_finetune_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml

from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
FinetuneLoRASDXLConfig,
)
from invoke_training.training.finetune_lora.finetune_lora_sdxl import run_training

Expand All @@ -27,7 +27,7 @@ def main():
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = FinetuneLoRAConfig(**cfg)
train_config = FinetuneLoRASDXLConfig(**cfg)

run_training(train_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,10 @@ class FinetuneLoRAConfig(BaseModel):

# The training batch size.
train_batch_size: int = 4


class FinetuneLoRASDXLConfig(FinetuneLoRAConfig):
# The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
# model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
# with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
vae_model: typing.Optional[str] = None
18 changes: 10 additions & 8 deletions src/invoke_training/training/finetune_lora/finetune_lora_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training.lora.injection.stable_diffusion_v1 import (
convert_lora_state_dict_to_kohya_format_sd1,
from invoke_training.lora.injection.stable_diffusion import (
convert_lora_state_dict_to_kohya_format,
inject_lora_into_clip_text_encoder,
inject_lora_into_unet_sd1,
inject_lora_into_unet,
)
from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
Expand All @@ -36,8 +36,8 @@
check_base_model_version,
)
from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker
from invoke_training.training.shared.datasets.image_caption_dataloader import (
build_image_caption_dataloader,
from invoke_training.training.shared.datasets.image_caption_sd_dataloader import (
build_image_caption_sd_dataloader,
)
from invoke_training.training.shared.serialization import save_state_dict

Expand Down Expand Up @@ -122,7 +122,7 @@ def _save_checkpoint(
state_dict = {}
for model_lora_layers in lora_layers.values():
model_state_dict = model_lora_layers.get_lora_state_dict()
model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict)
model_kohya_state_dict = convert_lora_state_dict_to_kohya_format(model_state_dict)
state_dict.update(model_kohya_state_dict)

save_state_dict(state_dict, save_path)
Expand Down Expand Up @@ -189,6 +189,8 @@ def _generate_validation_images(
prompt,
num_inference_steps=30,
generator=generator,
height=config.dataset.resolution,
width=config.dataset.resolution,
).images[0]
)

Expand Down Expand Up @@ -312,7 +314,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901

lora_layers = torch.nn.ModuleDict()
if config.train_unet:
lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks)
lora_layers["unet"] = inject_lora_into_unet(unet, config.train_unet_non_attention_blocks)
if config.train_text_encoder:
lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder)

Expand All @@ -324,7 +326,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901

optimizer = _initialize_optimizer(config, lora_layers.parameters())

data_loader = build_image_caption_dataloader(config.dataset, tokenizer, config.train_batch_size)
data_loader = build_image_caption_sd_dataloader(config.dataset, tokenizer, config.train_batch_size)

# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
Expand Down
Loading