Skip to content

Commit

Permalink
chore: Update model config type names
Browse files Browse the repository at this point in the history
  • Loading branch information
blessedcoolant committed Jun 17, 2023
1 parent 4cbc802 commit 67d05d2
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 14 deletions.
4 changes: 2 additions & 2 deletions invokeai/backend/model_management/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int

class Config(ModelConfigBase):
class ControlNetModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]

def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
Expand Down Expand Up @@ -82,6 +82,6 @@ def convert_if_required(
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path
2 changes: 1 addition & 1 deletion invokeai/backend/model_management/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class LoRAModel(ModelBase):
#model_size: int

class Config(ModelConfigBase):
class LoraModelConfig(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]

def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
Expand Down
18 changes: 9 additions & 9 deletions invokeai/backend/model_management/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@

class StableDiffusion1Model(DiffusersModel):

class DiffusersConfig(ModelConfigBase):
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType

class CheckpointConfig(ModelConfigBase):
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
Expand Down Expand Up @@ -107,7 +107,7 @@ def convert_if_required(
) -> str:
assert model_path == config.path

if isinstance(config, cls.CheckpointConfig):
if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
Expand All @@ -120,14 +120,14 @@ def convert_if_required(
class StableDiffusion2Model(DiffusersModel):

# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool

class CheckpointConfig(ModelConfigBase):
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
Expand Down Expand Up @@ -220,7 +220,7 @@ def convert_if_required(
) -> str:
assert model_path == config.path

if isinstance(config, cls.CheckpointConfig):
if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
Expand Down Expand Up @@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
output_path: str,
) -> str:
"""
Expand All @@ -281,8 +281,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon

elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention
prediction_type = config.prediction_type
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type

else:
raise Exception(f"Unknown model provided: {version}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TextualInversionModel(ModelBase):
#model_size: int

class Config(ModelConfigBase):
class TextualInversionModelConfig(ModelConfigBase):
format: None

def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/model_management/models/vae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import safetensors
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
Expand All @@ -22,7 +23,7 @@ class VaeModel(ModelBase):
#vae_class: Type
#model_size: int

class Config(ModelConfigBase):
class VAEModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]

def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
Expand Down

0 comments on commit 67d05d2

Please sign in to comment.