Skip to content

Commit

Permalink
Merge pull request #36 from invoke-ai/feat/textual-inversion
Browse files Browse the repository at this point in the history
Add Textual Inversion for SD1
  • Loading branch information
RyanJDick authored Dec 6, 2023
2 parents 0b59346 + 31b050f commit c7eaaf2
Show file tree
Hide file tree
Showing 13 changed files with 974 additions and 21 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"invoke-dreambooth-lora-sd" = "invoke_training.scripts.invoke_dreambooth_lora_sd:main"
"invoke-dreambooth-lora-sdxl" = "invoke_training.scripts.invoke_dreambooth_lora_sdxl:main"
"invoke-generate-images" = "invoke_training.scripts.invoke_generate_images:main"
"invoke-textual-inversion-sd" = "invoke_training.scripts.invoke_textual_inversion_sd:main"

[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
Expand Down
38 changes: 38 additions & 0 deletions src/invoke_training/scripts/invoke_textual_inversion_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
from pathlib import Path

import yaml

from invoke_training.training.config.textual_inversion_config import (
TextualInversionConfig,
)
from invoke_training.training.textual_inversion.textual_inversion_sd import run_training


def parse_args():
parser = argparse.ArgumentParser(
description="Textual inversion training for Stable Diffusion v1 and v2 base models."
)
parser.add_argument(
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. See `TextualInversionConfig` for the supported fields.",
)
return parser.parse_args()


def main():
args = parse_args()

# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = TextualInversionConfig(**cfg)

run_training(train_config)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions src/invoke_training/training/config/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,18 @@ class DreamBoothDataLoaderConfig(BaseModel):

# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
dataloader_num_workers: int = 0


class TextualInversionDataLoaderConfig(BaseModel):
# The directory to load images from.
dataset_dir: str

# The image file extensions to include in the dataset.
# If None, then the following file extensions will be loaded: [".png", ".jpg", ".jpeg"].
image_file_extensions: typing.Optional[list[str]] = None

# The image transforms to apply to all instance and class dataset images.
image_transforms: ImageTransformConfig = ImageTransformConfig()

# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
dataloader_num_workers: int = 0
20 changes: 1 addition & 19 deletions src/invoke_training/training/config/finetune_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,7 @@
ImageCaptionDataLoaderConfig,
)
from invoke_training.training.config.optimizer_config import OptimizerConfig


class TrainingOutputConfig(BaseModel):
"""Configuration for a training run's output."""

# The output directory where the training outputs (model checkpoints, logs,
# intermediate predictions) will be written. A subdirectory will be created
# with a timestamp for each new training run.
base_output_dir: str

# The integration to report results and logs to ('all', 'tensorboard',
# 'wandb', or 'comet_ml'). This value is passed to Hugging Face Accelerate.
# See accelerate.Accelerator.log_with for more details.
report_to: typing.Optional[typing.Literal["all", "tensorboard", "wandb", "comet_ml"]] = "tensorboard"

# The file type to save the model as.
# Note that "ckpt" and "pt" are alternative file extensions for the same
# file format.
save_model_as: typing.Literal["ckpt", "pt", "safetensors"] = "safetensors"
from invoke_training.training.config.training_output_config import TrainingOutputConfig


class LoRATrainingConfig(BaseModel):
Expand Down
107 changes: 107 additions & 0 deletions src/invoke_training/training/config/textual_inversion_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import typing

from pydantic import BaseModel

from invoke_training.training.config.data_config import TextualInversionDataLoaderConfig
from invoke_training.training.config.optimizer_config import OptimizerConfig
from invoke_training.training.config.training_output_config import TrainingOutputConfig


class TextualInversionTrainingConfig(BaseModel):
"""The base configuration for any Textual Inversion training run."""

# Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file.
# (E.g. 'runwayml/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-xl-base-1.0',
# '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )
model: str = "runwayml/stable-diffusion-v1-5"

# A seed for reproducible training.
seed: typing.Optional[int] = None

# The number of textual inversion placeholder vectors that will be used to learn the concept.
num_vectors: int = 1

# The special word to associate the learned embeddings with. You must use this trigger word in your prompt at
# inference time.
# TODO(ryand): Rename to placeholder_str - seems more appropriate.
placeholder_token: str

# A vocabulary token to use as an initializer for the placeholder token(s). It should be a single word that roughly
# describes the object or style that you're trying to train on. Must map to a single tokenizer token. Either
# initializer_token or initial_embedding_file should be set.
initializer_token: typing.Optional[str] = None

# Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder
# token in the file must match the 'placeholder_token' field. Either initializer_token or initial_embedding_file
# should be set.
initial_embedding_file: typing.Optional[str] = None

# Whether you're training the model to learn a new "style" or a new "object".
learnable_property: typing.Literal["object", "style"] = "object"

# If True, the VAE will be applied to all of the images in the dataset before starting training and the results will
# be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
# speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
# non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
cache_vae_outputs: bool = False

# If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
# images. This reduces VRAM requirements at the cost of slower generation of validation images.
enable_cpu_offload_during_validation: bool = False

# The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
# Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
gradient_accumulation_steps: int = 1

# The mixed precision mode to use ('no','fp16','bf16 or 'fp8'). This value is passed to Hugging Face Accelerate. See
# accelerate.Accelerator for more details.
mixed_precision: typing.Optional[typing.Literal["no", "fp16", "bf16", "fp8"]] = None

# If true, use xformers for more efficient attention blocks.
xformers: bool = False

# Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
# gradient checkpointing slows down training by ~20%.
gradient_checkpointing: bool = False

# Total number of training steps to perform. (One training step is one gradient update.)
max_train_steps: int = 5000

# The interval (in epochs) at which to save checkpoints. If None, checkpoint won't be triggered by this setting. It
# is recommend to only set one of save_every_n_epochs and save_every_n_steps to a non-None value.
save_every_n_epochs: typing.Optional[int] = 1

# The interval (in steps) at which to save checkpoints. If None, checkpoint won't be triggered by this setting. It
# is recommend to only set one of save_every_n_epochs and save_every_n_steps to a non-None value.
save_every_n_steps: typing.Optional[int] = None

# The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
# limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
max_checkpoints: typing.Optional[int] = None

# The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
# If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
prediction_type: typing.Optional[typing.Literal["epsilon", "v_prediction"]] = None

# Max gradient norm for clipping. Set to None for no clipping.
max_grad_norm: typing.Optional[float] = 1.0

# A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
# See also 'validate_every_n_epochs'.
validation_prompts: list[str] = []

# The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
# become quite slow if this number is too large.
num_validation_images_per_prompt: int = 4

# The interval (in epochs) at which validation images will be generated.
validate_every_n_epochs: int = 1

# The training batch size.
train_batch_size: int = 4


class TextualInversionConfig(TextualInversionTrainingConfig):
output: TrainingOutputConfig
optimizer: OptimizerConfig
dataset: TextualInversionDataLoaderConfig
22 changes: 22 additions & 0 deletions src/invoke_training/training/config/training_output_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import typing

from pydantic import BaseModel


class TrainingOutputConfig(BaseModel):
"""Configuration for a training run's output."""

# The output directory where the training outputs (model checkpoints, logs,
# intermediate predictions) will be written. A subdirectory will be created
# with a timestamp for each new training run.
base_output_dir: str

# The integration to report results and logs to ('all', 'tensorboard',
# 'wandb', or 'comet_ml'). This value is passed to Hugging Face Accelerate.
# See accelerate.Accelerator.log_with for more details.
report_to: typing.Optional[typing.Literal["all", "tensorboard", "wandb", "comet_ml"]] = "tensorboard"

# The file type to save the model as.
# Note that "ckpt" and "pt" are alternative file extensions for the same
# file format.
save_model_as: typing.Literal["ckpt", "pt", "safetensors"] = "safetensors"
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def train_forward(
text_encoder: CLIPTextModel,
unet: UNet2DConditionModel,
weight_dtype: torch.dtype,
):
) -> torch.Tensor:
"""Run the forward training pass for a single data_batch.
Returns:
Expand Down Expand Up @@ -289,7 +289,7 @@ def train_forward(
# The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here.
encoder_hidden_states = data_batch.get("text_encoder_output", None)
if encoder_hidden_states is None:
encoder_hidden_states = text_encoder(data_batch["caption_token_ids"])[0]
encoder_hidden_states = text_encoder(data_batch["caption_token_ids"])[0].to(dtype=weight_dtype)

# Get the target for loss depending on the prediction type.
if config.prediction_type is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import typing

from torch.utils.data import DataLoader
from transformers import CLIPTokenizer

from invoke_training.training.config.data_config import TextualInversionDataLoaderConfig
from invoke_training.training.shared.data.datasets.image_dir_dataset import (
ImageDirDataset,
)
from invoke_training.training.shared.data.datasets.transform_dataset import (
TransformDataset,
)
from invoke_training.training.shared.data.transforms.drop_field_transform import (
DropFieldTransform,
)
from invoke_training.training.shared.data.transforms.load_cache_transform import (
LoadCacheTransform,
)
from invoke_training.training.shared.data.transforms.sd_image_transform import (
SDImageTransform,
)
from invoke_training.training.shared.data.transforms.sd_tokenize_transform import (
SDTokenizeTransform,
)
from invoke_training.training.shared.data.transforms.tensor_disk_cache import (
TensorDiskCache,
)
from invoke_training.training.shared.data.transforms.textual_inversion_caption_transform import (
TextualInversionCaptionTransform,
)


def _get_default_textual_inversion_prompt_templates(learnable_property: typing.Literal["object", "style"]) -> list[str]:
if learnable_property == "object":
return [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
elif learnable_property == "style":
return [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
else:
raise ValueError(f"Unrecognized learnable property type: '{learnable_property}'.")


def build_textual_inversion_sd_dataloader(
config: TextualInversionDataLoaderConfig,
placeholder_str: str,
learnable_property: typing.Literal["object", "style"],
tokenizer: typing.Optional[CLIPTokenizer],
batch_size: int,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion v1/v2..
Args:
config (ImageCaptionDataLoaderConfig): The dataset config.
placeholder_str (str): The placeholder string being trained.
learnable_property (str): One of ["object", "style"] indicating the type of training being performed.
tokenizer (CLIPTokenizer, option): The tokenizer to apply to the captions. Can be None if
`text_encoder_output_cache_dir` is set.
batch_size (int): The DataLoader batch size.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""

base_dataset = ImageDirDataset(image_dir=config.dataset_dir, image_extensions=config.image_file_extensions)

all_transforms = [
TextualInversionCaptionTransform(
field_name="caption",
placeholder_str=placeholder_str,
caption_templates=_get_default_textual_inversion_prompt_templates(learnable_property),
),
SDTokenizeTransform(tokenizer),
]

if vae_output_cache_dir is None:
all_transforms.append(
SDImageTransform(
resolution=config.image_transforms.resolution,
center_crop=config.image_transforms.center_crop,
random_flip=config.image_transforms.random_flip,
)
)
else:
vae_cache = TensorDiskCache(vae_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=vae_cache, cache_key_field="id", cache_field_to_output_field={"vae_output": "vae_output"}
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))

dataset = TransformDataset(base_dataset, all_transforms)

return DataLoader(
dataset,
shuffle=shuffle,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
Loading

0 comments on commit c7eaaf2

Please sign in to comment.