Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preparation for SDXL support #12

Merged
merged 4 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 28 additions & 101 deletions src/invoke_training/training/lora/lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
import os
import time

import datasets
import diffusers
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter, get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
Expand All @@ -28,90 +24,22 @@
inject_lora_into_unet_sd1,
)
from invoke_training.training.lora.lora_training_config import LoRATrainingConfig
from invoke_training.training.shared.accelerator_utils import (
get_mixed_precision_dtype,
initialize_accelerator,
initialize_logging,
)
from invoke_training.training.shared.base_model_version import (
BaseModelVersionEnum,
check_base_model_version,
)
from invoke_training.training.shared.checkpoint_tracker import CheckpointTracker
from invoke_training.training.shared.datasets.image_caption_dataloader import (
build_image_caption_dataloader,
)
from invoke_training.training.shared.serialization import save_state_dict


def _initialize_accelerator(out_dir: str, config: LoRATrainingConfig) -> Accelerator:
"""Configure Hugging Face accelerate and return an Accelerator.

Args:
out_dir (str): The output directory where results will be written.
config (LoRATrainingConfig): LoRA training configuration.

Returns:
Accelerator
"""
accelerator_project_config = ProjectConfiguration(
project_dir=out_dir,
logging_dir=os.path.join(out_dir, "logs"),
)
return Accelerator(
project_config=accelerator_project_config,
gradient_accumulation_steps=config.gradient_accumulation_steps,
mixed_precision=config.mixed_precision,
log_with=config.output.report_to,
)


def _initialize_logging(accelerator: Accelerator) -> MultiProcessAdapter:
"""Configure logging.

Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the
main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,
diffusers).

Args:
accelerator (Accelerator): The Accelerator to configure.

Returns:
MultiProcessAdapter: _description_
"""
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
# Only log errors from non-main processes.
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

return get_logger(__name__)


def _get_weight_type(accelerator: Accelerator):
"""Extract torch.dtype from Accelerator config.

Args:
accelerator (Accelerator): The Hugging Face Accelerator.

Raises:
NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.

Returns:
torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.
"""
weight_dtype: torch.dtype = torch.float32
if accelerator.mixed_precision is None or accelerator.mixed_precision == "no":
weight_dtype = torch.float32
elif accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
return weight_dtype


def _load_models(
accelerator: Accelerator,
config: LoRATrainingConfig,
Expand All @@ -132,23 +60,13 @@ def _load_models(
UNet2DConditionModel,
]: A tuple of loaded models.
"""
weight_dtype = _get_weight_type(accelerator)
weight_dtype = get_mixed_precision_dtype(accelerator)

tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(config.model, subfolder="tokenizer", local_files_only=True)
noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
config.model, subfolder="scheduler", local_files_only=True
)
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
config.model,
subfolder="text_encoder",
local_files_only=True,
)
vae: AutoencoderKL = AutoencoderKL.from_pretrained(config.model, subfolder="vae", local_files_only=True)
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
config.model,
subfolder="unet",
local_files_only=True,
)
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(config.model, subfolder="tokenizer")
noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(config.model, subfolder="scheduler")
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(config.model, subfolder="text_encoder")
vae: AutoencoderKL = AutoencoderKL.from_pretrained(config.model, subfolder="vae")
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(config.model, subfolder="unet")

# Disable gradient calculation for model weights to save memory.
text_encoder.requires_grad_(False)
Expand Down Expand Up @@ -354,12 +272,21 @@ def _train_forward(


def run_lora_training(config: LoRATrainingConfig): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
check_base_model_version(
{BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},
config.model,
local_files_only=False,
)

# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.output.base_output_dir, f"{time.time()}")
os.makedirs(out_dir)

accelerator = _initialize_accelerator(out_dir, config)
logger = _initialize_logging(accelerator)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.output.report_to
)
logger = initialize_logging(__name__, accelerator)

# Set the accelerate seed.
if config.seed is not None:
Expand All @@ -376,7 +303,7 @@ def run_lora_training(config: LoRATrainingConfig): # noqa: C901
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)

weight_dtype = _get_weight_type(accelerator)
weight_dtype = get_mixed_precision_dtype(accelerator)

logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = _load_models(accelerator, config)
Expand Down
91 changes: 91 additions & 0 deletions src/invoke_training/training/shared/accelerator_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import os

import datasets
import diffusers
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter, get_logger
from accelerate.utils import ProjectConfiguration


def initialize_accelerator(
out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str
) -> Accelerator:
"""Configure Hugging Face accelerate and return an Accelerator.

Args:
out_dir (str): The output directory where results will be written.
gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...).
mixed_precision (str): Forwarded to accelerate.Accelerator(...).
log_with (str): Forwarded to accelerat.Accelerator(...)

Returns:
Accelerator
"""
accelerator_project_config = ProjectConfiguration(
project_dir=out_dir,
logging_dir=os.path.join(out_dir, "logs"),
)
return Accelerator(
project_config=accelerator_project_config,
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=log_with,
)


def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter:
"""Configure logging.

Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the
main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,
diffusers).

Args:
accelerator (Accelerator): The Accelerator to configure.

Returns:
MultiProcessAdapter: _description_
"""
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
# Only log errors from non-main processes.
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

return get_logger(logger_name)


def get_mixed_precision_dtype(accelerator: Accelerator):
"""Extract torch.dtype from Accelerator config.

Args:
accelerator (Accelerator): The Hugging Face Accelerator.

Raises:
NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.

Returns:
torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.
"""
weight_dtype: torch.dtype = torch.float32
if accelerator.mixed_precision is None or accelerator.mixed_precision == "no":
weight_dtype = torch.float32
elif accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
return weight_dtype
75 changes: 75 additions & 0 deletions src/invoke_training/training/shared/base_model_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum

from transformers import PretrainedConfig


class BaseModelVersionEnum(Enum):
STABLE_DIFFUSION_V1 = 1
STABLE_DIFFUSION_V2 = 2
STABLE_DIFFUSION_SDXL_BASE = 3
STABLE_DIFFUSION_SDXL_REFINER = 4


def get_base_model_version(
diffusers_model_name: str, revision: str = "main", local_files_only: bool = True
) -> BaseModelVersionEnum:
"""Returns the `BaseModelVersionEnum` of a diffusers model.

Args:
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub).
revision (str, optional): The model revision (branch or commit hash). Defaults to "main".

Raises:
Exception: If the base model version can not be determined.

Returns:
BaseModelVersionEnum: The detected base model version.
"""
unet_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path=diffusers_model_name,
revision=revision,
subfolder="unet",
local_files_only=local_files_only,
)

# This logic was copied from
# https://github.com/invoke-ai/InvokeAI/blob/e77400ab62d24acbdf2f48a7427705e7b8b97e4a/invokeai/backend/model_management/model_probe.py#L412-L421
# This seems fragile. If you see this and know of a better way to detect the base model version, your contribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the standard way of doing it right now unfortunately. To specifically point out xl vs not xl, the diffusers class name is pretty clear in the root config of the model. Beyond that, I'm not sure of a better way.

# would be welcome.
if unet_config.cross_attention_dim == 768:
return BaseModelVersionEnum.STABLE_DIFFUSION_V1
elif unet_config.cross_attention_dim == 1024:
return BaseModelVersionEnum.STABLE_DIFFUSION_V2
elif unet_config.cross_attention_dim == 1280:
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER
elif unet_config.cross_attention_dim == 2048:
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE
else:
raise Exception(
"Failed to determine base model version. UNet cross_attention_dim has unexpected value: "
f"'{ unet_config.cross_attention_dim}'."
)


def check_base_model_version(
allowed_versions: set[BaseModelVersionEnum],
diffusers_model_name: str,
revision: str = "main",
local_files_only: bool = True,
):
"""Helper function that checks if a diffusers model is compatible with a set of base model versions.

Args:
allowed_versions (set[BaseModelVersionEnum]): The set of allowed base model versions.
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub) to check.
revision (str, optional): The model revision (branch or commit hash). Defaults to "main".

Raises:
ValueError: If the model has an unsupported version.
"""
version = get_base_model_version(diffusers_model_name, revision, local_files_only)
if version not in allowed_versions:
raise ValueError(
f"Model '{diffusers_model_name}' (revision='{revision}') has an unsupported version: '{version.name}'. "
f"Supported versions: {[v.name for v in allowed_versions]}."
)
48 changes: 48 additions & 0 deletions tests/invoke_training/training/shared/test_base_model_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from transformers import PretrainedConfig

from invoke_training.training.shared.base_model_version import (
BaseModelVersionEnum,
check_base_model_version,
get_base_model_version,
)


@pytest.mark.loads_model
@pytest.mark.parametrize(
["diffusers_model_name", "expected_version"],
[
("runwayml/stable-diffusion-v1-5", BaseModelVersionEnum.STABLE_DIFFUSION_V1),
("stabilityai/stable-diffusion-2-1", BaseModelVersionEnum.STABLE_DIFFUSION_V2),
("stabilityai/stable-diffusion-xl-base-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE),
("stabilityai/stable-diffusion-xl-refiner-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER),
],
)
def test_get_base_model_version(diffusers_model_name: str, expected_version: BaseModelVersionEnum):
"""Test get_base_model_version(...) with one test model for each supported version."""
# Check if the diffusers_model_name model is downloaded and xfail if not.
# This check ensures that users don't have to download all of the test models just to run the test suite.
try:
_ = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path=diffusers_model_name,
subfolder="unet",
local_files_only=True,
)
except OSError:
pytest.xfail(f"'{diffusers_model_name}' is not downloaded.")

version = get_base_model_version(diffusers_model_name)
assert version == expected_version


@pytest.mark.loads_model
def test_check_base_model_version_pass():
"""Test that check_base_model_version(...) does not raise an Exception when the model is valid."""
check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V1}, "runwayml/stable-diffusion-v1-5")


@pytest.mark.loads_model
def test_check_base_model_version_fail():
"""Test that check_base_model_version(...) raises a ValueError when the model is invalid."""
with pytest.raises(ValueError):
check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V2}, "runwayml/stable-diffusion-v1-5")