Skip to content

Commit

Permalink
Merge pull request #13 from invoke-ai/ryan/sdxl-prep-rename-stuff
Browse files Browse the repository at this point in the history
Prep for SDXL support: rename training modes
  • Loading branch information
RyanJDick authored Aug 14, 2023
2 parents 456e6e9 + 4767ece commit f8995d0
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 50 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ A library for training custom Stable Diffusion models (fine-tuning, LoRA trainin

**WARNING:** This repo is currently under construction. More details coming soon.

## Training Modes

- Finetune *(Not implemented yet)*
- Finetune with LoRA
- Stable Diffusion v1/v2: `invoke-finetune-lora-sd`
- Stable Diffusion XL: `invoke-finetune-lora-sdxl`
- DreamBooth *(Not implemented yet)*
- DreamBooth with LoRA *(Not implemented yet)*
- Textual Inversion *(Not implemented yet)*
- Pivotal Tuning Inversion *(Not implemented yet)*
- Pivotal Tuning Inversion with LoRA *(Not implemented yet)*

## Developer Quick Start

### Setup Development Environment
Expand All @@ -23,13 +35,13 @@ There are some test 'markers' defined in [pyproject.toml](/pyproject.toml) that
pytest tests/ -m "not cuda and not loads_model"
```

### Train a LoRA
### Finetune a Stable Diffusion model with LoRA
The following steps explain how to train a basic Pokemon Style LoRA using the [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset, and how to use it in [InvokeAI](https://github.com/invoke-ai/InvokeAI).

This training process has been tested on an Nvidia GPU with 8GB of VRAM.

1. For this example, we will use the [lora_training_example.yaml]() config file. See [lora_training_config.py](/src/invoke_training/training/lora/lora_training_config.py) for the full list of supported LoRA training configs.
2. Start training with `invoke-train-lora --cfg-file configs/lora_training_example.yaml`.
1. For this example, we will use the [finetune_lora_sd_pokemon_example.yaml](/configs/finetune_lora_sd_pokemon_example.yaml) config file. See [lora_training_config.py](/src/invoke_training/training/lora/lora_training_config.py) for the full list of supported LoRA training configs.
2. Start training with `invoke-finetune-lora-sd --cfg-file configs/finetune_lora_sd_pokemon_example.yaml`.
3. Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated images for fixed prompts throughout the training process.
4. Select a checkpoint based on the quality of the generated images. As an example, we'll use the **Epoch 19** checkpoint.
5. If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is a sample config for training a Pokemon LoRA model.
# This is a sample config for finetuning a Stable Diffusion 1.5 model with LoRA to produce a Pokemon LoRA model.

output:
base_output_dir: output/
Expand All @@ -10,6 +10,7 @@ dataset:
dataset_name: lambdalabs/pokemon-blip-captions

# General
model: runwayml/stable-diffusion-v1-5
seed: 1
gradient_accumulation_steps: 1
mixed_precision: fp16
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ dependencies = [
]

[project.scripts]
"invoke-train-lora" = "invoke_training.scripts.invoke_train_lora:main"
"invoke-finetune-lora-sd" = "invoke_training.scripts.invoke_finetune_lora_sd:main"
"invoke-finetune-lora-sdxl" = "invoke_training.scripts.invoke_finetune_lora_sdxl:main"

[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
Expand Down
36 changes: 36 additions & 0 deletions src/invoke_training/scripts/invoke_finetune_lora_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
from pathlib import Path

import yaml

from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
)
from invoke_training.training.finetune_lora.finetune_lora_sd import run_training


def parse_args():
parser = argparse.ArgumentParser(description="Finetuning with LoRA for Stable Diffusion v1 and v2 base models.")
parser.add_argument(
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. See `FinetuneLoRAConfig` for the supported fields.",
)
return parser.parse_args()


def main():
args = parse_args()

# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = FinetuneLoRAConfig(**cfg)

run_training(train_config)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions src/invoke_training/scripts/invoke_finetune_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
from pathlib import Path

import yaml

from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
)
from invoke_training.training.finetune_lora.finetune_lora_sdxl import run_training


def parse_args():
parser = argparse.ArgumentParser(description="Finetuning with LoRA for Stable Diffusion XL models.")
parser.add_argument(
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. See `FinetuneLoRAConfig` for the supported fields.",
)
return parser.parse_args()


def main():
args = parse_args()

# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = FinetuneLoRAConfig(**cfg)

run_training(train_config)


if __name__ == "__main__":
main()
34 changes: 0 additions & 34 deletions src/invoke_training/scripts/invoke_train_lora.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class DatasetConfig(BaseModel):
dataloader_num_workers: int = 0


class LoRATrainingConfig(BaseModel):
class FinetuneLoRAConfig(BaseModel):
"""The configuration for a LoRA training run."""

output: TrainingOutputConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
inject_lora_into_clip_text_encoder,
inject_lora_into_unet_sd1,
)
from invoke_training.training.lora.lora_training_config import LoRATrainingConfig
from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
)
from invoke_training.training.shared.accelerator_utils import (
get_mixed_precision_dtype,
initialize_accelerator,
Expand All @@ -42,13 +44,13 @@

def _load_models(
accelerator: Accelerator,
config: LoRATrainingConfig,
config: FinetuneLoRAConfig,
) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]:
"""Load all models required for training from disk, transfer them to the
target training device and cast their weight dtypes.
Args:
config (LoRATrainingConfig): The LoRA training run config.
config (FinetuneLoRAConfig): The LoRA training run config.
logger (logging.Logger): A logger.
Returns:
Expand Down Expand Up @@ -85,7 +87,7 @@ def _load_models(
return tokenizer, noise_scheduler, text_encoder, vae, unet


def _initialize_optimizer(config: LoRATrainingConfig, trainable_params: list) -> torch.optim.Optimizer:
def _initialize_optimizer(config: FinetuneLoRAConfig, trainable_params: list) -> torch.optim.Optimizer:
"""Initialize an optimizer based on the config."""
return torch.optim.AdamW(
trainable_params,
Expand Down Expand Up @@ -137,7 +139,7 @@ def _generate_validation_images(
tokenizer: CLIPTokenizer,
noise_scheduler: DDPMScheduler,
unet: UNet2DConditionModel,
config: LoRATrainingConfig,
config: FinetuneLoRAConfig,
logger: logging.Logger,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
Expand All @@ -152,7 +154,7 @@ def _generate_validation_images(
tokenizer (CLIPTokenizer):
noise_scheduler (DDPMScheduler):
unet (UNet2DConditionModel):
config (LoRATrainingConfig): Training configs.
config (FinetuneLoRAConfig): Training configs.
logger (logging.Logger): Logger.
"""
logger.info("Generating validation images.")
Expand Down Expand Up @@ -217,7 +219,7 @@ def _generate_validation_images(


def _train_forward(
config: LoRATrainingConfig,
config: FinetuneLoRAConfig,
data_batch: dict,
vae: AutoencoderKL,
noise_scheduler: DDPMScheduler,
Expand Down Expand Up @@ -271,7 +273,7 @@ def _train_forward(
return torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")


def run_lora_training(config: LoRATrainingConfig): # noqa: C901
def run_training(config: FinetuneLoRAConfig): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
check_base_model_version(
{BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from invoke_training.training.finetune_lora.finetune_lora_config import (
FinetuneLoRAConfig,
)


def run_training(config: FinetuneLoRAConfig):
raise NotImplementedError("finetune_lora_sdxl is not implemented.")
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch.utils.data import DataLoader
from transformers import CLIPTokenizer

from invoke_training.training.lora.lora_training_config import DatasetConfig
from invoke_training.training.finetune_lora.finetune_lora_config import DatasetConfig
from invoke_training.training.shared.datasets.hf_dir_image_caption_reader import (
HFDirImageCaptionReader,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from transformers import CLIPTokenizer

from invoke_training.training.lora.lora_training_config import DatasetConfig
from invoke_training.training.finetune_lora.finetune_lora_config import DatasetConfig
from invoke_training.training.shared.datasets.image_caption_dataloader import (
build_image_caption_dataloader,
)
Expand Down

0 comments on commit f8995d0

Please sign in to comment.