From 35f8781ea22cccbbbb6fab92180ee2b7bca35a70 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 3 Jul 2024 11:13:16 -0400 Subject: [PATCH] Fix static type errors with SCHEDULER_NAME_VALUES. And, avoid bi-directional cross-directory imports, which contribute to circular import issues. --- invokeai/app/invocations/constants.py | 4 --- invokeai/app/invocations/denoise_latents.py | 3 +- invokeai/app/invocations/scheduler.py | 2 +- .../tiled_multi_diffusion_denoise_latents.py | 3 +- invokeai/backend/model_manager/config.py | 2 +- .../stable_diffusion/schedulers/schedulers.py | 32 ++++++++++++++++++- invokeai/invocation_api/__init__.py | 4 +-- .../schedulers/test_schedulers.py | 10 ++++++ 8 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 tests/backend/stable_diffusion/schedulers/test_schedulers.py diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index e01589be812..e97275e4fd8 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,6 +1,5 @@ from typing import Literal -from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.util.devices import TorchDevice LATENT_SCALE_FACTOR = 8 @@ -11,9 +10,6 @@ The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. """ -SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] -"""A literal type representing the valid scheduler names.""" - IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] """A literal type for PIL image modes supported by Invoke""" diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index fd901298f77..7ccf9068939 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -17,7 +17,7 @@ from transformers import CLIPVisionModelWithProjection from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.fields import ( ConditioningField, @@ -54,6 +54,7 @@ TextConditioningRegions, ) from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.mask import to_standard_float_mask diff --git a/invokeai/app/invocations/scheduler.py b/invokeai/app/invocations/scheduler.py index 52af20378ef..a870a442ef8 100644 --- a/invokeai/app/invocations/scheduler.py +++ b/invokeai/app/invocations/scheduler.py @@ -1,5 +1,4 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( FieldDescriptions, InputField, @@ -7,6 +6,7 @@ UIType, ) from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @invocation_output("scheduler_output") diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 2566fd25514..5d408a4df7c 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -8,7 +8,7 @@ from pydantic import field_validator from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.fields import ( @@ -29,6 +29,7 @@ MultiDiffusionPipeline, MultiDiffusionRegionConditioning, ) +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.tiles.tiles import ( calc_tiles_min_overlap, ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index a8eb13d3396..dbcd2593682 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,10 +30,10 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict -from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.raw_model import RawModel +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime diff --git a/invokeai/backend/stable_diffusion/schedulers/schedulers.py b/invokeai/backend/stable_diffusion/schedulers/schedulers.py index 3a55d52d4a0..7d6851e278d 100644 --- a/invokeai/backend/stable_diffusion/schedulers/schedulers.py +++ b/invokeai/backend/stable_diffusion/schedulers/schedulers.py @@ -1,3 +1,5 @@ +from typing import Any, Literal, Type + from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -16,8 +18,36 @@ TCDScheduler, UniPCMultistepScheduler, ) +from diffusers.schedulers.scheduling_utils import SchedulerMixin + +SCHEDULER_NAME_VALUES = Literal[ + "ddim", + "ddpm", + "deis", + "lms", + "lms_k", + "pndm", + "heun", + "heun_k", + "euler", + "euler_k", + "euler_a", + "kdpm_2", + "kdpm_2_a", + "dpmpp_2s", + "dpmpp_2s_k", + "dpmpp_2m", + "dpmpp_2m_k", + "dpmpp_2m_sde", + "dpmpp_2m_sde_k", + "dpmpp_sde", + "dpmpp_sde_k", + "unipc", + "lcm", + "tcd", +] -SCHEDULER_MAP = { +SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = { "ddim": (DDIMScheduler, {}), "ddpm": (DDPMScheduler, {}), "deis": (DEISMultistepScheduler, {}), diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 97260c4dfe0..586f85b9c26 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -11,7 +11,6 @@ invocation, invocation_output, ) -from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -78,6 +77,7 @@ ConditioningFieldData, SDXLConditioningInfo, ) +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device from invokeai.version import __version__ @@ -163,7 +163,7 @@ "BaseModelType", "ModelType", "SubModelType", - # invokeai.app.invocations.constants + # invokeai.backend.stable_diffusion.schedulers.schedulers "SCHEDULER_NAME_VALUES", # invokeai.version "__version__", diff --git a/tests/backend/stable_diffusion/schedulers/test_schedulers.py b/tests/backend/stable_diffusion/schedulers/test_schedulers.py new file mode 100644 index 00000000000..bb49fc4f3bc --- /dev/null +++ b/tests/backend/stable_diffusion/schedulers/test_schedulers.py @@ -0,0 +1,10 @@ +from typing import get_args + +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES + + +def test_scheduler_map_has_all_keys(): + # Assert that SCHEDULER_MAP has all keys from SCHEDULER_NAME_VALUES. + # TODO(ryand): This feels like it should be a type check, but I couldn't find a clean way to do this and didn't want + # to spend more time on it. + assert set(SCHEDULER_MAP.keys()) == set(get_args(SCHEDULER_NAME_VALUES))