Skip to content

Commit

Permalink
Merge pull request #35 from invoke-ai/ryan/dreambooth-caching
Browse files Browse the repository at this point in the history
Add VAE caching to DreamBooth scripts
  • Loading branch information
RyanJDick authored Sep 12, 2023
2 parents b670baf + 0af01fc commit 0b59346
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 52 deletions.
29 changes: 25 additions & 4 deletions src/invoke_training/training/dreambooth_lora/dreambooth_lora_sd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
import os
import tempfile
import time

import torch
Expand All @@ -16,6 +17,7 @@
)
from invoke_training.training.config.finetune_lora_config import DreamBoothLoRAConfig
from invoke_training.training.finetune_lora.finetune_lora_sd import (
cache_vae_outputs,
generate_validation_images,
load_models,
train_forward,
Expand Down Expand Up @@ -88,14 +90,32 @@ def run_training(config: DreamBoothLoRAConfig): # noqa: C901
text_encoder.to(accelerator.device, dtype=weight_dtype)

# Prepare VAE output cache.
# vae_output_cache_dir_name = None
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.instance_dataset.image_transforms.random_flip:
if config.dataset.image_transforms.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.instance_dataset.image_transforms.center_crop:
if not config.dataset.image_transforms.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")

raise NotImplementedError("'cache_vae_outputs' is not yet supported in DreamBooth training.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = build_dreambooth_sd_dataloader(
data_loader_config=config.dataset,
tokenizer=tokenizer,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)

Expand Down Expand Up @@ -143,6 +163,7 @@ def run_training(config: DreamBoothLoRAConfig): # noqa: C901
data_loader_config=config.dataset,
tokenizer=tokenizer,
batch_size=config.train_batch_size,
vae_output_cache_dir=vae_output_cache_dir_name,
)

# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
import os
import tempfile
import time

import torch
Expand All @@ -18,6 +19,7 @@
DreamBoothLoRASDXLConfig,
)
from invoke_training.training.finetune_lora.finetune_lora_sdxl import (
cache_vae_outputs,
generate_validation_images,
load_models,
train_forward,
Expand Down Expand Up @@ -90,14 +92,33 @@ def run_training(config: DreamBoothLoRASDXLConfig): # noqa: C901
text_encoder_2.to(accelerator.device, dtype=weight_dtype)

# Prepare VAE output cache.
# vae_output_cache_dir_name = None
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.dataset.image_transforms.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.dataset.image_transforms.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")

raise NotImplementedError("'cache_vae_outputs' is not yet supported in DreamBooth training.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = build_dreambooth_sdxl_dataloader(
data_loader_config=config.dataset,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)

Expand Down Expand Up @@ -152,6 +173,7 @@ def run_training(config: DreamBoothLoRASDXLConfig): # noqa: C901
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
batch_size=config.train_batch_size,
vae_output_cache_dir=vae_output_cache_dir_name,
)

# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
Expand Down
17 changes: 8 additions & 9 deletions src/invoke_training/training/finetune_lora/finetune_lora_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.hooks import remove_hook_from_module
from accelerate.utils import set_seed
Expand Down Expand Up @@ -118,19 +119,14 @@ def cache_text_encoder_outputs(
cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]})


def cache_vae_outputs(cache_dir: str, config: FinetuneLoRAConfig, tokenizer: CLIPTokenizer, vae: AutoencoderKL):
def cache_vae_outputs(cache_dir: str, data_loader: torch.utils.data.DataLoader, vae: AutoencoderKL):
"""Run the VAE on all images in the dataset and cache the results to disk.
Args:
cache_dir (str): The directory where the results will be cached.
config (FinetuneLoRAConfig): Training config.
tokenizer (CLIPTokenizer): The tokenizer.
data_loader (DataLoader): The data loader.
vae (AutoencoderKL): The VAE.
"""
data_loader = build_image_caption_sd_dataloader(
config.dataset, tokenizer, batch_size=config.train_batch_size, shuffle=False
)

cache = TensorDiskCache(cache_dir)

for data_batch in tqdm(data_loader):
Expand Down Expand Up @@ -396,10 +392,13 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
# Only the main process should populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
cache_vae_outputs(vae_output_cache_dir_name, config, tokenizer, vae)
data_loader = build_image_caption_sd_dataloader(
config.dataset, tokenizer, batch_size=config.train_batch_size, shuffle=False
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
Expand Down
14 changes: 6 additions & 8 deletions src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.hooks import remove_hook_from_module
from accelerate.utils import set_seed
Expand Down Expand Up @@ -207,9 +208,7 @@ def cache_text_encoder_outputs(

def cache_vae_outputs(
cache_dir: str,
config: FinetuneLoRASDXLConfig,
tokenizer_1: PreTrainedTokenizer,
tokenizer_2: PreTrainedTokenizer,
data_loader: torch.utils.data.DataLoader,
vae: AutoencoderKL,
):
"""Run the VAE on all images in the dataset and cache the results to disk.
Expand All @@ -220,10 +219,6 @@ def cache_vae_outputs(
tokenizer (CLIPTokenizer): The tokenizer.
vae (AutoencoderKL): The VAE.
"""
data_loader = build_image_caption_sdxl_dataloader(
config.dataset, tokenizer_1, tokenizer_2, config.train_batch_size, shuffle=False
)

cache = TensorDiskCache(cache_dir)

for data_batch in tqdm(data_loader):
Expand Down Expand Up @@ -521,7 +516,10 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
cache_vae_outputs(vae_output_cache_dir_name, config, tokenizer_1, tokenizer_2, vae)
data_loader = build_image_caption_sdxl_dataloader(
config.dataset, tokenizer_1, tokenizer_2, config.train_batch_size, shuffle=False
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,50 @@
from invoke_training.training.shared.data.transforms.constant_field_transform import (
ConstantFieldTransform,
)
from invoke_training.training.shared.data.transforms.drop_field_transform import (
DropFieldTransform,
)
from invoke_training.training.shared.data.transforms.load_cache_transform import (
LoadCacheTransform,
)
from invoke_training.training.shared.data.transforms.sd_image_transform import (
SDImageTransform,
)
from invoke_training.training.shared.data.transforms.sd_tokenize_transform import (
SDTokenizeTransform,
)
from invoke_training.training.shared.data.transforms.tensor_disk_cache import (
TensorDiskCache,
)


def build_dreambooth_sd_dataloader(
data_loader_config: DreamBoothDataLoaderConfig,
tokenizer: typing.Optional[CLIPTokenizer],
batch_size: int,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
"""Construct a DataLoader for a DreamBooth dataset for Stable Diffusion v1/v2.."""
"""Construct a DataLoader for a DreamBooth dataset for Stable Diffusion v1/v2.
Args:
data_loader_config (DreamBoothDataLoaderConfig):
tokenizer (typing.Optional[CLIPTokenizer]):
batch_size (int):
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than
interleaving class and instance examples. This is intended to be used when processing the entire dataset for
caching purposes. Defaults to False.
Returns:
DataLoader
"""

# 1. Prepare instance dataset
instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir)
instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir, id_prefix="instance_")
instance_dataset = TransformDataset(
instance_dataset,
[
Expand All @@ -48,7 +74,7 @@ def build_dreambooth_sd_dataloader(
# 2. Prepare class dataset.
class_dataset = None
if data_loader_config.class_data_dir is not None:
class_dataset = ImageDirDataset(data_loader_config.class_data_dir)
class_dataset = ImageDirDataset(data_loader_config.class_data_dir, id_prefix="class_")
class_dataset = TransformDataset(
class_dataset,
[
Expand All @@ -60,31 +86,51 @@ def build_dreambooth_sd_dataloader(

# 3. Merge instance dataset and class dataset.
merged_dataset = ConcatDataset(datasets)
all_transforms = [
SDImageTransform(
resolution=data_loader_config.image_transforms.resolution,
center_crop=data_loader_config.image_transforms.center_crop,
random_flip=data_loader_config.image_transforms.random_flip,
),
SDTokenizeTransform(tokenizer),
]
all_transforms = [SDTokenizeTransform(tokenizer)]
if vae_output_cache_dir is None:
all_transforms.append(
SDImageTransform(
resolution=data_loader_config.image_transforms.resolution,
center_crop=data_loader_config.image_transforms.center_crop,
random_flip=data_loader_config.image_transforms.random_flip,
)
)
else:
vae_cache = TensorDiskCache(vae_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=vae_cache, cache_key_field="id", cache_field_to_output_field={"vae_output": "vae_output"}
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
merged_dataset = TransformDataset(merged_dataset, all_transforms)

# 4. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset.
# 4. If sequential_batching is enabled, return a basic data loader that iterates over examples sequentially (without
# interleaving class and instance examples). This is typically only used when preparing the data cache.
if sequential_batching:
return DataLoader(
merged_dataset,
batch_size=batch_size,
num_workers=data_loader_config.dataloader_num_workers,
shuffle=shuffle,
)

# 5. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset.
samplers = []
if shuffle:
samplers.append(SequentialRangeSampler(0, len(instance_dataset)))
else:
samplers.append(ShuffledRangeSampler(0, len(instance_dataset)))

# 5. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset.
# 6. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset.
if class_dataset is not None:
if shuffle:
samplers.append(SequentialRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset)))
else:
samplers.append(ShuffledRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset)))

# 6. Interleave instance and class samplers.
# 7. Interleave instance and class samplers.
interleaved_sampler = InterleavedSampler(samplers)

return DataLoader(
Expand Down
Loading

0 comments on commit 0b59346

Please sign in to comment.