-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #36 from invoke-ai/feat/textual-inversion
Add Textual Inversion for SD1
- Loading branch information
Showing
13 changed files
with
974 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
src/invoke_training/scripts/invoke_textual_inversion_sd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import yaml | ||
|
||
from invoke_training.training.config.textual_inversion_config import ( | ||
TextualInversionConfig, | ||
) | ||
from invoke_training.training.textual_inversion.textual_inversion_sd import run_training | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Textual inversion training for Stable Diffusion v1 and v2 base models." | ||
) | ||
parser.add_argument( | ||
"--cfg-file", | ||
type=Path, | ||
required=True, | ||
help="Path to the YAML training config file. See `TextualInversionConfig` for the supported fields.", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
# Load YAML config file. | ||
with open(args.cfg_file, "r") as f: | ||
cfg = yaml.safe_load(f) | ||
|
||
train_config = TextualInversionConfig(**cfg) | ||
|
||
run_training(train_config) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
src/invoke_training/training/config/textual_inversion_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import typing | ||
|
||
from pydantic import BaseModel | ||
|
||
from invoke_training.training.config.data_config import TextualInversionDataLoaderConfig | ||
from invoke_training.training.config.optimizer_config import OptimizerConfig | ||
from invoke_training.training.config.training_output_config import TrainingOutputConfig | ||
|
||
|
||
class TextualInversionTrainingConfig(BaseModel): | ||
"""The base configuration for any Textual Inversion training run.""" | ||
|
||
# Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. | ||
# (E.g. 'runwayml/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-xl-base-1.0', | ||
# '/path/to/realisticVisionV51_v51VAE.safetensors', etc. ) | ||
model: str = "runwayml/stable-diffusion-v1-5" | ||
|
||
# A seed for reproducible training. | ||
seed: typing.Optional[int] = None | ||
|
||
# The number of textual inversion placeholder vectors that will be used to learn the concept. | ||
num_vectors: int = 1 | ||
|
||
# The special word to associate the learned embeddings with. You must use this trigger word in your prompt at | ||
# inference time. | ||
# TODO(ryand): Rename to placeholder_str - seems more appropriate. | ||
placeholder_token: str | ||
|
||
# A vocabulary token to use as an initializer for the placeholder token(s). It should be a single word that roughly | ||
# describes the object or style that you're trying to train on. Must map to a single tokenizer token. Either | ||
# initializer_token or initial_embedding_file should be set. | ||
initializer_token: typing.Optional[str] = None | ||
|
||
# Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder | ||
# token in the file must match the 'placeholder_token' field. Either initializer_token or initial_embedding_file | ||
# should be set. | ||
initial_embedding_file: typing.Optional[str] = None | ||
|
||
# Whether you're training the model to learn a new "style" or a new "object". | ||
learnable_property: typing.Literal["object", "style"] = "object" | ||
|
||
# If True, the VAE will be applied to all of the images in the dataset before starting training and the results will | ||
# be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and | ||
# speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all | ||
# non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). | ||
cache_vae_outputs: bool = False | ||
|
||
# If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation | ||
# images. This reduces VRAM requirements at the cost of slower generation of validation images. | ||
enable_cpu_offload_during_validation: bool = False | ||
|
||
# The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face | ||
# Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. | ||
gradient_accumulation_steps: int = 1 | ||
|
||
# The mixed precision mode to use ('no','fp16','bf16 or 'fp8'). This value is passed to Hugging Face Accelerate. See | ||
# accelerate.Accelerator for more details. | ||
mixed_precision: typing.Optional[typing.Literal["no", "fp16", "bf16", "fp8"]] = None | ||
|
||
# If true, use xformers for more efficient attention blocks. | ||
xformers: bool = False | ||
|
||
# Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling | ||
# gradient checkpointing slows down training by ~20%. | ||
gradient_checkpointing: bool = False | ||
|
||
# Total number of training steps to perform. (One training step is one gradient update.) | ||
max_train_steps: int = 5000 | ||
|
||
# The interval (in epochs) at which to save checkpoints. If None, checkpoint won't be triggered by this setting. It | ||
# is recommend to only set one of save_every_n_epochs and save_every_n_steps to a non-None value. | ||
save_every_n_epochs: typing.Optional[int] = 1 | ||
|
||
# The interval (in steps) at which to save checkpoints. If None, checkpoint won't be triggered by this setting. It | ||
# is recommend to only set one of save_every_n_epochs and save_every_n_steps to a non-None value. | ||
save_every_n_steps: typing.Optional[int] = None | ||
|
||
# The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this | ||
# limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. | ||
max_checkpoints: typing.Optional[int] = None | ||
|
||
# The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. | ||
# If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. | ||
prediction_type: typing.Optional[typing.Literal["epsilon", "v_prediction"]] = None | ||
|
||
# Max gradient norm for clipping. Set to None for no clipping. | ||
max_grad_norm: typing.Optional[float] = 1.0 | ||
|
||
# A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. | ||
# See also 'validate_every_n_epochs'. | ||
validation_prompts: list[str] = [] | ||
|
||
# The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can | ||
# become quite slow if this number is too large. | ||
num_validation_images_per_prompt: int = 4 | ||
|
||
# The interval (in epochs) at which validation images will be generated. | ||
validate_every_n_epochs: int = 1 | ||
|
||
# The training batch size. | ||
train_batch_size: int = 4 | ||
|
||
|
||
class TextualInversionConfig(TextualInversionTrainingConfig): | ||
output: TrainingOutputConfig | ||
optimizer: OptimizerConfig | ||
dataset: TextualInversionDataLoaderConfig |
22 changes: 22 additions & 0 deletions
22
src/invoke_training/training/config/training_output_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import typing | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class TrainingOutputConfig(BaseModel): | ||
"""Configuration for a training run's output.""" | ||
|
||
# The output directory where the training outputs (model checkpoints, logs, | ||
# intermediate predictions) will be written. A subdirectory will be created | ||
# with a timestamp for each new training run. | ||
base_output_dir: str | ||
|
||
# The integration to report results and logs to ('all', 'tensorboard', | ||
# 'wandb', or 'comet_ml'). This value is passed to Hugging Face Accelerate. | ||
# See accelerate.Accelerator.log_with for more details. | ||
report_to: typing.Optional[typing.Literal["all", "tensorboard", "wandb", "comet_ml"]] = "tensorboard" | ||
|
||
# The file type to save the model as. | ||
# Note that "ckpt" and "pt" are alternative file extensions for the same | ||
# file format. | ||
save_model_as: typing.Literal["ckpt", "pt", "safetensors"] = "safetensors" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
src/invoke_training/training/shared/data/data_loaders/textual_inversion_sd_dataloader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import typing | ||
|
||
from torch.utils.data import DataLoader | ||
from transformers import CLIPTokenizer | ||
|
||
from invoke_training.training.config.data_config import TextualInversionDataLoaderConfig | ||
from invoke_training.training.shared.data.datasets.image_dir_dataset import ( | ||
ImageDirDataset, | ||
) | ||
from invoke_training.training.shared.data.datasets.transform_dataset import ( | ||
TransformDataset, | ||
) | ||
from invoke_training.training.shared.data.transforms.drop_field_transform import ( | ||
DropFieldTransform, | ||
) | ||
from invoke_training.training.shared.data.transforms.load_cache_transform import ( | ||
LoadCacheTransform, | ||
) | ||
from invoke_training.training.shared.data.transforms.sd_image_transform import ( | ||
SDImageTransform, | ||
) | ||
from invoke_training.training.shared.data.transforms.sd_tokenize_transform import ( | ||
SDTokenizeTransform, | ||
) | ||
from invoke_training.training.shared.data.transforms.tensor_disk_cache import ( | ||
TensorDiskCache, | ||
) | ||
from invoke_training.training.shared.data.transforms.textual_inversion_caption_transform import ( | ||
TextualInversionCaptionTransform, | ||
) | ||
|
||
|
||
def _get_default_textual_inversion_prompt_templates(learnable_property: typing.Literal["object", "style"]) -> list[str]: | ||
if learnable_property == "object": | ||
return [ | ||
"a photo of a {}", | ||
"a rendering of a {}", | ||
"a cropped photo of the {}", | ||
"the photo of a {}", | ||
"a photo of a clean {}", | ||
"a photo of a dirty {}", | ||
"a dark photo of the {}", | ||
"a photo of my {}", | ||
"a photo of the cool {}", | ||
"a close-up photo of a {}", | ||
"a bright photo of the {}", | ||
"a cropped photo of a {}", | ||
"a photo of the {}", | ||
"a good photo of the {}", | ||
"a photo of one {}", | ||
"a close-up photo of the {}", | ||
"a rendition of the {}", | ||
"a photo of the clean {}", | ||
"a rendition of a {}", | ||
"a photo of a nice {}", | ||
"a good photo of a {}", | ||
"a photo of the nice {}", | ||
"a photo of the small {}", | ||
"a photo of the weird {}", | ||
"a photo of the large {}", | ||
"a photo of a cool {}", | ||
"a photo of a small {}", | ||
] | ||
elif learnable_property == "style": | ||
return [ | ||
"a painting in the style of {}", | ||
"a rendering in the style of {}", | ||
"a cropped painting in the style of {}", | ||
"the painting in the style of {}", | ||
"a clean painting in the style of {}", | ||
"a dirty painting in the style of {}", | ||
"a dark painting in the style of {}", | ||
"a picture in the style of {}", | ||
"a cool painting in the style of {}", | ||
"a close-up painting in the style of {}", | ||
"a bright painting in the style of {}", | ||
"a cropped painting in the style of {}", | ||
"a good painting in the style of {}", | ||
"a close-up painting in the style of {}", | ||
"a rendition in the style of {}", | ||
"a nice painting in the style of {}", | ||
"a small painting in the style of {}", | ||
"a weird painting in the style of {}", | ||
"a large painting in the style of {}", | ||
] | ||
else: | ||
raise ValueError(f"Unrecognized learnable property type: '{learnable_property}'.") | ||
|
||
|
||
def build_textual_inversion_sd_dataloader( | ||
config: TextualInversionDataLoaderConfig, | ||
placeholder_str: str, | ||
learnable_property: typing.Literal["object", "style"], | ||
tokenizer: typing.Optional[CLIPTokenizer], | ||
batch_size: int, | ||
vae_output_cache_dir: typing.Optional[str] = None, | ||
shuffle: bool = True, | ||
) -> DataLoader: | ||
"""Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion v1/v2.. | ||
Args: | ||
config (ImageCaptionDataLoaderConfig): The dataset config. | ||
placeholder_str (str): The placeholder string being trained. | ||
learnable_property (str): One of ["object", "style"] indicating the type of training being performed. | ||
tokenizer (CLIPTokenizer, option): The tokenizer to apply to the captions. Can be None if | ||
`text_encoder_output_cache_dir` is set. | ||
batch_size (int): The DataLoader batch size. | ||
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If | ||
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. | ||
shuffle (bool, optional): Whether to shuffle the dataset order. | ||
Returns: | ||
DataLoader | ||
""" | ||
|
||
base_dataset = ImageDirDataset(image_dir=config.dataset_dir, image_extensions=config.image_file_extensions) | ||
|
||
all_transforms = [ | ||
TextualInversionCaptionTransform( | ||
field_name="caption", | ||
placeholder_str=placeholder_str, | ||
caption_templates=_get_default_textual_inversion_prompt_templates(learnable_property), | ||
), | ||
SDTokenizeTransform(tokenizer), | ||
] | ||
|
||
if vae_output_cache_dir is None: | ||
all_transforms.append( | ||
SDImageTransform( | ||
resolution=config.image_transforms.resolution, | ||
center_crop=config.image_transforms.center_crop, | ||
random_flip=config.image_transforms.random_flip, | ||
) | ||
) | ||
else: | ||
vae_cache = TensorDiskCache(vae_output_cache_dir) | ||
all_transforms.append( | ||
LoadCacheTransform( | ||
cache=vae_cache, cache_key_field="id", cache_field_to_output_field={"vae_output": "vae_output"} | ||
) | ||
) | ||
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. | ||
all_transforms.append(DropFieldTransform("image")) | ||
|
||
dataset = TransformDataset(base_dataset, all_transforms) | ||
|
||
return DataLoader( | ||
dataset, | ||
shuffle=shuffle, | ||
batch_size=batch_size, | ||
num_workers=config.dataloader_num_workers, | ||
) |
Oops, something went wrong.