From 504d0903741b5d5584805b1ccf15b9e3079fc936 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 4 Sep 2023 15:51:19 -0400 Subject: [PATCH 1/4] Bugfix: do not hash() the key in LoadCacheTransform. --- .../training/shared/data/transforms/load_cache_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/invoke_training/training/shared/data/transforms/load_cache_transform.py b/src/invoke_training/training/shared/data/transforms/load_cache_transform.py index b4bb5ecc..dc531e20 100644 --- a/src/invoke_training/training/shared/data/transforms/load_cache_transform.py +++ b/src/invoke_training/training/shared/data/transforms/load_cache_transform.py @@ -26,7 +26,7 @@ def __init__( def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: key = data[self._cache_key_field] - cache_data = self._cache.load(hash(key)) + cache_data = self._cache.load(key) for src, dst in self._cache_field_to_output_field.items(): data[dst] = cache_data[src] From 76e542d7b63750634d51f34db8bc25f676368f5b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 4 Sep 2023 15:54:20 -0400 Subject: [PATCH 2/4] Add support for an id_prefix in ImageDirDataset. --- .../training/shared/data/datasets/image_dir_dataset.py | 6 ++++-- .../training/shared/data/datasets/test_image_dir_dataset.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/invoke_training/training/shared/data/datasets/image_dir_dataset.py b/src/invoke_training/training/shared/data/datasets/image_dir_dataset.py index 392e92fa..9b8db1af 100644 --- a/src/invoke_training/training/shared/data/datasets/image_dir_dataset.py +++ b/src/invoke_training/training/shared/data/datasets/image_dir_dataset.py @@ -8,15 +8,17 @@ class ImageDirDataset(torch.utils.data.Dataset): """A dataset that loads image files from a directory.""" - def __init__(self, image_dir: str, image_extensions: typing.Optional[list[str]] = None): + def __init__(self, image_dir: str, id_prefix: str = "", image_extensions: typing.Optional[list[str]] = None): """Initialize an ImageDirDataset Args: image_dir (str): The directory to load images from. + id_prefix (str): A prefix added to the 'id' field in every example. image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not case-sensitive). Defaults to [".jpg", ".jpeg", ".png"]. """ super().__init__() + self._id_prefix = id_prefix if image_extensions is None: image_extensions = [".jpg", ".jpeg", ".png"] image_extensions = [ext.lower() for ext in image_extensions] @@ -34,4 +36,4 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. - return {"id": idx, "image": Image.open(self._image_paths[idx]).convert("RGB")} + return {"id": f"{self._id_prefix}{idx}", "image": Image.open(self._image_paths[idx]).convert("RGB")} diff --git a/tests/invoke_training/training/shared/data/datasets/test_image_dir_dataset.py b/tests/invoke_training/training/shared/data/datasets/test_image_dir_dataset.py index 47d8e608..7934d64e 100644 --- a/tests/invoke_training/training/shared/data/datasets/test_image_dir_dataset.py +++ b/tests/invoke_training/training/shared/data/datasets/test_image_dir_dataset.py @@ -21,4 +21,4 @@ def test_image_dir_dataset_getitem(image_dir): # noqa: F811 assert set(example.keys()) == {"image", "id"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" - assert example["id"] == 0 + assert example["id"] == "0" From 54cb7abbc05ab94da72b917bd66b3f677d8821cb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 4 Sep 2023 15:56:53 -0400 Subject: [PATCH 3/4] Add VAE caching to SD DreamBooth script. --- .../dreambooth_lora/dreambooth_lora_sd.py | 29 +++++++- .../finetune_lora/finetune_lora_sd.py | 17 ++--- .../data_loaders/dreambooth_sd_dataloader.py | 74 +++++++++++++++---- 3 files changed, 93 insertions(+), 27 deletions(-) diff --git a/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sd.py b/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sd.py index 7e7b3b94..6949eec2 100644 --- a/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sd.py +++ b/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sd.py @@ -1,6 +1,7 @@ import json import math import os +import tempfile import time import torch @@ -16,6 +17,7 @@ ) from invoke_training.training.config.finetune_lora_config import DreamBoothLoRAConfig from invoke_training.training.finetune_lora.finetune_lora_sd import ( + cache_vae_outputs, generate_validation_images, load_models, train_forward, @@ -88,14 +90,32 @@ def run_training(config: DreamBoothLoRAConfig): # noqa: C901 text_encoder.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. - # vae_output_cache_dir_name = None + vae_output_cache_dir_name = None if config.cache_vae_outputs: - if config.instance_dataset.image_transforms.random_flip: + if config.dataset.image_transforms.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") - if not config.instance_dataset.image_transforms.center_crop: + if not config.dataset.image_transforms.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") - raise NotImplementedError("'cache_vae_outputs' is not yet supported in DreamBooth training.") + # We use a temporary directory for the cache. The directory will automatically be cleaned up when + # tmp_vae_output_cache_dir is destroyed. + tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() + vae_output_cache_dir_name = tmp_vae_output_cache_dir.name + if accelerator.is_local_main_process: + # Only the main process should to populate the cache. + logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") + vae.to(accelerator.device, dtype=weight_dtype) + data_loader = build_dreambooth_sd_dataloader( + data_loader_config=config.dataset, + tokenizer=tokenizer, + batch_size=config.train_batch_size, + shuffle=False, + sequential_batching=True, + ) + cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) + # Move the VAE back to the CPU, because it is not needed for training. + vae.to("cpu") + accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) @@ -143,6 +163,7 @@ def run_training(config: DreamBoothLoRAConfig): # noqa: C901 data_loader_config=config.dataset, tokenizer=tokenizer, batch_size=config.train_batch_size, + vae_output_cache_dir=vae_output_cache_dir_name, ) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py index e1f2ecf7..09dccfce 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torch.utils.data from accelerate import Accelerator from accelerate.hooks import remove_hook_from_module from accelerate.utils import set_seed @@ -118,19 +119,14 @@ def cache_text_encoder_outputs( cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]}) -def cache_vae_outputs(cache_dir: str, config: FinetuneLoRAConfig, tokenizer: CLIPTokenizer, vae: AutoencoderKL): +def cache_vae_outputs(cache_dir: str, data_loader: torch.utils.data.DataLoader, vae: AutoencoderKL): """Run the VAE on all images in the dataset and cache the results to disk. Args: cache_dir (str): The directory where the results will be cached. - config (FinetuneLoRAConfig): Training config. - tokenizer (CLIPTokenizer): The tokenizer. + data_loader (DataLoader): The data loader. vae (AutoencoderKL): The VAE. """ - data_loader = build_image_caption_sd_dataloader( - config.dataset, tokenizer, batch_size=config.train_batch_size, shuffle=False - ) - cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): @@ -396,10 +392,13 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: - # Only the main process should to populate the cache. + # Only the main process should populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) - cache_vae_outputs(vae_output_cache_dir_name, config, tokenizer, vae) + data_loader = build_image_caption_sd_dataloader( + config.dataset, tokenizer, batch_size=config.train_batch_size, shuffle=False + ) + cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() diff --git a/src/invoke_training/training/shared/data/data_loaders/dreambooth_sd_dataloader.py b/src/invoke_training/training/shared/data/data_loaders/dreambooth_sd_dataloader.py index ae547a9c..7f2a619e 100644 --- a/src/invoke_training/training/shared/data/data_loaders/dreambooth_sd_dataloader.py +++ b/src/invoke_training/training/shared/data/data_loaders/dreambooth_sd_dataloader.py @@ -18,24 +18,50 @@ from invoke_training.training.shared.data.transforms.constant_field_transform import ( ConstantFieldTransform, ) +from invoke_training.training.shared.data.transforms.drop_field_transform import ( + DropFieldTransform, +) +from invoke_training.training.shared.data.transforms.load_cache_transform import ( + LoadCacheTransform, +) from invoke_training.training.shared.data.transforms.sd_image_transform import ( SDImageTransform, ) from invoke_training.training.shared.data.transforms.sd_tokenize_transform import ( SDTokenizeTransform, ) +from invoke_training.training.shared.data.transforms.tensor_disk_cache import ( + TensorDiskCache, +) def build_dreambooth_sd_dataloader( data_loader_config: DreamBoothDataLoaderConfig, tokenizer: typing.Optional[CLIPTokenizer], batch_size: int, + vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, + sequential_batching: bool = False, ) -> DataLoader: - """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion v1/v2..""" + """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion v1/v2. + + Args: + data_loader_config (DreamBoothDataLoaderConfig): + tokenizer (typing.Optional[CLIPTokenizer]): + batch_size (int): + vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If + set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. + shuffle (bool, optional): Whether to shuffle the dataset order. + sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than + interleaving class and instance examples. This is intended to be used when processing the entire dataset for + caching purposes. Defaults to False. + + Returns: + DataLoader + """ # 1. Prepare instance dataset - instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir) + instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir, id_prefix="instance_") instance_dataset = TransformDataset( instance_dataset, [ @@ -48,7 +74,7 @@ def build_dreambooth_sd_dataloader( # 2. Prepare class dataset. class_dataset = None if data_loader_config.class_data_dir is not None: - class_dataset = ImageDirDataset(data_loader_config.class_data_dir) + class_dataset = ImageDirDataset(data_loader_config.class_data_dir, id_prefix="class_") class_dataset = TransformDataset( class_dataset, [ @@ -60,31 +86,51 @@ def build_dreambooth_sd_dataloader( # 3. Merge instance dataset and class dataset. merged_dataset = ConcatDataset(datasets) - all_transforms = [ - SDImageTransform( - resolution=data_loader_config.image_transforms.resolution, - center_crop=data_loader_config.image_transforms.center_crop, - random_flip=data_loader_config.image_transforms.random_flip, - ), - SDTokenizeTransform(tokenizer), - ] + all_transforms = [SDTokenizeTransform(tokenizer)] + if vae_output_cache_dir is None: + all_transforms.append( + SDImageTransform( + resolution=data_loader_config.image_transforms.resolution, + center_crop=data_loader_config.image_transforms.center_crop, + random_flip=data_loader_config.image_transforms.random_flip, + ) + ) + else: + vae_cache = TensorDiskCache(vae_output_cache_dir) + all_transforms.append( + LoadCacheTransform( + cache=vae_cache, cache_key_field="id", cache_field_to_output_field={"vae_output": "vae_output"} + ) + ) + # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. + all_transforms.append(DropFieldTransform("image")) merged_dataset = TransformDataset(merged_dataset, all_transforms) - # 4. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset. + # 4. If sequential_batching is enabled, return a basic data loader that iterates over examples sequentially (without + # interleaving class and instance examples). This is typically only used when preparing the data cache. + if sequential_batching: + return DataLoader( + merged_dataset, + batch_size=batch_size, + num_workers=data_loader_config.dataloader_num_workers, + shuffle=shuffle, + ) + + # 5. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset. samplers = [] if shuffle: samplers.append(SequentialRangeSampler(0, len(instance_dataset))) else: samplers.append(ShuffledRangeSampler(0, len(instance_dataset))) - # 5. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset. + # 6. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset. if class_dataset is not None: if shuffle: samplers.append(SequentialRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset))) else: samplers.append(ShuffledRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset))) - # 6. Interleave instance and class samplers. + # 7. Interleave instance and class samplers. interleaved_sampler = InterleavedSampler(samplers) return DataLoader( From 0af01fcbb778b6afd1a38bb86fbcbb7271b53f32 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 4 Sep 2023 16:36:07 -0400 Subject: [PATCH 4/4] Add VAE caching to SDXL DreamBooth script. --- .../dreambooth_lora/dreambooth_lora_sdxl.py | 26 +++++- .../finetune_lora/finetune_lora_sdxl.py | 14 ++-- .../dreambooth_sdxl_dataloader.py | 79 ++++++++++++++++--- 3 files changed, 98 insertions(+), 21 deletions(-) diff --git a/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sdxl.py b/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sdxl.py index f45c3c41..e35b072f 100644 --- a/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sdxl.py +++ b/src/invoke_training/training/dreambooth_lora/dreambooth_lora_sdxl.py @@ -1,6 +1,7 @@ import json import math import os +import tempfile import time import torch @@ -18,6 +19,7 @@ DreamBoothLoRASDXLConfig, ) from invoke_training.training.finetune_lora.finetune_lora_sdxl import ( + cache_vae_outputs, generate_validation_images, load_models, train_forward, @@ -90,14 +92,33 @@ def run_training(config: DreamBoothLoRASDXLConfig): # noqa: C901 text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. - # vae_output_cache_dir_name = None + vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.dataset.image_transforms.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.dataset.image_transforms.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") - raise NotImplementedError("'cache_vae_outputs' is not yet supported in DreamBooth training.") + # We use a temporary directory for the cache. The directory will automatically be cleaned up when + # tmp_vae_output_cache_dir is destroyed. + tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() + vae_output_cache_dir_name = tmp_vae_output_cache_dir.name + if accelerator.is_local_main_process: + # Only the main process should to populate the cache. + logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") + vae.to(accelerator.device, dtype=weight_dtype) + data_loader = build_dreambooth_sdxl_dataloader( + data_loader_config=config.dataset, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + batch_size=config.train_batch_size, + shuffle=False, + sequential_batching=True, + ) + cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) + # Move the VAE back to the CPU, because it is not needed for training. + vae.to("cpu") + accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) @@ -152,6 +173,7 @@ def run_training(config: DreamBoothLoRASDXLConfig): # noqa: C901 tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, batch_size=config.train_batch_size, + vae_output_cache_dir=vae_output_cache_dir_name, ) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index 5d11d1a2..3b8739c1 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torch.utils.data from accelerate import Accelerator from accelerate.hooks import remove_hook_from_module from accelerate.utils import set_seed @@ -207,9 +208,7 @@ def cache_text_encoder_outputs( def cache_vae_outputs( cache_dir: str, - config: FinetuneLoRASDXLConfig, - tokenizer_1: PreTrainedTokenizer, - tokenizer_2: PreTrainedTokenizer, + data_loader: torch.utils.data.DataLoader, vae: AutoencoderKL, ): """Run the VAE on all images in the dataset and cache the results to disk. @@ -220,10 +219,6 @@ def cache_vae_outputs( tokenizer (CLIPTokenizer): The tokenizer. vae (AutoencoderKL): The VAE. """ - data_loader = build_image_caption_sdxl_dataloader( - config.dataset, tokenizer_1, tokenizer_2, config.train_batch_size, shuffle=False - ) - cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): @@ -521,7 +516,10 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 # Only the main process should to populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) - cache_vae_outputs(vae_output_cache_dir_name, config, tokenizer_1, tokenizer_2, vae) + data_loader = build_image_caption_sdxl_dataloader( + config.dataset, tokenizer_1, tokenizer_2, config.train_batch_size, shuffle=False + ) + cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() diff --git a/src/invoke_training/training/shared/data/data_loaders/dreambooth_sdxl_dataloader.py b/src/invoke_training/training/shared/data/data_loaders/dreambooth_sdxl_dataloader.py index b749cd47..15b8ed57 100644 --- a/src/invoke_training/training/shared/data/data_loaders/dreambooth_sdxl_dataloader.py +++ b/src/invoke_training/training/shared/data/data_loaders/dreambooth_sdxl_dataloader.py @@ -1,3 +1,5 @@ +import typing + from torch.utils.data import ConcatDataset, DataLoader from transformers import PreTrainedTokenizer @@ -19,12 +21,21 @@ from invoke_training.training.shared.data.transforms.constant_field_transform import ( ConstantFieldTransform, ) +from invoke_training.training.shared.data.transforms.drop_field_transform import ( + DropFieldTransform, +) +from invoke_training.training.shared.data.transforms.load_cache_transform import ( + LoadCacheTransform, +) from invoke_training.training.shared.data.transforms.sd_tokenize_transform import ( SDTokenizeTransform, ) from invoke_training.training.shared.data.transforms.sdxl_image_transform import ( SDXLImageTransform, ) +from invoke_training.training.shared.data.transforms.tensor_disk_cache import ( + TensorDiskCache, +) def build_dreambooth_sdxl_dataloader( @@ -32,12 +43,29 @@ def build_dreambooth_sdxl_dataloader( tokenizer_1: PreTrainedTokenizer, tokenizer_2: PreTrainedTokenizer, batch_size: int, + vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, + sequential_batching: bool = False, ) -> DataLoader: - """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL.""" + """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL. + + Args: + data_loader_config (DreamBoothDataLoaderConfig): + tokenizer_1 (PreTrainedTokenizer): Tokenizer 1. + tokenizer_2 (PreTrainedTokenizer): Tokenizer 2. + batch_size (int): + vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If + set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. + shuffle (bool, optional): Whether to shuffle the dataset order. + sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than + interleaving class and instance examples. This is intended to be used when processing the entire dataset for + caching purposes. Defaults to False. + Returns: + DataLoader + """ # 1. Prepare instance dataset - instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir) + instance_dataset = ImageDirDataset(data_loader_config.instance_data_dir, id_prefix="instance_") instance_dataset = TransformDataset( instance_dataset, [ @@ -50,7 +78,7 @@ def build_dreambooth_sdxl_dataloader( # 2. Prepare class dataset. class_dataset = None if data_loader_config.class_data_dir is not None: - class_dataset = ImageDirDataset(data_loader_config.class_data_dir) + class_dataset = ImageDirDataset(data_loader_config.class_data_dir, id_prefix="class_") class_dataset = TransformDataset( class_dataset, [ @@ -63,31 +91,60 @@ def build_dreambooth_sdxl_dataloader( # 3. Merge instance dataset and class dataset. merged_dataset = ConcatDataset(datasets) all_transforms = [ - SDXLImageTransform( - resolution=data_loader_config.image_transforms.resolution, - center_crop=data_loader_config.image_transforms.center_crop, - random_flip=data_loader_config.image_transforms.random_flip, - ), SDTokenizeTransform(tokenizer_1, src_caption_key="caption", dst_token_key="caption_token_ids_1"), SDTokenizeTransform(tokenizer_2, src_caption_key="caption", dst_token_key="caption_token_ids_2"), ] + if vae_output_cache_dir is None: + all_transforms.append( + SDXLImageTransform( + resolution=data_loader_config.image_transforms.resolution, + center_crop=data_loader_config.image_transforms.center_crop, + random_flip=data_loader_config.image_transforms.random_flip, + ) + ) + else: + vae_cache = TensorDiskCache(vae_output_cache_dir) + all_transforms.append( + LoadCacheTransform( + cache=vae_cache, + cache_key_field="id", + cache_field_to_output_field={ + "vae_output": "vae_output", + "original_size_hw": "original_size_hw", + "crop_top_left_yx": "crop_top_left_yx", + }, + ) + ) + # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. + all_transforms.append(DropFieldTransform("image")) merged_dataset = TransformDataset(merged_dataset, all_transforms) - # 4. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset. + # 4. If sequential_batching is enabled, return a basic data loader that iterates over examples sequentially (without + # interleaving class and instance examples). This is typically only used when preparing the data cache. + if sequential_batching: + return DataLoader( + merged_dataset, + collate_fn=sdxl_image_caption_collate_fn, + batch_size=batch_size, + num_workers=data_loader_config.dataloader_num_workers, + shuffle=shuffle, + ) + + # 5. Prepare instance dataset sampler. Note that the instance_dataset comes first in the merged_dataset. samplers = [] if shuffle: samplers.append(SequentialRangeSampler(0, len(instance_dataset))) else: samplers.append(ShuffledRangeSampler(0, len(instance_dataset))) - # 5. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset. + # 6. Prepare class dataset sampler. Note that the class_dataset comes first in the merged_dataset. if class_dataset is not None: if shuffle: samplers.append(SequentialRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset))) else: samplers.append(ShuffledRangeSampler(len(instance_dataset), len(instance_dataset) + len(class_dataset))) - # 6. Interleave instance and class samplers. + # 7. Interleave instance and class samplers. interleaved_sampler = InterleavedSampler(samplers) return DataLoader(