-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preparation for SDXL support #12
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
edf6ec6
Add get_base_model_version(...) helper function.
RyanJDick 13fd1d9
Add check_base_model_version(...) helper function.
RyanJDick a06b8f8
Allow model download in training script, and add base model version c…
RyanJDick 74f3f18
Move accelerator utils to shared/ directory in preparation for SDXL s…
RyanJDick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import logging | ||
import os | ||
|
||
import datasets | ||
import diffusers | ||
import torch | ||
import transformers | ||
from accelerate import Accelerator | ||
from accelerate.logging import MultiProcessAdapter, get_logger | ||
from accelerate.utils import ProjectConfiguration | ||
|
||
|
||
def initialize_accelerator( | ||
out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str | ||
) -> Accelerator: | ||
"""Configure Hugging Face accelerate and return an Accelerator. | ||
|
||
Args: | ||
out_dir (str): The output directory where results will be written. | ||
gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...). | ||
mixed_precision (str): Forwarded to accelerate.Accelerator(...). | ||
log_with (str): Forwarded to accelerat.Accelerator(...) | ||
|
||
Returns: | ||
Accelerator | ||
""" | ||
accelerator_project_config = ProjectConfiguration( | ||
project_dir=out_dir, | ||
logging_dir=os.path.join(out_dir, "logs"), | ||
) | ||
return Accelerator( | ||
project_config=accelerator_project_config, | ||
gradient_accumulation_steps=gradient_accumulation_steps, | ||
mixed_precision=mixed_precision, | ||
log_with=log_with, | ||
) | ||
|
||
|
||
def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter: | ||
"""Configure logging. | ||
|
||
Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the | ||
main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers, | ||
diffusers). | ||
|
||
Args: | ||
accelerator (Accelerator): The Accelerator to configure. | ||
|
||
Returns: | ||
MultiProcessAdapter: _description_ | ||
""" | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO, | ||
) | ||
if accelerator.is_local_main_process: | ||
datasets.utils.logging.set_verbosity_warning() | ||
transformers.utils.logging.set_verbosity_warning() | ||
diffusers.utils.logging.set_verbosity_info() | ||
else: | ||
# Only log errors from non-main processes. | ||
datasets.utils.logging.set_verbosity_error() | ||
transformers.utils.logging.set_verbosity_error() | ||
diffusers.utils.logging.set_verbosity_error() | ||
|
||
return get_logger(logger_name) | ||
|
||
|
||
def get_mixed_precision_dtype(accelerator: Accelerator): | ||
"""Extract torch.dtype from Accelerator config. | ||
|
||
Args: | ||
accelerator (Accelerator): The Hugging Face Accelerator. | ||
|
||
Raises: | ||
NotImplementedError: If the accelerator's mixed_precision configuration is not recognized. | ||
|
||
Returns: | ||
torch.dtype: The weight type inferred from the accelerator mixed_precision configuration. | ||
""" | ||
weight_dtype: torch.dtype = torch.float32 | ||
if accelerator.mixed_precision is None or accelerator.mixed_precision == "no": | ||
weight_dtype = torch.float32 | ||
elif accelerator.mixed_precision == "fp16": | ||
weight_dtype = torch.float16 | ||
elif accelerator.mixed_precision == "bf16": | ||
weight_dtype = torch.bfloat16 | ||
else: | ||
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.") | ||
return weight_dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from enum import Enum | ||
|
||
from transformers import PretrainedConfig | ||
|
||
|
||
class BaseModelVersionEnum(Enum): | ||
STABLE_DIFFUSION_V1 = 1 | ||
STABLE_DIFFUSION_V2 = 2 | ||
STABLE_DIFFUSION_SDXL_BASE = 3 | ||
STABLE_DIFFUSION_SDXL_REFINER = 4 | ||
|
||
|
||
def get_base_model_version( | ||
diffusers_model_name: str, revision: str = "main", local_files_only: bool = True | ||
) -> BaseModelVersionEnum: | ||
"""Returns the `BaseModelVersionEnum` of a diffusers model. | ||
|
||
Args: | ||
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub). | ||
revision (str, optional): The model revision (branch or commit hash). Defaults to "main". | ||
|
||
Raises: | ||
Exception: If the base model version can not be determined. | ||
|
||
Returns: | ||
BaseModelVersionEnum: The detected base model version. | ||
""" | ||
unet_config = PretrainedConfig.from_pretrained( | ||
pretrained_model_name_or_path=diffusers_model_name, | ||
revision=revision, | ||
subfolder="unet", | ||
local_files_only=local_files_only, | ||
) | ||
|
||
# This logic was copied from | ||
# https://github.com/invoke-ai/InvokeAI/blob/e77400ab62d24acbdf2f48a7427705e7b8b97e4a/invokeai/backend/model_management/model_probe.py#L412-L421 | ||
# This seems fragile. If you see this and know of a better way to detect the base model version, your contribution | ||
# would be welcome. | ||
if unet_config.cross_attention_dim == 768: | ||
return BaseModelVersionEnum.STABLE_DIFFUSION_V1 | ||
elif unet_config.cross_attention_dim == 1024: | ||
return BaseModelVersionEnum.STABLE_DIFFUSION_V2 | ||
elif unet_config.cross_attention_dim == 1280: | ||
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER | ||
elif unet_config.cross_attention_dim == 2048: | ||
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE | ||
else: | ||
raise Exception( | ||
"Failed to determine base model version. UNet cross_attention_dim has unexpected value: " | ||
f"'{ unet_config.cross_attention_dim}'." | ||
) | ||
|
||
|
||
def check_base_model_version( | ||
allowed_versions: set[BaseModelVersionEnum], | ||
diffusers_model_name: str, | ||
revision: str = "main", | ||
local_files_only: bool = True, | ||
): | ||
"""Helper function that checks if a diffusers model is compatible with a set of base model versions. | ||
|
||
Args: | ||
allowed_versions (set[BaseModelVersionEnum]): The set of allowed base model versions. | ||
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub) to check. | ||
revision (str, optional): The model revision (branch or commit hash). Defaults to "main". | ||
|
||
Raises: | ||
ValueError: If the model has an unsupported version. | ||
""" | ||
version = get_base_model_version(diffusers_model_name, revision, local_files_only) | ||
if version not in allowed_versions: | ||
raise ValueError( | ||
f"Model '{diffusers_model_name}' (revision='{revision}') has an unsupported version: '{version.name}'. " | ||
f"Supported versions: {[v.name for v in allowed_versions]}." | ||
) |
48 changes: 48 additions & 0 deletions
48
tests/invoke_training/training/shared/test_base_model_version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pytest | ||
from transformers import PretrainedConfig | ||
|
||
from invoke_training.training.shared.base_model_version import ( | ||
BaseModelVersionEnum, | ||
check_base_model_version, | ||
get_base_model_version, | ||
) | ||
|
||
|
||
@pytest.mark.loads_model | ||
@pytest.mark.parametrize( | ||
["diffusers_model_name", "expected_version"], | ||
[ | ||
("runwayml/stable-diffusion-v1-5", BaseModelVersionEnum.STABLE_DIFFUSION_V1), | ||
("stabilityai/stable-diffusion-2-1", BaseModelVersionEnum.STABLE_DIFFUSION_V2), | ||
("stabilityai/stable-diffusion-xl-base-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE), | ||
("stabilityai/stable-diffusion-xl-refiner-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER), | ||
], | ||
) | ||
def test_get_base_model_version(diffusers_model_name: str, expected_version: BaseModelVersionEnum): | ||
"""Test get_base_model_version(...) with one test model for each supported version.""" | ||
# Check if the diffusers_model_name model is downloaded and xfail if not. | ||
# This check ensures that users don't have to download all of the test models just to run the test suite. | ||
try: | ||
_ = PretrainedConfig.from_pretrained( | ||
pretrained_model_name_or_path=diffusers_model_name, | ||
subfolder="unet", | ||
local_files_only=True, | ||
) | ||
except OSError: | ||
pytest.xfail(f"'{diffusers_model_name}' is not downloaded.") | ||
|
||
version = get_base_model_version(diffusers_model_name) | ||
assert version == expected_version | ||
|
||
|
||
@pytest.mark.loads_model | ||
def test_check_base_model_version_pass(): | ||
"""Test that check_base_model_version(...) does not raise an Exception when the model is valid.""" | ||
check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V1}, "runwayml/stable-diffusion-v1-5") | ||
|
||
|
||
@pytest.mark.loads_model | ||
def test_check_base_model_version_fail(): | ||
"""Test that check_base_model_version(...) raises a ValueError when the model is invalid.""" | ||
with pytest.raises(ValueError): | ||
check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V2}, "runwayml/stable-diffusion-v1-5") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be the standard way of doing it right now unfortunately. To specifically point out xl vs not xl, the diffusers class name is pretty clear in the root config of the model. Beyond that, I'm not sure of a better way.