Skip to content

Commit

Permalink
Fix static type errors with SCHEDULER_NAME_VALUES. And, avoid bi-dire…
Browse files Browse the repository at this point in the history
…ctional cross-directory imports, which contribute to circular import issues.
  • Loading branch information
RyanJDick authored and hipsterusername committed Jul 5, 2024
1 parent 3a24d70 commit 35f8781
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 11 deletions.
4 changes: 0 additions & 4 deletions invokeai/app/invocations/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Literal

from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice

LATENT_SCALE_FACTOR = 8
Expand All @@ -11,9 +10,6 @@
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""

SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""

IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

Expand Down
3 changes: 2 additions & 1 deletion invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers import CLIPVisionModelWithProjection

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
Expand Down Expand Up @@ -54,6 +54,7 @@
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES


@invocation_output("scheduler_output")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
Expand All @@ -29,6 +29,7 @@
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict

from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES

# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
Expand Down
32 changes: 31 additions & 1 deletion invokeai/backend/stable_diffusion/schedulers/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Literal, Type

from diffusers import (
DDIMScheduler,
DDPMScheduler,
Expand All @@ -16,8 +18,36 @@
TCDScheduler,
UniPCMultistepScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin

SCHEDULER_NAME_VALUES = Literal[
"ddim",
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
"lcm",
"tcd",
]

SCHEDULER_MAP = {
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
"ddim": (DDIMScheduler, {}),
"ddpm": (DDPMScheduler, {}),
"deis": (DEISMultistepScheduler, {}),
Expand Down
4 changes: 2 additions & 2 deletions invokeai/invocation_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
BoardField,
ColorField,
Expand Down Expand Up @@ -78,6 +77,7 @@
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
from invokeai.version import __version__

Expand Down Expand Up @@ -163,7 +163,7 @@
"BaseModelType",
"ModelType",
"SubModelType",
# invokeai.app.invocations.constants
# invokeai.backend.stable_diffusion.schedulers.schedulers
"SCHEDULER_NAME_VALUES",
# invokeai.version
"__version__",
Expand Down
10 changes: 10 additions & 0 deletions tests/backend/stable_diffusion/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import get_args

from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES


def test_scheduler_map_has_all_keys():
# Assert that SCHEDULER_MAP has all keys from SCHEDULER_NAME_VALUES.
# TODO(ryand): This feels like it should be a type check, but I couldn't find a clean way to do this and didn't want
# to spend more time on it.
assert set(SCHEDULER_MAP.keys()) == set(get_args(SCHEDULER_NAME_VALUES))

0 comments on commit 35f8781

Please sign in to comment.