-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from invoke-ai/ryan/lora-training
Add initial LoRA training script
- Loading branch information
Showing
23 changed files
with
1,218 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
output/ | ||
|
||
# pyenv | ||
.python-version | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,43 @@ | ||
# InvokeTraining | ||
# invoke-training | ||
|
||
A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI). | ||
|
||
**WARNING:** This repo is currently under construction. More details coming soon. | ||
|
||
## Developer Quick Start | ||
|
||
### Setup Development Environment | ||
1. (Optional) Create a python virtual environment. | ||
2. Install dependencies: `pip install -e .[test]`. | ||
3. Run tests: `pytest tests`. | ||
4. (Optional) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (black, ruff, isort) on `git commit`. | ||
5. (Optional) Set up `black`, `isort`, and `ruff` in your IDE of choice. | ||
1. Install dependencies: `pip install -e .[test]`. | ||
1. (Optional) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (black, ruff, isort) on `git commit`. | ||
1. (Optional) Set up `black`, `isort`, and `ruff` in your IDE of choice. | ||
|
||
### Unit Tests | ||
Run all unit tests with: | ||
```bash | ||
pytest tests/ | ||
``` | ||
|
||
There are some test 'markers' defined in [pyproject.toml](/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights: | ||
```bash | ||
pytest tests/ -m "not cuda and not loads_model" | ||
``` | ||
|
||
### Train a LoRA | ||
The following steps explain how to train a basic Pokemon Style LoRA using the [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset, and how to use it in [InvokeAI](https://github.com/invoke-ai/InvokeAI). | ||
|
||
This training process has been tested on an Nvidia GPU with 8GB of VRAM. | ||
|
||
1. For this example, we will use the [lora_training_example.yaml]() config file. See [lora_training_config.py](/src/invoke_training/training/lora/lora_training_config.py) for the full list of supported LoRA training configs. | ||
2. Start training with `invoke-train-lora --cfg-file configs/lora_training_example.yaml`. | ||
3. Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated images for fixed prompts throughout the training process. | ||
4. Select a checkpoint based on the quality of the generated images. As an example, we'll use the **Epoch 19** checkpoint. | ||
5. If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. | ||
6. Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example: | ||
```bash | ||
cp output/1691088769.5694647/checkpoint_epoch-00000019.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000019.safetensors | ||
``` | ||
7. You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉 | ||
|
||
![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA.](images/invokeai_yoda_pokemon_lora.png) | ||
*Example image generated with the prompt "yoda" and Pokemon LoRA.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# This is a sample config for training a Pokemon LoRA model. | ||
|
||
output: | ||
base_output_dir: output/ | ||
|
||
optimizer: | ||
learning_rate: 1.0e-3 | ||
|
||
dataset: | ||
name: lambdalabs/pokemon-blip-captions | ||
|
||
# General | ||
seed: 1 | ||
gradient_accumulation_steps: 1 | ||
mixed_precision: fp16 | ||
xformers: True | ||
gradient_checkpointing: True | ||
max_train_steps: 4000 | ||
save_every_n_epochs: 1 | ||
save_every_n_steps: null | ||
max_checkpoints: 100 | ||
validation_prompts: | ||
- yoda | ||
- astronaut | ||
- yoda in a space suit | ||
validate_every_n_epochs: 1 | ||
train_batch_size: 4 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,21 +5,31 @@ build-backend = "setuptools.build_meta" | |
[project] | ||
name = "invoke-training" | ||
version = "0.0.1" | ||
authors = [ | ||
{ name="The Invoke AI Team", email="[email protected]" }, | ||
] | ||
authors = [{ name = "The Invoke AI Team", email = "[email protected]" }] | ||
description = "A library for Stable Diffusion model training." | ||
readme = "README.md" | ||
requires-python = ">=3.9" | ||
license = {text = "Apache-2.0"} | ||
license = { text = "Apache-2.0" } | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"Operating System :: OS Independent", | ||
] | ||
dependencies = [ | ||
"accelerate~=0.21.0", | ||
"datasets~=2.14.3", | ||
"diffusers~=0.19.3", | ||
"torch~=2.0.1", | ||
"numpy", | ||
"pydantic", | ||
"pyyaml", | ||
"safetensors", | ||
"tensorboard", | ||
"torch>=2.0.1", | ||
"torchvision~=0.15.2", | ||
"tqdm", | ||
"transformers~=4.31.0", | ||
# Known issue with xformers 0.0.16 on some GPUs: | ||
# https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212 | ||
"xformers>=0.0.17", | ||
] | ||
|
||
[project.optional-dependencies] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Oops, something went wrong.