From a6d5ceb0cdfd52dcf7aeff4bed19e436c7dd0685 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Aug 2023 11:53:37 -0400 Subject: [PATCH 01/10] Copy finetune_lora_sd.py to finetune_lora_sdxl.py as a starting point for the SDXL finetune w/ LoRA script. --- .../finetune_lora/finetune_lora_sdxl.py | 465 +++++++++++++++++- 1 file changed, 463 insertions(+), 2 deletions(-) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index fa7f8060..9793c744 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -1,7 +1,468 @@ +import json +import logging +import math +import os +import time + +import numpy as np +import torch +from accelerate import Accelerator +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from invoke_training.lora.injection.stable_diffusion_v1 import ( + convert_lora_state_dict_to_kohya_format_sd1, + inject_lora_into_clip_text_encoder, + inject_lora_into_unet_sd1, +) from invoke_training.training.finetune_lora.finetune_lora_config import ( FinetuneLoRAConfig, ) +from invoke_training.training.shared.accelerator_utils import ( + get_mixed_precision_dtype, + initialize_accelerator, + initialize_logging, +) +from invoke_training.training.shared.base_model_version import ( + BaseModelVersionEnum, + check_base_model_version, +) +from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker +from invoke_training.training.shared.datasets.image_caption_dataloader import ( + build_image_caption_dataloader, +) +from invoke_training.training.shared.serialization import save_state_dict + + +def _load_models( + accelerator: Accelerator, + config: FinetuneLoRAConfig, +) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]: + """Load all models required for training from disk, transfer them to the + target training device and cast their weight dtypes. + + Args: + config (FinetuneLoRAConfig): The LoRA training run config. + logger (logging.Logger): A logger. + + Returns: + tuple[ + CLIPTokenizer, + DDPMScheduler, + CLIPTextModel, + AutoencoderKL, + UNet2DConditionModel, + ]: A tuple of loaded models. + """ + weight_dtype = get_mixed_precision_dtype(accelerator) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(config.model, subfolder="tokenizer") + noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(config.model, subfolder="scheduler") + text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(config.model, subfolder="text_encoder") + vae: AutoencoderKL = AutoencoderKL.from_pretrained(config.model, subfolder="vae") + unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(config.model, subfolder="unet") + + # Disable gradient calculation for model weights to save memory. + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + unet.requires_grad_(False) + + # Put models in 'eval' mode. + text_encoder.eval() + vae.eval() + unet.eval() + + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + + return tokenizer, noise_scheduler, text_encoder, vae, unet + + +def _initialize_optimizer(config: FinetuneLoRAConfig, trainable_params: list) -> torch.optim.Optimizer: + """Initialize an optimizer based on the config.""" + return torch.optim.AdamW( + trainable_params, + lr=config.optimizer.learning_rate, + betas=(config.optimizer.adam_beta1, config.optimizer.adam_beta2), + weight_decay=config.optimizer.adam_weight_decay, + eps=config.optimizer.adam_epsilon, + ) + + +def _save_checkpoint( + idx: int, + lora_layers: torch.nn.ModuleDict, + logger: logging.Logger, + checkpoint_tracker: CheckpointTracker, +): + """Save a checkpoint. Old checkpoints are deleted if necessary to respect the config.max_checkpoints config. + + Args: + idx (int): The checkpoint index (typically step count or epoch). + lora_layers (torch.nn.ModuleDict): The LoRA layers to save in a ModuleDict mapping keys to + `LoRALayerCollection`s. + logger (logging.Logger): Logger. + checkpoint_tracker (CheckpointTracker): The checkpoint tracker. + """ + # Prune checkpoints and get new checkpoint path. + num_pruned = checkpoint_tracker.prune(1) + if num_pruned > 0: + logger.info(f"Pruned {num_pruned} checkpoint(s).") + save_path = checkpoint_tracker.get_path(idx) + + state_dict = {} + for model_lora_layers in lora_layers.values(): + model_state_dict = model_lora_layers.get_lora_state_dict() + model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict) + state_dict.update(model_kohya_state_dict) + + save_state_dict(state_dict, save_path) + # accelerator.save_state(save_path) + logger.info(f"Saved state to '{save_path}'.") + + +def _generate_validation_images( + epoch: int, + out_dir: str, + accelerator: Accelerator, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + noise_scheduler: DDPMScheduler, + unet: UNet2DConditionModel, + config: FinetuneLoRAConfig, + logger: logging.Logger, +): + """Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout + training. + + Args: + epoch (int): Epoch number, for reporting purposes. + out_dir (str): The output directory where the validation images will be stored. + accelerator (Accelerator): Accelerator + vae (AutoencoderKL): + text_encoder (CLIPTextModel): + tokenizer (CLIPTokenizer): + noise_scheduler (DDPMScheduler): + unet (UNet2DConditionModel): + config (FinetuneLoRAConfig): Training configs. + logger (logging.Logger): Logger. + """ + logger.info("Generating validation images.") + + # Create pipeline. + pipeline = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=noise_scheduler, + safety_checker=None, + feature_extractor=None, + # TODO(ryand): Add safety checker support. + requires_safety_checker=False, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # Run inference. + with torch.no_grad(): + for prompt_idx, prompt in enumerate(config.validation_prompts): + generator = torch.Generator(device=accelerator.device) + if config.seed is not None: + generator = generator.manual_seed(config.seed) + + images = [] + for _ in range(config.num_validation_images_per_prompt): + with accelerator.autocast(): + images.append( + pipeline( + prompt, + num_inference_steps=30, + generator=generator, + ).images[0] + ) + + # Save images to disk. + validation_dir = os.path.join( + out_dir, + "validation", + f"epoch_{epoch:0>8}", + f"prompt_{prompt_idx:0>4}", + ) + os.makedirs(validation_dir) + for image_idx, image in enumerate(images): + image.save(os.path.join(validation_dir, f"{image_idx:0>4}.jpg")) + + # Log images to trackers. Currently, only tensorboard is supported. + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images( + f"validation (prompt {prompt_idx})", + np_images, + epoch, + dataformats="NHWC", + ) + + del pipeline + torch.cuda.empty_cache() + + +def _train_forward( + config: FinetuneLoRAConfig, + data_batch: dict, + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + weight_dtype: torch.dtype, +): + """Run the forward training pass for a single data_batch. + + Returns: + torch.Tensor: Loss + """ + # Convert images to latent space. + latents = vae.encode(data_batch["image"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents. + noise = torch.randn_like(latents) + + batch_size = latents.shape[0] + # Sample a random timestep for each image. + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (batch_size,), + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep (this is the forward + # diffusion process). + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning. + encoder_hidden_states = text_encoder(data_batch["caption_token_ids"])[0] + + # Get the target for loss depending on the prediction type. + if config.prediction_type is not None: + # Set the prediction_type of scheduler if it's defined in config. + noise_scheduler.register_to_config(prediction_type=config.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual. + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + return torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") + + +def run_training(config: FinetuneLoRAConfig): # noqa: C901 + # Give a clear error message if an unsupported base model was chosen. + check_base_model_version( + {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2}, + config.model, + local_files_only=False, + ) + + # Create a timestamped directory for all outputs. + out_dir = os.path.join(config.output.base_output_dir, f"{time.time()}") + os.makedirs(out_dir) + + accelerator = initialize_accelerator( + out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.output.report_to + ) + logger = initialize_logging(__name__, accelerator) + + # Set the accelerate seed. + if config.seed is not None: + set_seed(config.seed) + + # Log the accelerator configuration from every process to help with debugging. + logger.info(accelerator.state, main_process_only=False) + + logger.info("Starting LoRA Training.") + logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") + logger.info(f"Output dir: '{out_dir}'") + + # Write the configuration to disk. + with open(os.path.join(out_dir, "config.json"), "w") as f: + json.dump(config.dict(), f, indent=2, default=str) + + weight_dtype = get_mixed_precision_dtype(accelerator) + + logger.info("Loading models.") + tokenizer, noise_scheduler, text_encoder, vae, unet = _load_models(accelerator, config) + + lora_layers = torch.nn.ModuleDict() + if config.train_unet: + lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks) + if config.train_text_encoder: + lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder) + + if config.xformers: + import xformers # noqa: F401 + + unet.enable_xformers_memory_efficient_attention() + vae.enable_xformers_memory_efficient_attention() + + optimizer = _initialize_optimizer(config, lora_layers.parameters()) + + data_loader = build_image_caption_dataloader(config.dataset, tokenizer, config.train_batch_size) + + # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps + # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears + # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process + # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), + # so the scaling here simply reverses that behaviour. + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( + config.optimizer.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=config.optimizer.lr_warmup_steps * accelerator.num_processes, + num_training_steps=config.max_train_steps * accelerator.num_processes, + ) + + prepared_result: tuple[ + UNet2DConditionModel, + CLIPTextModel, + torch.nn.ModuleDict, + torch.optim.Optimizer, + torch.utils.data.DataLoader, + torch.optim.lr_scheduler.LRScheduler, + ] = accelerator.prepare(unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler) + unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result + + # Calculate the number of epochs and total training steps. A "step" represents a single weight update operation + # (i.e. takes into account gradient accumulation steps). + # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when + # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. + num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.max_train_steps / num_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers("lora_training") + + epoch_checkpoint_tracker = CheckpointTracker( + base_dir=out_dir, + prefix="checkpoint_epoch", + extension=f".{config.output.save_model_as}", + max_checkpoints=config.max_checkpoints, + ) + + step_checkpoint_tracker = CheckpointTracker( + base_dir=out_dir, + prefix="checkpoint_step", + extension=f".{config.output.save_model_as}", + max_checkpoints=config.max_checkpoints, + ) + + # Train! + total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(data_loader)}") + logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") + logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") + logger.info(f" Parallel processes = {accelerator.num_processes}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Total optimization steps = {config.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + progress_bar = tqdm( + range(global_step, config.max_train_steps), + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + lora_layers.train() + + train_loss = 0.0 + for data_batch in data_loader: + with accelerator.accumulate(lora_layers): + loss = _train_forward( + config, + data_batch, + vae, + noise_scheduler, + text_encoder, + unet, + weight_dtype, + ) + + # Gather the losses across all processes for logging (if we use distributed training). + # TODO(ryand): Test that this works properly with distributed training. + avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() + train_loss += avg_loss.item() / config.gradient_accumulation_steps + + # Backpropagate. + accelerator.backward(loss) + if accelerator.sync_gradients and config.max_grad_norm is not None: + params_to_clip = lora_layers.parameters() + accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes. + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if config.save_every_n_steps is not None and (global_step + 1) % config.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + _save_checkpoint(global_step + 1, lora_layers, logger, step_checkpoint_tracker) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + if global_step >= config.max_train_steps: + break + + # Save a checkpoint every n epochs. + if config.save_every_n_epochs is not None and (epoch + 1) % config.save_every_n_epochs == 0: + if accelerator.is_main_process: + _save_checkpoint(epoch + 1, lora_layers, logger, epoch_checkpoint_tracker) + accelerator.wait_for_everyone() + # Generate validation images every n epochs. + if len(config.validation_prompts) > 0 and (epoch + 1) % config.validate_every_n_epochs == 0: + if accelerator.is_main_process: + _generate_validation_images( + epoch=epoch + 1, + out_dir=out_dir, + accelerator=accelerator, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + unet=unet, + config=config, + logger=logger, + ) -def run_training(config: FinetuneLoRAConfig): - raise NotImplementedError("finetune_lora_sdxl is not implemented.") + accelerator.end_training() From f2d751fd0d29a894edba1c242e18626fad4f0074 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Aug 2023 17:08:09 -0400 Subject: [PATCH 02/10] Add ImageCaptionSDXLDataset. --- .../training/shared/datasets/ARCHITECTURE.md | 16 +-- .../datasets/image_caption_dataloader.py | 6 +- ...dataset.py => image_caption_sd_dataset.py} | 8 +- .../datasets/image_caption_sdxl_dataset.py | 109 ++++++++++++++++++ ...et.py => test_image_caption_sd_dataset.py} | 12 +- .../test_image_caption_sdxl_dataset.py | 84 ++++++++++++++ 6 files changed, 215 insertions(+), 20 deletions(-) rename src/invoke_training/training/shared/datasets/{image_caption_dataset.py => image_caption_sd_dataset.py} (90%) create mode 100644 src/invoke_training/training/shared/datasets/image_caption_sdxl_dataset.py rename tests/invoke_training/training/shared/datasets/{test_image_caption_dataset.py => test_image_caption_sd_dataset.py} (74%) create mode 100644 tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataset.py diff --git a/src/invoke_training/training/shared/datasets/ARCHITECTURE.md b/src/invoke_training/training/shared/datasets/ARCHITECTURE.md index 366a8cc4..1cfa5698 100644 --- a/src/invoke_training/training/shared/datasets/ARCHITECTURE.md +++ b/src/invoke_training/training/shared/datasets/ARCHITECTURE.md @@ -1,18 +1,20 @@ # Dataset Architecture -Dataset handling is split into 3 layers of abstraction: `BaseImageCaptionReader`s, `ImageCaptionDataset`, and `DataLoader`. Each is explained in more detail below. +Dataset handling is split into 3 layers of abstraction: Readers, Datasets, and DataLoaders. Each is explained in more detail below. -## `BaseImageCaptionReader` +## Readers -`BaseImageCaptionReader` defines an interface that can be implemented by multiple sub-classes. +`BaseImageCaptionReader` defines a reader interface that can be implemented by multiple sub-classes. `BaseImageCaptionReader` sub-classes are intended as an abstraction over different dataset formats. They are responsible for loading image-caption pairs from disk. Readers implement the [torch Dataset interface](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files), i.e. they implement `__init__`, `__len__`, and `__getitem__`. Two examples of concrete reader implementations are the `HFHubImageCaptionReader` and the `HFDirImageCaptionReader`. -## `ImageCaptionDataset` +## Datasets -An `ImageCaptionDataset` is a wrapper for a `BaseImageCaptionReader` object. The `ImageCaptionDataset` is responsible for example-level operations that are agnostic to the underlying dataset format. For example, image augmentations are handled at this layer. +Dataset classes wrap reader classes and are responsible for example-level operations that are agnostic to the underlying dataset format. -## `DataLoader` +As an example `ImageCaptionSDDataset` is a wrapper for a `BaseImageCaptionReader` object that implements image augmentations and caption tokenization. -The `ImageCaptionDataset` is wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc. +## DataLoaders + +The dataset classes are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc. diff --git a/src/invoke_training/training/shared/datasets/image_caption_dataloader.py b/src/invoke_training/training/shared/datasets/image_caption_dataloader.py index 990555ee..006f36b9 100644 --- a/src/invoke_training/training/shared/datasets/image_caption_dataloader.py +++ b/src/invoke_training/training/shared/datasets/image_caption_dataloader.py @@ -8,8 +8,8 @@ from invoke_training.training.shared.datasets.hf_hub_image_caption_reader import ( HFHubImageCaptionReader, ) -from invoke_training.training.shared.datasets.image_caption_dataset import ( - ImageCaptionDataset, +from invoke_training.training.shared.datasets.image_caption_sd_dataset import ( + ImageCaptionSDDataset, ) @@ -44,7 +44,7 @@ def build_image_caption_dataloader(config: DatasetConfig, tokenizer: CLIPTokeniz else: raise ValueError("One of 'dataset_name' or 'dataset_dir' must be set.") - dataset = ImageCaptionDataset( + dataset = ImageCaptionSDDataset( reader=reader, tokenizer=tokenizer, resolution=config.resolution, diff --git a/src/invoke_training/training/shared/datasets/image_caption_dataset.py b/src/invoke_training/training/shared/datasets/image_caption_sd_dataset.py similarity index 90% rename from src/invoke_training/training/shared/datasets/image_caption_dataset.py rename to src/invoke_training/training/shared/datasets/image_caption_sd_dataset.py index 2d5bc946..273dfb40 100644 --- a/src/invoke_training/training/shared/datasets/image_caption_dataset.py +++ b/src/invoke_training/training/shared/datasets/image_caption_sd_dataset.py @@ -6,9 +6,9 @@ ) -class ImageCaptionDataset: - """A image-caption dataset class that wraps a BaseImageCaptionReader and applies common image transformations and - caption tokenization. +class ImageCaptionSDDataset: + """A image-caption dataset for Stable Diffusion v1/v2 models. This class wraps a BaseImageCaptionReader and applies + common image transformations and caption tokenization. """ def __init__( @@ -19,7 +19,7 @@ def __init__( center_crop: bool = False, random_flip: bool = False, ): - """Initialize ImageCaptionDataset. + """Initialize ImageCaptionSDDataset. Args: reader (BaseImageCaptionReader): The reader to wrap. diff --git a/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataset.py b/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataset.py new file mode 100644 index 00000000..d649ebde --- /dev/null +++ b/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataset.py @@ -0,0 +1,109 @@ +import random + +import PIL +from torchvision import transforms +from torchvision.transforms.functional import crop +from transformers import PreTrainedTokenizer + +from invoke_training.training.shared.datasets.base_image_caption_reader import ( + BaseImageCaptionReader, +) + + +class ImageCaptionSDXLDataset: + """A image-caption dataset for Stable Diffusion XL models. This class wraps a BaseImageCaptionReader and applies + common image transformations and caption tokenization. + """ + + def __init__( + self, + reader: BaseImageCaptionReader, + tokenizer_1: PreTrainedTokenizer, + tokenizer_2: PreTrainedTokenizer, + resolution: int, + center_crop: bool = False, + random_flip: bool = False, + ): + """Initialize ImageCaptionSDDataset. + + Args: + reader (BaseImageCaptionReader): The reader to wrap. + tokenizer_1 (PreTrainedTokenizer): The first SDXL text tokenizer. + tokenizer_2 (PreTrainedTokenizer): The second SDXL text tokenizer. + resolution (int): The image resolution that will be produced (square images are assumed). + center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If + False, crop at a random location. + random_flip (bool, optional): Whether to apply a random horizontal flip to the images. + """ + self._reader = reader + self._tokenizer_1 = tokenizer_1 + self._tokenizer_2 = tokenizer_2 + + # Image transforms. + self._resolution = resolution + self._resize_transform = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR) + self._center_crop_enabled = center_crop + self._crop_transform = transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution) + self._random_flip_enabled = random_flip + self._flip_transform = transforms.RandomHorizontalFlip(p=1.0) + self._other_transforms = transforms.Compose( + [ + transforms.ToTensor(), + # Convert pixel values from range [0, 1.0] to range [-1.0, 1.0]. Normalize applies the following + # transform: out = (in - 0.5) / 0.5 + transforms.Normalize([0.5], [0.5]), + ] + ) + + def _tokenize_caption(self, tokenizer, caption: str): + input = tokenizer( + caption, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return input.input_ids[0, ...] + + def _preprocess_image(self, image: PIL.Image.Image): + # This SDXL image pre-processing logic is adapted from: + # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873 + original_size_hw = (image.height, image.width) + + # Resize smaller image dimension to `resolution`. + image = self._resize_transform(image) + + # Apply cropping, and record top left crop position. + if self._center_crop_enabled: + top_left_y = max(0, int(round((image.height - self._resolution) / 2.0))) + top_left_x = max(0, int(round((image.width - self._resolution) / 2.0))) + image = self._crop_transform(image) + else: + top_left_y, top_left_x, h, w = self._crop_transform.get_params(image, (self._resolution, self._resolution)) + image = crop(image, top_left_y, top_left_x, h, w) + + # Apply random flip and update top left crop position accordingly. + if self._random_flip_enabled and random.random() < 0.5: + top_left_x = image.width - top_left_x + image = self._flip_transform(image) + + crop_top_left_yx = (top_left_y, top_left_x) + + # Convert image to Tensor and normalize to range [-1.0, 1.0]. + image = self._other_transforms(image) + + return original_size_hw, crop_top_left_yx, image + + def __len__(self) -> int: + return len(self._reader) + + def __getitem__(self, idx: int): + example = self._reader[idx] + original_size_hw, crop_top_left_yx, image = self._preprocess_image(example["image"]) + return { + "image": image, + "original_size_hw": original_size_hw, + "crop_top_left_yx": crop_top_left_yx, + "caption_token_ids_1": self._tokenize_caption(self._tokenizer_1, example["caption"]), + "caption_token_ids_2": self._tokenize_caption(self._tokenizer_2, example["caption"]), + } diff --git a/tests/invoke_training/training/shared/datasets/test_image_caption_dataset.py b/tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataset.py similarity index 74% rename from tests/invoke_training/training/shared/datasets/test_image_caption_dataset.py rename to tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataset.py index 62c5f4b8..f51b459f 100644 --- a/tests/invoke_training/training/shared/datasets/test_image_caption_dataset.py +++ b/tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataset.py @@ -6,24 +6,24 @@ from PIL import Image from transformers import CLIPTokenizer -from invoke_training.training.shared.datasets.image_caption_dataset import ( - ImageCaptionDataset, +from invoke_training.training.shared.datasets.image_caption_sd_dataset import ( + ImageCaptionSDDataset, ) def test_image_caption_dataset_len(): - """Test that the ImageCaptionDataset __len__() function returns the length of the underlying reader.""" + """Test that the ImageCaptionSDDataset __len__() function returns the length of the underlying reader.""" reader_mock = unittest.mock.MagicMock() reader_mock.__len__.return_value = 5 - dataset = ImageCaptionDataset(reader_mock, None, resolution=512) + dataset = ImageCaptionSDDataset(reader_mock, None, resolution=512) assert len(dataset) == 5 @pytest.mark.loads_model def test_image_caption_dataset_getitem(): - """Test that the ImageCaptionDataset __getitem__() function returns an example with the expected type and + """Test that the ImageCaptionSDDataset __getitem__() function returns an example with the expected type and dimensions. """ # Prepare mock reader. @@ -40,7 +40,7 @@ def test_image_caption_dataset_getitem(): revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", ) - dataset = ImageCaptionDataset(reader_mock, tokenizer, resolution=512) + dataset = ImageCaptionSDDataset(reader_mock, tokenizer, resolution=512) example = dataset[0] diff --git a/tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataset.py b/tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataset.py new file mode 100644 index 00000000..de5d02ae --- /dev/null +++ b/tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataset.py @@ -0,0 +1,84 @@ +import unittest + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import CLIPTokenizer + +from invoke_training.training.shared.datasets.image_caption_sdxl_dataset import ( + ImageCaptionSDXLDataset, +) + + +def test_image_caption_sdxl_dataset_len(): + """Test that the ImageCaptionSDXLDataset __len__() function returns the length of the underlying reader.""" + reader_mock = unittest.mock.MagicMock() + reader_mock.__len__.return_value = 5 + + dataset = ImageCaptionSDXLDataset(reader_mock, None, None, resolution=512) + + assert len(dataset) == 5 + + +@pytest.mark.loads_model +def test_image_caption_sdxl_dataset_getitem(): + """Test that the ImageCaptionSDXLDataset __getitem__() function returns a valid example.""" + # Prepare mock reader. + rgb_np = np.ones((256, 128, 3), dtype=np.uint8) + rgb_pil = Image.fromarray(rgb_np) + reader_mock = unittest.mock.MagicMock() + reader_mock.__getitem__.return_value = {"image": rgb_pil, "caption": "This is a test caption."} + + # Load tokenizers. + tokenizer_1 = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="tokenizer", + local_files_only=True, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="tokenizer_2", + local_files_only=True, + ) + + # We expect the shape of the 256x128 input image to be transformed as follows: + # 1. Resize to 1024x512. + # 2. Center crop at top_left_yx = (256, 0) to produce a 512x512 image. + dataset = ImageCaptionSDXLDataset( + reader_mock, tokenizer_1, tokenizer_2, center_crop=True, random_flip=False, resolution=512 + ) + + example = dataset[0] + + reader_mock.__getitem__.assert_called_with(0) + assert set(example.keys()) == { + "image", + "original_size_hw", + "crop_top_left_yx", + "caption_token_ids_1", + "caption_token_ids_2", + } + + image = example["image"] + assert isinstance(image, torch.Tensor) + assert image.shape == (3, 512, 512) + assert image.dtype == torch.float32 + + caption_token_ids = example["caption_token_ids_1"] + assert isinstance(caption_token_ids, torch.Tensor) + assert caption_token_ids.shape == (77,) + assert caption_token_ids.dtype == torch.int64 + + caption_token_ids = example["caption_token_ids_2"] + assert isinstance(caption_token_ids, torch.Tensor) + assert caption_token_ids.shape == (77,) + assert caption_token_ids.dtype == torch.int64 + + original_size_hw = example["original_size_hw"] + assert isinstance(original_size_hw, tuple) + assert original_size_hw == (256, 128) + + crop_top_left_yx = example["crop_top_left_yx"] + assert isinstance(crop_top_left_yx, tuple) + assert crop_top_left_yx == (256, 0) From 754028434d999b6d91ff19905c7e02f85d96628f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Aug 2023 20:41:28 -0400 Subject: [PATCH 03/10] Add an SDXL dataloader. --- .../finetune_lora/finetune_lora_sd.py | 6 +- ...ader.py => image_caption_sd_dataloader.py} | 4 +- .../datasets/image_caption_sdxl_dataloader.py | 76 +++++++++++++++++++ ...py => test_image_caption_sd_dataloader.py} | 10 +-- .../test_image_capture_sdxl_dataloader.py | 62 +++++++++++++++ 5 files changed, 148 insertions(+), 10 deletions(-) rename src/invoke_training/training/shared/datasets/{image_caption_dataloader.py => image_caption_sd_dataloader.py} (90%) create mode 100644 src/invoke_training/training/shared/datasets/image_caption_sdxl_dataloader.py rename tests/invoke_training/training/shared/datasets/{test_image_caption_dataloader.py => test_image_caption_sd_dataloader.py} (77%) create mode 100644 tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py index 9793c744..8a596307 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py @@ -36,8 +36,8 @@ check_base_model_version, ) from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker -from invoke_training.training.shared.datasets.image_caption_dataloader import ( - build_image_caption_dataloader, +from invoke_training.training.shared.datasets.image_caption_sd_dataloader import ( + build_image_caption_sd_dataloader, ) from invoke_training.training.shared.serialization import save_state_dict @@ -324,7 +324,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 optimizer = _initialize_optimizer(config, lora_layers.parameters()) - data_loader = build_image_caption_dataloader(config.dataset, tokenizer, config.train_batch_size) + data_loader = build_image_caption_sd_dataloader(config.dataset, tokenizer, config.train_batch_size) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears diff --git a/src/invoke_training/training/shared/datasets/image_caption_dataloader.py b/src/invoke_training/training/shared/datasets/image_caption_sd_dataloader.py similarity index 90% rename from src/invoke_training/training/shared/datasets/image_caption_dataloader.py rename to src/invoke_training/training/shared/datasets/image_caption_sd_dataloader.py index 006f36b9..3602483a 100644 --- a/src/invoke_training/training/shared/datasets/image_caption_dataloader.py +++ b/src/invoke_training/training/shared/datasets/image_caption_sd_dataloader.py @@ -13,8 +13,8 @@ ) -def build_image_caption_dataloader(config: DatasetConfig, tokenizer: CLIPTokenizer, batch_size: int) -> DataLoader: - """Construct a DataLoader for an image-caption dataset. +def build_image_caption_sd_dataloader(config: DatasetConfig, tokenizer: CLIPTokenizer, batch_size: int) -> DataLoader: + """Construct a DataLoader for an image-caption dataset for Stable Diffusion v1/v2.. Args: config (DatasetConfig): The dataset config. diff --git a/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataloader.py b/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataloader.py new file mode 100644 index 00000000..0702d5b1 --- /dev/null +++ b/src/invoke_training/training/shared/datasets/image_caption_sdxl_dataloader.py @@ -0,0 +1,76 @@ +import torch +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizer + +from invoke_training.training.finetune_lora.finetune_lora_config import DatasetConfig +from invoke_training.training.shared.datasets.hf_dir_image_caption_reader import ( + HFDirImageCaptionReader, +) +from invoke_training.training.shared.datasets.hf_hub_image_caption_reader import ( + HFHubImageCaptionReader, +) +from invoke_training.training.shared.datasets.image_caption_sdxl_dataset import ( + ImageCaptionSDXLDataset, +) + + +def _collate_fn(examples): + """A batch collation function for the ImageCaptionSDXLDataset.""" + return { + "image": torch.stack([example["image"] for example in examples]), + "original_size_hw": [example["original_size_hw"] for example in examples], + "crop_top_left_yx": [example["crop_top_left_yx"] for example in examples], + "caption_token_ids_1": torch.stack([example["caption_token_ids_1"] for example in examples]), + "caption_token_ids_2": torch.stack([example["caption_token_ids_2"] for example in examples]), + } + + +def build_image_caption_sdxl_dataloader( + config: DatasetConfig, tokenizer_1: PreTrainedTokenizer, tokenizer_2: PreTrainedTokenizer, batch_size: int +) -> DataLoader: + """Construct a DataLoader for an image-caption dataset for Stable Diffusion XL. + + Args: + config (DatasetConfig): The dataset config. + tokenizer (CLIPTokenizer): The tokenizer to apply to the captions. + batch_size (int): The DataLoader batch size. + + Returns: + DataLoader + """ + if config.dataset_name is not None: + reader = HFHubImageCaptionReader( + dataset_name=config.dataset_name, + hf_load_dataset_kwargs={ + "name": config.dataset_config_name, + "cache_dir": config.hf_cache_dir, + }, + image_column=config.image_column, + caption_column=config.caption_column, + ) + elif config.dataset_dir is not None: + reader = HFDirImageCaptionReader( + dataset_dir=config.dataset_dir, + hf_load_dataset_kwargs=None, + image_column=config.image_column, + caption_column=config.caption_column, + ) + else: + raise ValueError("One of 'dataset_name' or 'dataset_dir' must be set.") + + dataset = ImageCaptionSDXLDataset( + reader=reader, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + resolution=config.resolution, + center_crop=config.center_crop, + random_flip=config.random_flip, + ) + + return DataLoader( + dataset, + shuffle=True, + collate_fn=_collate_fn, + batch_size=batch_size, + num_workers=config.dataloader_num_workers, + ) diff --git a/tests/invoke_training/training/shared/datasets/test_image_caption_dataloader.py b/tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataloader.py similarity index 77% rename from tests/invoke_training/training/shared/datasets/test_image_caption_dataloader.py rename to tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataloader.py index 007c0728..26172f6f 100644 --- a/tests/invoke_training/training/shared/datasets/test_image_caption_dataloader.py +++ b/tests/invoke_training/training/shared/datasets/test_image_caption_sd_dataloader.py @@ -5,14 +5,14 @@ from transformers import CLIPTokenizer from invoke_training.training.finetune_lora.finetune_lora_config import DatasetConfig -from invoke_training.training.shared.datasets.image_caption_dataloader import ( - build_image_caption_dataloader, +from invoke_training.training.shared.datasets.image_caption_sd_dataloader import ( + build_image_caption_sd_dataloader, ) @pytest.mark.loads_model -def test_build_image_caption_dataloader(): - """Smoke test of build_image_caption_dataloader(...).""" +def test_build_image_caption_sd_dataloader(): + """Smoke test of build_image_caption_sd_dataloader(...).""" tokenizer = CLIPTokenizer.from_pretrained( "runwayml/stable-diffusion-v1-5", @@ -22,7 +22,7 @@ def test_build_image_caption_dataloader(): ) config = DatasetConfig(dataset_name="lambdalabs/pokemon-blip-captions", resolution=512) - data_loader = build_image_caption_dataloader(config, tokenizer, 4) + data_loader = build_image_caption_sd_dataloader(config, tokenizer, 4) # 833 is the length of the dataset determined manually here: # https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions diff --git a/tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py b/tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py new file mode 100644 index 00000000..3929a4c4 --- /dev/null +++ b/tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py @@ -0,0 +1,62 @@ +import math + +import pytest +import torch +from transformers import CLIPTokenizer + +from invoke_training.training.finetune_lora.finetune_lora_config import DatasetConfig +from invoke_training.training.shared.datasets.image_caption_sdxl_dataloader import ( + build_image_caption_sdxl_dataloader, +) + + +@pytest.mark.loads_model +def test_build_image_caption_sdxl_dataloader(): + """Smoke test of build_image_caption_sdxl_dataloader(...).""" + + tokenizer_1 = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="tokenizer", + local_files_only=True, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="tokenizer_2", + local_files_only=True, + ) + + config = DatasetConfig(dataset_name="lambdalabs/pokemon-blip-captions", resolution=512) + data_loader = build_image_caption_sdxl_dataloader(config, tokenizer_1, tokenizer_2, 4) + + # 833 is the length of the dataset determined manually here: + # https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions + assert len(data_loader) == math.ceil(833 / 4) + + example = next(iter(data_loader)) + assert set(example.keys()) == { + "image", + "original_size_hw", + "crop_top_left_yx", + "caption_token_ids_1", + "caption_token_ids_2", + } + + image = example["image"] + assert image.shape == (4, 3, 512, 512) + assert image.dtype == torch.float32 + + original_size_hw = example["original_size_hw"] + assert len(original_size_hw) == 4 + assert len(original_size_hw[0]) == 2 + + crop_top_left_yx = example["crop_top_left_yx"] + assert len(crop_top_left_yx) == 4 + assert len(crop_top_left_yx[0]) == 2 + + caption_token_ids_1 = example["caption_token_ids_1"] + assert caption_token_ids_1.shape == (4, 77) + assert caption_token_ids_1.dtype == torch.int64 + + caption_token_ids_2 = example["caption_token_ids_2"] + assert caption_token_ids_2.shape == (4, 77) + assert caption_token_ids_2.dtype == torch.int64 From fe7b6a36f00e235545d9bbd9605706596b6155fe Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Aug 2023 21:38:30 -0400 Subject: [PATCH 04/10] Update finetune_lora_sdxl.py with SDXL support. --- .../lora/injection/stable_diffusion_v1.py | 2 +- .../finetune_lora/finetune_lora_config.py | 6 + .../finetune_lora/finetune_lora_sdxl.py | 232 +++++++++++++----- 3 files changed, 181 insertions(+), 59 deletions(-) diff --git a/src/invoke_training/lora/injection/stable_diffusion_v1.py b/src/invoke_training/lora/injection/stable_diffusion_v1.py index 440a870f..c700ab28 100644 --- a/src/invoke_training/lora/injection/stable_diffusion_v1.py +++ b/src/invoke_training/lora/injection/stable_diffusion_v1.py @@ -45,7 +45,7 @@ def inject_lora_into_unet_sd1( return lora_layers -def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel): +def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str = "lora_te"): lora_layers = inject_lora_layers( module=text_encoder, lora_map={ diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_config.py b/src/invoke_training/training/finetune_lora/finetune_lora_config.py index 217c92c2..5ab18da3 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_config.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_config.py @@ -167,3 +167,9 @@ class FinetuneLoRAConfig(BaseModel): # The training batch size. train_batch_size: int = 4 + + +class FinetuneLoRASDXLConfig(FinetuneLoRAConfig): + # The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base + # model (specified by the `model` parameter). + vae_model: typing.Optional[str] = None diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index 9793c744..395da5d3 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -11,12 +11,17 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, - StableDiffusionPipeline, + StableDiffusionXLPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import ( + AutoTokenizer, + CLIPPreTrainedModel, + PretrainedConfig, + PreTrainedTokenizer, +) from invoke_training.lora.injection.stable_diffusion_v1 import ( convert_lora_state_dict_to_kohya_format_sd1, @@ -24,7 +29,7 @@ inject_lora_into_unet_sd1, ) from invoke_training.training.finetune_lora.finetune_lora_config import ( - FinetuneLoRAConfig, + FinetuneLoRASDXLConfig, ) from invoke_training.training.shared.accelerator_utils import ( get_mixed_precision_dtype, @@ -36,58 +41,122 @@ check_base_model_version, ) from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker -from invoke_training.training.shared.datasets.image_caption_dataloader import ( - build_image_caption_dataloader, +from invoke_training.training.shared.datasets.image_caption_sdxl_dataloader import ( + build_image_caption_sdxl_dataloader, ) from invoke_training.training.shared.serialization import save_state_dict +def _import_model_class_for_model(pretrained_model_name_or_path: str, subfolder: str = "", revision: str = "main"): + """Lookup the model class in a diffusers model config, import the class, and return it. This function is useful when + loading models that could be one of many possible classes. + + Args: + pretrained_model_name_or_path (str): The diffusers model name/path. + subfolder (str, optional): The model subfolder. + revision (str, optional): The diffusers model revision. + + + Raises: + ValueError: If the detected model class is not recognize. + + Returns: + type: The model class. + """ + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + def _load_models( accelerator: Accelerator, - config: FinetuneLoRAConfig, -) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]: - """Load all models required for training from disk, transfer them to the - target training device and cast their weight dtypes. + config: FinetuneLoRASDXLConfig, +) -> tuple[ + PreTrainedTokenizer, + PreTrainedTokenizer, + DDPMScheduler, + CLIPPreTrainedModel, + CLIPPreTrainedModel, + AutoencoderKL, + UNet2DConditionModel, +]: + """Load all models required for training, transfer them to the target training device and cast their weight dtypes. Args: - config (FinetuneLoRAConfig): The LoRA training run config. - logger (logging.Logger): A logger. + accelerator (Accelerator): Accelerator + config (FinetuneLoRASDXLConfig): Training config. Returns: tuple[ - CLIPTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizer, DDPMScheduler, - CLIPTextModel, + CLIPPreTrainedModel, + CLIPPreTrainedModel, AutoencoderKL, UNet2DConditionModel, ]: A tuple of loaded models. """ weight_dtype = get_mixed_precision_dtype(accelerator) - tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(config.model, subfolder="tokenizer") + # Load tokenizers. + tokenizer_1: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + config.model, subfolder="tokenizer", use_fast=False + ) + tokenizer_2: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + config.model, subfolder="tokenizer_2", use_fast=False + ) + + # Load noise scheduler. noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(config.model, subfolder="scheduler") - text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(config.model, subfolder="text_encoder") - vae: AutoencoderKL = AutoencoderKL.from_pretrained(config.model, subfolder="vae") + + # Load text encoders. + text_encoder_cls_1 = _import_model_class_for_model(config.model, subfolder="text_encoder") + text_encoder_1 = text_encoder_cls_1.from_pretrained(config.model, subfolder="text_encoder") + text_encoder_cls_2 = _import_model_class_for_model(config.model, subfolder="text_encoder_2") + text_encoder_2 = text_encoder_cls_2.from_pretrained(config.model, subfolder="text_encoder_2") + + # Load VAE. + vae_model = config.vae_model if config.vae_model is not None else config.model + vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model, subfolder="vae") + + # Load UNet. unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(config.model, subfolder="unet") # Disable gradient calculation for model weights to save memory. - text_encoder.requires_grad_(False) + text_encoder_1.requires_grad_(False) + text_encoder_2.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) # Put models in 'eval' mode. - text_encoder.eval() + text_encoder_1.eval() + text_encoder_2.eval() vae.eval() unet.eval() - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_1.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) - return tokenizer, noise_scheduler, text_encoder, vae, unet + return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet -def _initialize_optimizer(config: FinetuneLoRAConfig, trainable_params: list) -> torch.optim.Optimizer: +# TODO(ryand): Split this function out to avoid duplication between sd and sdxl configs. +def _initialize_optimizer(config: FinetuneLoRASDXLConfig, trainable_params: list) -> torch.optim.Optimizer: """Initialize an optimizer based on the config.""" return torch.optim.AdamW( trainable_params, @@ -130,45 +199,62 @@ def _save_checkpoint( logger.info(f"Saved state to '{save_path}'.") +# encode_prompt was adapted from: +# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L470-L496 +def _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_token_ids_list: list[torch.Tensor]): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + text_input_ids = prompt_token_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder. + # TODO(ryand): Document this logic more clearly. + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + def _generate_validation_images( epoch: int, out_dir: str, accelerator: Accelerator, vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, + text_encoder_1: CLIPPreTrainedModel, + text_encoder_2: CLIPPreTrainedModel, + tokenizer_1: PreTrainedTokenizer, + tokenizer_2: PreTrainedTokenizer, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, - config: FinetuneLoRAConfig, + config: FinetuneLoRASDXLConfig, logger: logging.Logger, ): """Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout training. - - Args: - epoch (int): Epoch number, for reporting purposes. - out_dir (str): The output directory where the validation images will be stored. - accelerator (Accelerator): Accelerator - vae (AutoencoderKL): - text_encoder (CLIPTextModel): - tokenizer (CLIPTokenizer): - noise_scheduler (DDPMScheduler): - unet (UNet2DConditionModel): - config (FinetuneLoRAConfig): Training configs. - logger (logging.Logger): Logger. """ logger.info("Generating validation images.") # Create pipeline. - pipeline = StableDiffusionPipeline( + pipeline = StableDiffusionXLPipeline( vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, + text_encoder=text_encoder_1, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer_1, + tokenizer_2=tokenizer_2, unet=unet, scheduler=noise_scheduler, safety_checker=None, feature_extractor=None, - # TODO(ryand): Add safety checker support. requires_safety_checker=False, ) pipeline = pipeline.to(accelerator.device) @@ -219,11 +305,13 @@ def _generate_validation_images( def _train_forward( - config: FinetuneLoRAConfig, + accelerator: Accelerator, + config: FinetuneLoRASDXLConfig, data_batch: dict, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, - text_encoder: CLIPTextModel, + text_encoder_1: CLIPPreTrainedModel, + text_encoder_2: CLIPPreTrainedModel, unet: UNet2DConditionModel, weight_dtype: torch.dtype, ): @@ -249,12 +337,33 @@ def _train_forward( ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep (this is the forward - # diffusion process). + # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion + # process). noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # compute_time_ids was copied from: + # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L1033-L1039 + # "time_ids" may seem like a weird naming choice. The name comes from the diffusers SDXL implementation. Presumably, + # it is a result of the fact that the original size and crop values get concatenated with the time embeddings. + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (config.dataset.resolution, config.dataset.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(data_batch["original_size_hw"], data_batch["crop_top_left_yx"])] + ) + unet_conditions = {"time_ids": add_time_ids} + # Get the text embedding for conditioning. - encoder_hidden_states = text_encoder(data_batch["caption_token_ids"])[0] + prompt_embeds, pooled_prompt_embeds = _encode_prompt( + text_encoders=[text_encoder_1, text_encoder_2], + prompt_token_ids_list=[data_batch["caption_token_ids_1"], data_batch["caption_token_ids_2"]], + ) + unet_conditions["text_embeds"] = pooled_prompt_embeds # Get the target for loss depending on the prediction type. if config.prediction_type is not None: @@ -268,15 +377,15 @@ def _train_forward( raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual. - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_conditions).sample return torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") -def run_training(config: FinetuneLoRAConfig): # noqa: C901 +def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. check_base_model_version( - {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2}, + {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL}, config.model, local_files_only=False, ) @@ -297,7 +406,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) - logger.info("Starting LoRA Training.") + logger.info("Starting Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") @@ -308,13 +417,16 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 weight_dtype = get_mixed_precision_dtype(accelerator) logger.info("Loading models.") - tokenizer, noise_scheduler, text_encoder, vae, unet = _load_models(accelerator, config) + tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = _load_models( + accelerator, config + ) lora_layers = torch.nn.ModuleDict() if config.train_unet: lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks) if config.train_text_encoder: - lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder) + lora_layers["text_encoder_1"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te1") + lora_layers["text_encoder_2"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te2") if config.xformers: import xformers # noqa: F401 @@ -324,7 +436,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 optimizer = _initialize_optimizer(config, lora_layers.parameters()) - data_loader = build_image_caption_dataloader(config.dataset, tokenizer, config.train_batch_size) + data_loader = build_image_caption_sdxl_dataloader(config.dataset, tokenizer_1, tokenizer_2, config.train_batch_size) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears @@ -340,13 +452,14 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 prepared_result: tuple[ UNet2DConditionModel, - CLIPTextModel, + CLIPPreTrainedModel, + CLIPPreTrainedModel, torch.nn.ModuleDict, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, - ] = accelerator.prepare(unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler) - unet, text_encoder, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result + ] = accelerator.prepare(unet, text_encoder_1, text_encoder_2, lora_layers, optimizer, data_loader, lr_scheduler) + unet, text_encoder_1, text_encoder_2, lora_layers, optimizer, data_loader, lr_scheduler = prepared_result # Calculate the number of epochs and total training steps. A "step" represents a single weight update operation # (i.e. takes into account gradient accumulation steps). @@ -403,7 +516,8 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 data_batch, vae, noise_scheduler, - text_encoder, + text_encoder_1, + text_encoder_2, unet, weight_dtype, ) @@ -457,8 +571,10 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 out_dir=out_dir, accelerator=accelerator, vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, + text_encoder_1=text_encoder_1, + text_encoder_2=text_encoder_2, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, config=config, From 240d78bb17c89d07b216ac7574ebeab09d40130b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Aug 2023 23:31:49 -0400 Subject: [PATCH 05/10] Fix bugs in initial SDXL implementation. --- .../finetune_lora_sdxl_pokemon_example.yaml | 30 +++++++++++++++++++ .../scripts/invoke_finetune_lora_sdxl.py | 4 +-- .../finetune_lora/finetune_lora_sdxl.py | 8 ++--- 3 files changed, 35 insertions(+), 7 deletions(-) create mode 100644 configs/finetune_lora_sdxl_pokemon_example.yaml diff --git a/configs/finetune_lora_sdxl_pokemon_example.yaml b/configs/finetune_lora_sdxl_pokemon_example.yaml new file mode 100644 index 00000000..bd6b8bea --- /dev/null +++ b/configs/finetune_lora_sdxl_pokemon_example.yaml @@ -0,0 +1,30 @@ +# This is a sample config for finetuning a Stable Diffusion 1.5 model with LoRA to produce a Pokemon LoRA model. + +output: + base_output_dir: output/ + +optimizer: + learning_rate: 5.0e-4 + +dataset: + dataset_name: lambdalabs/pokemon-blip-captions + +# General +model: stabilityai/stable-diffusion-xl-base-1.0 +vae_model: madebyollin/sdxl-vae-fp16-fix +seed: 1 +train_text_encoder: False +gradient_accumulation_steps: 4 +mixed_precision: fp16 +xformers: True +gradient_checkpointing: True +max_train_steps: 4000 +save_every_n_epochs: 1 +save_every_n_steps: null +max_checkpoints: 100 +validation_prompts: + - yoda + - astronaut + - yoda in a space suit +validate_every_n_epochs: 1 +train_batch_size: 1 diff --git a/src/invoke_training/scripts/invoke_finetune_lora_sdxl.py b/src/invoke_training/scripts/invoke_finetune_lora_sdxl.py index c05cff64..535bbdc0 100644 --- a/src/invoke_training/scripts/invoke_finetune_lora_sdxl.py +++ b/src/invoke_training/scripts/invoke_finetune_lora_sdxl.py @@ -4,7 +4,7 @@ import yaml from invoke_training.training.finetune_lora.finetune_lora_config import ( - FinetuneLoRAConfig, + FinetuneLoRASDXLConfig, ) from invoke_training.training.finetune_lora.finetune_lora_sdxl import run_training @@ -27,7 +27,7 @@ def main(): with open(args.cfg_file, "r") as f: cfg = yaml.safe_load(f) - train_config = FinetuneLoRAConfig(**cfg) + train_config = FinetuneLoRASDXLConfig(**cfg) run_training(train_config) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index 395da5d3..f229887a 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -130,7 +130,7 @@ def _load_models( # Load VAE. vae_model = config.vae_model if config.vae_model is not None else config.model - vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model, subfolder="vae") + vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model, subfolder="vae" if config.vae_model is None else None) # Load UNet. unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(config.model, subfolder="unet") @@ -253,9 +253,6 @@ def _generate_validation_images( tokenizer_2=tokenizer_2, unet=unet, scheduler=noise_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -385,7 +382,7 @@ def compute_time_ids(original_size, crops_coords_top_left): def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. check_base_model_version( - {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL}, + {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE}, config.model, local_files_only=False, ) @@ -512,6 +509,7 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 for data_batch in data_loader: with accelerator.accumulate(lora_layers): loss = _train_forward( + accelerator, config, data_batch, vae, From cda7bb7db55801cb7740c75e61da318e6433f5e4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 10 Aug 2023 14:42:23 -0400 Subject: [PATCH 06/10] More SDXL bugfixes. --- configs/finetune_lora_sdxl_pokemon_example.yaml | 2 +- src/invoke_training/lora/injection/stable_diffusion_v1.py | 2 +- .../training/finetune_lora/finetune_lora_config.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/finetune_lora_sdxl_pokemon_example.yaml b/configs/finetune_lora_sdxl_pokemon_example.yaml index bd6b8bea..589342b4 100644 --- a/configs/finetune_lora_sdxl_pokemon_example.yaml +++ b/configs/finetune_lora_sdxl_pokemon_example.yaml @@ -1,4 +1,4 @@ -# This is a sample config for finetuning a Stable Diffusion 1.5 model with LoRA to produce a Pokemon LoRA model. +# This is a sample config for finetuning a SDXL 1.0 model with LoRA to produce a Pokemon LoRA model. output: base_output_dir: output/ diff --git a/src/invoke_training/lora/injection/stable_diffusion_v1.py b/src/invoke_training/lora/injection/stable_diffusion_v1.py index c700ab28..2676f871 100644 --- a/src/invoke_training/lora/injection/stable_diffusion_v1.py +++ b/src/invoke_training/lora/injection/stable_diffusion_v1.py @@ -54,7 +54,7 @@ def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str }, include_descendants_of={CLIPAttention, CLIPMLP}, exclude_descendants_of=None, - prefix="lora_te", + prefix=prefix, dtype=torch.float32, ) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_config.py b/src/invoke_training/training/finetune_lora/finetune_lora_config.py index 5ab18da3..e35ce517 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_config.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_config.py @@ -171,5 +171,6 @@ class FinetuneLoRAConfig(BaseModel): class FinetuneLoRASDXLConfig(FinetuneLoRAConfig): # The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base - # model (specified by the `model` parameter). + # model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped + # with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version. vae_model: typing.Optional[str] = None From bc0cc34c5ce2d357e68262f07248a66bdbed8d2d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 10 Aug 2023 17:57:54 -0400 Subject: [PATCH 07/10] Update stable diffusion LoRA injection code for compatibility with both SD and SDXL. --- ...le_diffusion_v1.py => stable_diffusion.py} | 8 +- .../finetune_lora/finetune_lora_sd.py | 10 +- .../finetune_lora/finetune_lora_sdxl.py | 10 +- .../lora/injection/test_stable_diffusion.py | 183 ++++++++++++++++++ .../injection/test_stable_diffusion_v1.py | 145 -------------- 5 files changed, 197 insertions(+), 159 deletions(-) rename src/invoke_training/lora/injection/{stable_diffusion_v1.py => stable_diffusion.py} (93%) create mode 100644 tests/invoke_training/lora/injection/test_stable_diffusion.py delete mode 100644 tests/invoke_training/lora/injection/test_stable_diffusion_v1.py diff --git a/src/invoke_training/lora/injection/stable_diffusion_v1.py b/src/invoke_training/lora/injection/stable_diffusion.py similarity index 93% rename from src/invoke_training/lora/injection/stable_diffusion_v1.py rename to src/invoke_training/lora/injection/stable_diffusion.py index 2676f871..941de33d 100644 --- a/src/invoke_training/lora/injection/stable_diffusion_v1.py +++ b/src/invoke_training/lora/injection/stable_diffusion.py @@ -12,10 +12,10 @@ from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer -def inject_lora_into_unet_sd1( +def inject_lora_into_unet( unet: UNet2DConditionModel, include_non_attention_blocks: bool = False ) -> LoRALayerCollection: - """Inject LoRA layers into a Stable Diffusion v1 UNet model. + """Inject LoRA layers into a Stable Diffusion UNet model. Args: unet (UNet2DConditionModel): The UNet model to inject LoRA layers into. @@ -61,10 +61,10 @@ def inject_lora_into_clip_text_encoder(text_encoder: CLIPTextModel, prefix: str return lora_layers -def convert_lora_state_dict_to_kohya_format_sd1( +def convert_lora_state_dict_to_kohya_format( state_dict: typing.Dict[str, torch.Tensor] ) -> typing.Dict[str, torch.Tensor]: - """Convert a Stable Diffusion v1 LoRA state_dict from internal invoke-training format to kohya_ss format. + """Convert a Stable Diffusion LoRA state_dict from internal invoke-training format to kohya_ss format. Args: state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format. diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py index 8a596307..c4895a2e 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py @@ -18,10 +18,10 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -from invoke_training.lora.injection.stable_diffusion_v1 import ( - convert_lora_state_dict_to_kohya_format_sd1, +from invoke_training.lora.injection.stable_diffusion import ( + convert_lora_state_dict_to_kohya_format, inject_lora_into_clip_text_encoder, - inject_lora_into_unet_sd1, + inject_lora_into_unet, ) from invoke_training.training.finetune_lora.finetune_lora_config import ( FinetuneLoRAConfig, @@ -122,7 +122,7 @@ def _save_checkpoint( state_dict = {} for model_lora_layers in lora_layers.values(): model_state_dict = model_lora_layers.get_lora_state_dict() - model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict) + model_kohya_state_dict = convert_lora_state_dict_to_kohya_format(model_state_dict) state_dict.update(model_kohya_state_dict) save_state_dict(state_dict, save_path) @@ -312,7 +312,7 @@ def run_training(config: FinetuneLoRAConfig): # noqa: C901 lora_layers = torch.nn.ModuleDict() if config.train_unet: - lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks) + lora_layers["unet"] = inject_lora_into_unet(unet, config.train_unet_non_attention_blocks) if config.train_text_encoder: lora_layers["text_encoder"] = inject_lora_into_clip_text_encoder(text_encoder) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index f229887a..13c49b26 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -23,10 +23,10 @@ PreTrainedTokenizer, ) -from invoke_training.lora.injection.stable_diffusion_v1 import ( - convert_lora_state_dict_to_kohya_format_sd1, +from invoke_training.lora.injection.stable_diffusion import ( + convert_lora_state_dict_to_kohya_format, inject_lora_into_clip_text_encoder, - inject_lora_into_unet_sd1, + inject_lora_into_unet, ) from invoke_training.training.finetune_lora.finetune_lora_config import ( FinetuneLoRASDXLConfig, @@ -191,7 +191,7 @@ def _save_checkpoint( state_dict = {} for model_lora_layers in lora_layers.values(): model_state_dict = model_lora_layers.get_lora_state_dict() - model_kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(model_state_dict) + model_kohya_state_dict = convert_lora_state_dict_to_kohya_format(model_state_dict) state_dict.update(model_kohya_state_dict) save_state_dict(state_dict, save_path) @@ -420,7 +420,7 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 lora_layers = torch.nn.ModuleDict() if config.train_unet: - lora_layers["unet"] = inject_lora_into_unet_sd1(unet, config.train_unet_non_attention_blocks) + lora_layers["unet"] = inject_lora_into_unet(unet, config.train_unet_non_attention_blocks) if config.train_text_encoder: lora_layers["text_encoder_1"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te1") lora_layers["text_encoder_2"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te2") diff --git a/tests/invoke_training/lora/injection/test_stable_diffusion.py b/tests/invoke_training/lora/injection/test_stable_diffusion.py new file mode 100644 index 00000000..6d287448 --- /dev/null +++ b/tests/invoke_training/lora/injection/test_stable_diffusion.py @@ -0,0 +1,183 @@ +import pytest +import torch +from diffusers.models import UNet2DConditionModel +from transformers import CLIPTextModel + +from invoke_training.lora.injection.stable_diffusion import ( + convert_lora_state_dict_to_kohya_format, + inject_lora_into_clip_text_encoder, + inject_lora_into_unet, +) + + +@pytest.mark.loads_model +@pytest.mark.parametrize( + ["model_name", "revision", "expected_num_layers"], + [ + ("runwayml/stable-diffusion-v1-5", "c9ab35ff5f2c362e9e22fbafe278077e196057f0", 192), + ("stabilityai/stable-diffusion-xl-base-1.0", "47cd5302d866fa60cf8fb81f0e34d42e38f6100c", 722), + ], +) +def test_inject_lora_into_unet_smoke(model_name: str, revision: str, expected_num_layers: int): + """Smoke test of inject_lora_into_unet(...).""" + unet = UNet2DConditionModel.from_pretrained( + model_name, + subfolder="unet", + local_files_only=True, + revision=revision, + ) + lora_layers = inject_lora_into_unet(unet) + + # These assertions are based on a manual check of the injected layers and comparison against the behaviour of + # kohya_ss. They are included here to force another manual review after any future breaking change. + assert len(lora_layers) == expected_num_layers + for layer_name in lora_layers._names: + assert layer_name.endswith( + ("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", ".proj_in", ".proj_out") + ) + + +@pytest.mark.loads_model +@pytest.mark.parametrize( + ["model_name", "revision", "expected_num_layers"], + [ + ("runwayml/stable-diffusion-v1-5", "c9ab35ff5f2c362e9e22fbafe278077e196057f0", 278), + ("stabilityai/stable-diffusion-xl-base-1.0", "47cd5302d866fa60cf8fb81f0e34d42e38f6100c", 788), + ], +) +def test_inject_lora_into_unet_non_attention_layers_smoke(model_name: str, revision: str, expected_num_layers: int): + """Smoke test of inject_lora_into_unet(..., include_non_attention_blocks=True).""" + unet = UNet2DConditionModel.from_pretrained( + model_name, + subfolder="unet", + local_files_only=True, + revision=revision, + ) + lora_layers = inject_lora_into_unet(unet, include_non_attention_blocks=True) + + # These assertions are based on a manual check of the injected layers and comparison against the behaviour of + # kohya_ss. They are included here to force another manual review after any future breaking change. + assert len(lora_layers) == expected_num_layers + for layer_name in lora_layers._names: + assert layer_name.endswith( + ( + "to_q", + "to_k", + "to_v", + "to_out.0", + "ff.net.0.proj", + "ff.net.2", + ".proj_in", + ".proj_out", + ".conv1", + ".conv2", + ".time_emb_proj", + ".conv", + ".conv_shortcut", + ) + ) + + +@pytest.mark.loads_model +@pytest.mark.parametrize( + ["model_name", "revision", "text_encoder_name", "expected_num_layers"], + [ + ("stabilityai/stable-diffusion-xl-base-1.0", "47cd5302d866fa60cf8fb81f0e34d42e38f6100c", "text_encoder", 72), + ("stabilityai/stable-diffusion-xl-base-1.0", "47cd5302d866fa60cf8fb81f0e34d42e38f6100c", "text_encoder_2", 192), + ("runwayml/stable-diffusion-v1-5", "c9ab35ff5f2c362e9e22fbafe278077e196057f0", "text_encoder", 72), + ], +) +def test_inject_lora_into_clip_text_encoder_smoke(model_name, revision, text_encoder_name, expected_num_layers): + """Smoke test of inject_lora_into_clip_text_encoder(...).""" + text_encoder = CLIPTextModel.from_pretrained( + model_name, + subfolder=text_encoder_name, + local_files_only=True, + revision=revision, + ) + + lora_layers = inject_lora_into_clip_text_encoder(text_encoder) + + # These assertions are based on a manual check of the injected layers and comparison against the behaviour of + # kohya_ss. They are included here to force another manual review after any future breaking change. + assert len(lora_layers) == expected_num_layers + for layer_name in lora_layers._names: + assert layer_name.endswith(("mlp.fc1", "mlp.fc2", "k_proj", "out_proj", "q_proj", "v_proj")) + + +@pytest.mark.loads_model +@pytest.mark.loads_model +@pytest.mark.parametrize( + ["model_name", "revision", "expected_num_layers"], + [ + ("runwayml/stable-diffusion-v1-5", "c9ab35ff5f2c362e9e22fbafe278077e196057f0", 192), + ("stabilityai/stable-diffusion-xl-base-1.0", "47cd5302d866fa60cf8fb81f0e34d42e38f6100c", 722), + ], +) +def test_convert_lora_state_dict_to_kohya_format_smoke(model_name: str, revision: str, expected_num_layers: int): + """Smoke test of convert_lora_state_dict_to_kohya_format(...) with full SD 1.5 model.""" + unet = UNet2DConditionModel.from_pretrained( + model_name, + subfolder="unet", + local_files_only=True, + revision=revision, + ) + lora_layers = inject_lora_into_unet(unet) + lora_state_dict = lora_layers.get_lora_state_dict() + kohya_state_dict = convert_lora_state_dict_to_kohya_format(lora_state_dict) + + # These assertions are based on a manual check of the injected layers and comparison against the behaviour of + # kohya_ss. They are included here to force another manual review after any future breaking change. + assert len(kohya_state_dict) == expected_num_layers * 3 + for key in kohya_state_dict.keys(): + assert key.startswith("lora_unet_") + assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha")) + + +def test_convert_lora_state_dict_to_kohya_format(): + """Basic test of convert_lora_state_dict_to_kohya_format(...).""" + down_weight = torch.Tensor(4, 2) + up_weight = torch.Tensor(2, 4) + alpha = torch.Tensor([1.0]) + in_state_dict = { + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": down_weight, + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._up.weight": up_weight, + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.alpha": alpha, + } + + out_state_dict = convert_lora_state_dict_to_kohya_format(in_state_dict) + + expected_out_state_dict = { + "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight": down_weight, + "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight": up_weight, + "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.alpha": alpha, + } + + assert out_state_dict == expected_out_state_dict + + +def test_convert_lora_state_dict_to_kohya_format_unexpected_key(): + """Test that convert_lora_state_dict_to_kohya_format(...) raises an exception if it receives an unexpected + key. + """ + in_state_dict = { + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.unexpected": torch.Tensor(4, 2), + } + + with pytest.raises(ValueError): + _ = convert_lora_state_dict_to_kohya_format(in_state_dict) + + +def test_convert_lora_state_dict_to_kohya_format_conflicting_keys(): + """Test that convert_lora_state_dict_to_kohya_format(...) raises an exception if multiple keys map to the same + output key. + """ + # Note: There are differences in the '.' and '_' characters of these keys, but they both map to the same output + # kohya_ss keys. + in_state_dict = { + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": torch.Tensor(4, 2), + "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1_to_q._down.weight": torch.Tensor(4, 2), + } + + with pytest.raises(RuntimeError): + _ = convert_lora_state_dict_to_kohya_format(in_state_dict) diff --git a/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py b/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py deleted file mode 100644 index 6209a40a..00000000 --- a/tests/invoke_training/lora/injection/test_stable_diffusion_v1.py +++ /dev/null @@ -1,145 +0,0 @@ -import pytest -import torch -from diffusers.models import UNet2DConditionModel -from transformers import CLIPTextModel - -from invoke_training.lora.injection.stable_diffusion_v1 import ( - convert_lora_state_dict_to_kohya_format_sd1, - inject_lora_into_clip_text_encoder, - inject_lora_into_unet_sd1, -) - - -@pytest.fixture -def unet(): - return UNet2DConditionModel.from_pretrained( - "runwayml/stable-diffusion-v1-5", - subfolder="unet", - local_files_only=True, - revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", - ) - - -@pytest.mark.loads_model -def test_inject_lora_into_unet_sd1_smoke(unet): - """Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model.""" - lora_layers = inject_lora_into_unet_sd1(unet) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == 192 - for layer_name in lora_layers._names: - assert layer_name.endswith( - ("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", ".proj_in", ".proj_out") - ) - - -@pytest.mark.loads_model -def test_inject_lora_into_unet_sd1_non_attention_layers_smoke(unet): - """Smoke test of inject_lora_into_unet_sd1(..., include_non_attention_blocks=True) on full SD 1.5 model.""" - lora_layers = inject_lora_into_unet_sd1(unet, include_non_attention_blocks=True) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == 278 - for layer_name in lora_layers._names: - assert layer_name.endswith( - ( - "to_q", - "to_k", - "to_v", - "to_out.0", - "ff.net.0.proj", - "ff.net.2", - ".proj_in", - ".proj_out", - ".conv1", - ".conv2", - ".time_emb_proj", - ".conv", - ".conv_shortcut", - ) - ) - - -@pytest.mark.loads_model -def test_inject_lora_into_clip_text_encoder_smoke(): - """Smoke test of inject_lora_into_clip_text_encoder(...) on full SD 1.5 model.""" - text_encoder = CLIPTextModel.from_pretrained( - "runwayml/stable-diffusion-v1-5", - subfolder="text_encoder", - local_files_only=True, - revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0", - ) - - lora_layers = inject_lora_into_clip_text_encoder(text_encoder) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(lora_layers) == 72 # 216 / 3 - for layer_name in lora_layers._names: - assert layer_name.endswith(("mlp.fc1", "mlp.fc2", "k_proj", "out_proj", "q_proj", "v_proj")) - - -@pytest.mark.loads_model -def test_convert_lora_state_dict_to_kohya_format_sd1_smoke(unet): - """Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model.""" - lora_layers = inject_lora_into_unet_sd1(unet) - lora_state_dict = lora_layers.get_lora_state_dict() - kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(lora_state_dict) - - # These assertions are based on a manual check of the injected layers and comparison against the behaviour of - # kohya_ss. They are included here to force another manual review after any future breaking change. - assert len(kohya_state_dict) == 192 * 3 - for key in kohya_state_dict.keys(): - assert key.startswith("lora_unet_") - assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha")) - - -def test_convert_lora_state_dict_to_kohya_format_sd1(): - """Basic test of convert_lora_state_dict_to_kohya_format_sd1(...).""" - down_weight = torch.Tensor(4, 2) - up_weight = torch.Tensor(2, 4) - alpha = torch.Tensor([1.0]) - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": down_weight, - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._up.weight": up_weight, - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.alpha": alpha, - } - - out_state_dict = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) - - expected_out_state_dict = { - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight": down_weight, - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight": up_weight, - "lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.alpha": alpha, - } - - assert out_state_dict == expected_out_state_dict - - -def test_convert_lora_state_dict_to_kohya_format_sd1_unexpected_key(): - """Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if it receives an unexpected - key. - """ - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.unexpected": torch.Tensor(4, 2), - } - - with pytest.raises(ValueError): - _ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) - - -def test_convert_lora_state_dict_to_kohya_format_sd1_conflicting_keys(): - """Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if multiple keys map to the same - output key. - """ - # Note: There are differences in the '.' and '_' characters of these keys, but they both map to the same output - # kohya_ss keys. - in_state_dict = { - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": torch.Tensor(4, 2), - "lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1_to_q._down.weight": torch.Tensor(4, 2), - } - - with pytest.raises(RuntimeError): - _ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict) From fb39d165c31749f82a1a3b97f5e2367526d0069b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 10 Aug 2023 18:01:30 -0400 Subject: [PATCH 08/10] Adhere to dataset resolution when generating validation images. --- src/invoke_training/training/finetune_lora/finetune_lora_sd.py | 2 ++ .../training/finetune_lora/finetune_lora_sdxl.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py index c4895a2e..6ab53be7 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sd.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sd.py @@ -189,6 +189,8 @@ def _generate_validation_images( prompt, num_inference_steps=30, generator=generator, + height=config.dataset.resolution, + width=config.dataset.resolution, ).images[0] ) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index 13c49b26..6cba8ea5 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -272,6 +272,8 @@ def _generate_validation_images( prompt, num_inference_steps=30, generator=generator, + height=config.dataset.resolution, + width=config.dataset.resolution, ).images[0] ) From 98fe4712be81de6f40725fbd2bceca6f4c5de0db Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 10 Aug 2023 20:17:47 -0400 Subject: [PATCH 09/10] Fix filename typo: capture -> caption --- ...e_sdxl_dataloader.py => test_image_caption_sdxl_dataloader.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/invoke_training/training/shared/datasets/{test_image_capture_sdxl_dataloader.py => test_image_caption_sdxl_dataloader.py} (100%) diff --git a/tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py b/tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataloader.py similarity index 100% rename from tests/invoke_training/training/shared/datasets/test_image_capture_sdxl_dataloader.py rename to tests/invoke_training/training/shared/datasets/test_image_caption_sdxl_dataloader.py From 859c082109e7fb029db83ae7ab462ba10b19ffaa Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 13 Aug 2023 15:42:47 -0400 Subject: [PATCH 10/10] Fix bug in SDXL LoRA training. --- .../training/finetune_lora/finetune_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py index 6cba8ea5..785f0384 100644 --- a/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py +++ b/src/invoke_training/training/finetune_lora/finetune_lora_sdxl.py @@ -425,7 +425,7 @@ def run_training(config: FinetuneLoRASDXLConfig): # noqa: C901 lora_layers["unet"] = inject_lora_into_unet(unet, config.train_unet_non_attention_blocks) if config.train_text_encoder: lora_layers["text_encoder_1"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te1") - lora_layers["text_encoder_2"] = inject_lora_into_clip_text_encoder(text_encoder_1, "lora_te2") + lora_layers["text_encoder_2"] = inject_lora_into_clip_text_encoder(text_encoder_2, "lora_te2") if config.xformers: import xformers # noqa: F401