diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index f510279f183..50d645eb572 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management.models import get_all_model_configs -MODEL_CONFIGS = Union[tuple(get_all_model_configs())] +from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS +MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel): info: DiffusersModelInfo = Field(description="The converted model info") class ModelsList(BaseModel): - models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend - #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] + models: list[MODEL_CONFIGS] @models_router.get( @@ -72,10 +71,10 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: BaseModelType = Query( + base_model: Optional[BaseModelType] = Query( default=None, description="Base model" ), - model_type: ModelType = Query( + model_type: Optional[ModelType] = Query( default=None, description="The type of model to get" ), ) -> ModelsList: diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 22b4efec747..e14c58bab76 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -120,6 +120,22 @@ def custom_openapi(): invoker_schema["output"] = outputs_ref + from invokeai.backend.model_management.models import get_model_config_enums + for model_config_format_enum in set(get_model_config_enums()): + name = model_config_format_enum.__qualname__ + + if name in openapi_schema["components"]["schemas"]: + # print(f"Config with name {name} already defined") + continue + + # "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"} + openapi_schema["components"]["schemas"][name] = dict( + title=name, + description="An enumeration.", + type="string", + enum=list(v.value for v in model_config_format_enum), + ) + app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 9d77cadf8c5..b77aa5dafd4 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -43,115 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput): #fmt: on -class SD1ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" +class PipelineModelField(BaseModel): + """Pipeline model field""" - type: Literal["sd1_model_loader"] = "sd1_model_loader" - - model_name: str = Field(default="", description="Model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["model", "loader"], - "type_hints": { - "model_name": "model" # TODO: rename to model_name? - } - }, - } - - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - - base_model = BaseModelType.StableDiffusion1 # TODO: - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - ): - raise Exception(f"Unkown model name: {self.model_name}!") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") - return ModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.TextEncoder, - ), - loras=[], - ), - vae=VaeField( - vae=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Vae, - ), - ) - ) +class PipelineModelLoaderInvocation(BaseInvocation): + """Loads a pipeline model, outputting its submodels.""" -# TODO: optimize(less code copy) -class SD2ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" + type: Literal["pipeline_model_loader"] = "pipeline_model_loader" - type: Literal["sd2_model_loader"] = "sd2_model_loader" - - model_name: str = Field(default="", description="Model to load") + model: PipelineModelField = Field(description="The model to load") # TODO: precision? # Schema customisation @@ -160,22 +64,24 @@ class Config(InvocationConfig): "ui": { "tags": ["model", "loader"], "type_hints": { - "model_name": "model" # TODO: rename to model_name? + "model": "model" } }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = BaseModelType.StableDiffusion2 # TODO: + base_model = self.model.base_model + model_name = self.model.model_name + model_type = ModelType.Pipeline # TODO: not found exceptions if not context.services.model_manager.model_exists( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, ): - raise Exception(f"Unkown model name: {self.model_name}!") + raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") """ if not context.services.model_manager.model_exists( @@ -210,39 +116,39 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.TextEncoder, ), loras=[], ), vae=VaeField( vae=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Vae, ), ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index c212ff6a726..8b46b17ad04 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -5,7 +5,7 @@ import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING +from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( @@ -69,19 +69,6 @@ def model_exists( ) -> bool: pass - @abstractmethod - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name and typeof the default model, or None - if none is defined. - """ - pass - - @abstractmethod - def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType): - """Sets the default model to the indicated name.""" - pass - @abstractmethod def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ @@ -270,17 +257,6 @@ def model_exists( model_type, ) - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name of the default model, or None - if none is defined. - """ - return self.mgr.default_model() - - def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType): - """Sets the default model to the indicated name.""" - self.mgr.set_default_model(model_name, base_model, model_type) - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. @@ -297,21 +273,10 @@ def list_models( self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> dict: + ) -> list[dict]: + # ) -> dict: """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } + Return a list of models. """ return self.mgr.list_models(base_model, model_type) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index cbb319c6ea6..f9a66a87ddb 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -266,6 +266,8 @@ def __init__( for model_key, model_config in config.items(): model_name, base_model, model_type = self.parse_key(model_key) model_class = MODEL_CLASSES[base_model][model_type] + # alias for config file + model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) # check config version number and update on disk/RAM if necessary @@ -445,38 +447,6 @@ def get_model( _cache = self.cache, ) - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name of the default model, or None - if none is defined. - """ - for model_key, model_config in self.models.items(): - if model_config.default: - return self.parse_key(model_key) - - for model_key, _ in self.models.items(): - return self.parse_key(model_key) - else: - return None # TODO: or redo as (None, None, None) - - def set_default_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> None: - """ - Set the default model. The change will not take - effect until you call model_manager.commit() - """ - - model_key = self.model_key(model_name, base_model, model_type) - if model_key not in self.models: - raise Exception(f"Unknown model: {model_key}") - - for cur_model_key, config in self.models.items(): - config.default = cur_model_key == model_key - def model_info( self, model_name: str, @@ -503,9 +473,9 @@ def list_models( self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> Dict[str, Dict[str, str]]: + ) -> list[dict]: """ - Return a dict of models, in format [base_model][model_type][model_name] + Return a list of models. Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model @@ -513,7 +483,7 @@ def list_models( object derived from models.yaml """ - models = dict() + models = [] for model_key in sorted(self.models, key=str.casefold): model_config = self.models[model_key] @@ -523,18 +493,16 @@ def list_models( if model_type is not None and cur_model_type != model_type: continue - if cur_base_model not in models: - models[cur_base_model] = dict() - if cur_model_type not in models[cur_base_model]: - models[cur_base_model][cur_model_type] = dict() - - models[cur_base_model][cur_model_type][cur_model_name] = dict( + model_dict = dict( **model_config.dict(exclude_defaults=True), + # OpenAPIModelInfoBase name=cur_model_name, base_model=cur_base_model, type=cur_model_type, ) + models.append(model_dict) + return models def print_models(self) -> None: @@ -646,7 +614,9 @@ def commit(self, conf_file: Path=None) -> None: model_class = MODEL_CLASSES[base_model][model_type] if model_class.save_to_config: # TODO: or exclude_unset better fits here? - data_to_save[model_key] = model_config.dict(exclude_defaults=True) + data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) + # alias for config file + data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format") yaml_str = OmegaConf.to_yaml(data_to_save) config_file_path = conf_file or self.config_path diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 40995498bf4..6975d45f93b 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -1,3 +1,7 @@ +import inspect +from enum import Enum +from pydantic import BaseModel +from typing import Literal, get_origin from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel @@ -29,10 +33,63 @@ #}, } -def get_all_model_configs(): - configs = set() - for models in MODEL_CLASSES.values(): - for _, model in models.items(): - configs.update(model._get_configs().values()) - configs.discard(None) - return list(configs) # TODO: set, list or tuple +MODEL_CONFIGS = list() +OPENAPI_MODEL_CONFIGS = list() + +class OpenAPIModelInfoBase(BaseModel): + name: str + base_model: BaseModelType + type: ModelType + + +for base_model, models in MODEL_CLASSES.items(): + for model_type, model_class in models.items(): + model_configs = set(model_class._get_configs().values()) + model_configs.discard(None) + MODEL_CONFIGS.extend(model_configs) + + for cfg in model_configs: + model_name, cfg_name = cfg.__qualname__.split('.')[-2:] + openapi_cfg_name = model_name + cfg_name + if openapi_cfg_name in vars(): + continue + + api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict( + __annotations__ = dict( + type=Literal[model_type.value], + ), + )) + + #globals()[openapi_cfg_name] = api_wrapper + vars()[openapi_cfg_name] = api_wrapper + OPENAPI_MODEL_CONFIGS.append(api_wrapper) + +def get_model_config_enums(): + enums = list() + + for model_config in MODEL_CONFIGS: + fields = inspect.get_annotations(model_config) + try: + field = fields["model_format"] + except: + raise Exception("format field not found") + + # model_format: None + # model_format: SomeModelFormat + # model_format: Literal[SomeModelFormat.Diffusers] + # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint] + + if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): + enums.append(field) + + elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + enums.append(type(field.__args__[0])) + + elif field is None: + pass + + else: + raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") + + return enums + diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 3bf0045918f..ef354ecc076 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -48,12 +48,10 @@ class ModelError(str, Enum): class ModelConfigBase(BaseModel): path: str # or Path - #name: str # not included as present in model key description: Optional[str] = Field(None) - format: Optional[str] = Field(None) - default: Optional[bool] = Field(False) + model_format: Optional[str] = Field(None) # do not save to config - error: Optional[ModelError] = Field(None, exclude=True) + error: Optional[ModelError] = Field(None) class Config: use_enum_values = True @@ -94,6 +92,11 @@ def __init__( def _hf_definition_to_type(self, subtypes: List[str]) -> Type: if len(subtypes) < 2: raise Exception("Invalid subfolder definition!") + if all(t is None for t in subtypes): + return None + elif any(t is None for t in subtypes): + raise Exception(f"Unsupported definition: {subtypes}") + if subtypes[0] in ["diffusers", "transformers"]: res_type = sys.modules[subtypes[0]] subtypes = subtypes[1:] @@ -122,47 +125,41 @@ def _get_configs(cls): continue fields = inspect.get_annotations(value) - if "format" not in fields: - raise Exception("Invalid config definition - format field not found") - - format_type = typing.get_origin(fields["format"]) - if format_type not in {None, Literal, Union}: - raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") - - if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__): - raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") + try: + field = fields["model_format"] + except: + raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - if format_type == Union: - f_fields = fields["format"].__args__ - else: - f_fields = (fields["format"],) - + if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): + for model_format in field: + configs[model_format.value] = value - for field in f_fields: - if field is None: - format_name = None - else: - format_name = field.__args__[0] + elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + for model_format in field.__args__: + configs[model_format.value] = value - configs[format_name] = value # TODO: error when override(multiple)? + elif field is None: + configs[None] = value + else: + raise Exception(f"Unsupported format definition in {cls.__qualname__}") cls.__configs = configs return cls.__configs @classmethod def create_config(cls, **kwargs) -> ModelConfigBase: - if "format" not in kwargs: - raise Exception("Field 'format' not found in model config") + if "model_format" not in kwargs: + raise Exception("Field 'model_format' not found in model config") configs = cls._get_configs() - return configs[kwargs["format"]](**kwargs) + return configs[kwargs["model_format"]](**kwargs) @classmethod def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: return cls.create_config( path=path, - format=cls.detect_format(path), + model_format=cls.detect_format(path), ) @classmethod diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index d75c55010ad..9563f87afd2 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -14,12 +15,16 @@ classproperty, ) +class ControlNetModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class ControlNetModel(ModelBase): #model_class: Type #model_size: int class Config(ModelConfigBase): - format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: ControlNetModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.ControlNet @@ -69,9 +74,9 @@ def save_to_config(cls) -> bool: @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return ControlNetModelFormat.Diffusers else: - return "checkpoint" + return ControlNetModelFormat.Checkpoint @classmethod def convert_if_required( @@ -81,7 +86,7 @@ def convert_if_required( config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": - raise NotImlemetedError("Checkpoint controlnet models currently unsupported") + if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers: + raise NotImplementedError("Checkpoint controlnet models currently unsupported") else: return model_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index bcf3224ece1..59feacde065 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from typing import Optional, Union, Literal from .base import ( ModelBase, @@ -12,11 +13,15 @@ # TODO: naming from ..lora import LoRAModel as LoRAModelRaw +class LoRAModelFormat(str, Enum): + LyCORIS = "lycoris" + Diffusers = "diffusers" + class LoRAModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: Union[Literal["lycoris"], Literal["diffusers"]] + model_format: LoRAModelFormat # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora @@ -52,9 +57,9 @@ def save_to_config(cls) -> bool: @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return LoRAModelFormat.Diffusers else: - return "lycoris" + return LoRAModelFormat.LyCORIS @classmethod def convert_if_required( @@ -64,7 +69,7 @@ def convert_if_required( config: ModelConfigBase, base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == "diffusers": + if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: # TODO: add diffusers lora when it stabilizes a bit raise NotImplementedError("Diffusers lora not supported") else: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index bd519c88c80..f1693265715 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -1,5 +1,6 @@ import os import json +from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union @@ -19,16 +20,19 @@ from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf +class StableDiffusion1ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): - format: Literal["diffusers"] + model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): - format: Literal["checkpoint"] + model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -47,7 +51,7 @@ def __init__(self, model_path: str, base_model: BaseModelType, model_type: Model def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion1ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -57,7 +61,7 @@ def probe_config(cls, path: str, **kwargs): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion1ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -80,7 +84,7 @@ def probe_config(cls, path: str, **kwargs): return cls.create_config( path=path, - format=model_format, + model_format=model_format, config=ckpt_config_path, variant=variant, @@ -93,9 +97,9 @@ def save_to_config(cls) -> bool: @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion1ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion1ModelFormat.Checkpoint @classmethod def convert_if_required( @@ -116,19 +120,22 @@ def convert_if_required( else: return model_path +class StableDiffusion2ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): - format: Literal["diffusers"] + model_format: Literal[StableDiffusion2ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool class CheckpointConfig(ModelConfigBase): - format: Literal["checkpoint"] + model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -149,7 +156,7 @@ def __init__(self, model_path: str, base_model: BaseModelType, model_type: Model def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion2ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -159,7 +166,7 @@ def probe_config(cls, path: str, **kwargs): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion2ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -191,7 +198,7 @@ def probe_config(cls, path: str, **kwargs): return cls.create_config( path=path, - format=model_format, + model_format=model_format, config=ckpt_config_path, variant=variant, @@ -206,9 +213,9 @@ def save_to_config(cls) -> bool: @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion2ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion2ModelFormat.Checkpoint @classmethod def convert_if_required( @@ -281,8 +288,8 @@ def _convert_ckpt_and_cache( prediction_type = SchedulerPredictionType.Epsilon elif version == BaseModelType.StableDiffusion2: - upcast_attention = config.upcast_attention - prediction_type = config.prediction_type + upcast_attention = model_config.upcast_attention + prediction_type = model_config.prediction_type else: raise Exception(f"Unknown model provided: {version}") diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 66847f53ebd..9a032218f03 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: None + model_format: None def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.TextualInversion diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 1edb57ccc4b..76133b074de 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,5 +1,7 @@ import os import torch +import safetensors +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -18,12 +20,16 @@ from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf +class VaeModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class VaeModel(ModelBase): #vae_class: Type #model_size: int class Config(ModelConfigBase): - format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: VaeModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae @@ -70,9 +76,9 @@ def save_to_config(cls) -> bool: @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return VaeModelFormat.Diffusers else: - return "checkpoint" + return VaeModelFormat.Checkpoint @classmethod def convert_if_required( @@ -82,7 +88,7 @@ def convert_if_required( config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": + if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: return _convert_vae_ckpt_and_cache( weights_path=model_path, output_path=output_path, diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index a11d8d048cb..55fcc977459 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -24,6 +24,7 @@ import Toaster from './Toaster'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { useListModelsQuery } from 'services/apiSlice'; const DEFAULT_CONFIG = {}; @@ -46,6 +47,18 @@ const App = ({ const isApplicationReady = useIsApplicationReady(); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + const { data: controlnetModels } = useListModelsQuery({ + model_type: 'controlnet', + }); + const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' }); + const { data: loraModels } = useListModelsQuery({ model_type: 'lora' }); + const { data: embeddingModels } = useListModelsQuery({ + model_type: 'embedding', + }); + const [loadingOverridden, setLoadingOverridden] = useState(false); const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 5025ca081ad..cb18d483014 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; -import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { omit } from 'lodash-es'; @@ -18,7 +17,6 @@ const serializationDenylist: { gallery: galleryPersistDenylist, generation: generationPersistDenylist, lightbox: lightboxPersistDenylist, - models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index c6af5f36121..8f40b0bb59a 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { initialModelsState } from 'features/system/store/modelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -21,7 +20,6 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - models: initialModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 9fe554fee16..bf54e638364 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,9 +1,8 @@ -import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; +import { startAppListening } from '../..'; const moduleLog = log.child({ namespace: 'socketio' }); @@ -15,7 +14,7 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { models, nodes, config, images } = getState(); + const { nodes, config, images } = getState(); const { disabledTabs } = config; @@ -28,10 +27,6 @@ export const addSocketConnectedEventListener = () => { ); } - if (!models.ids.length) { - dispatch(receivedModels()); - } - if (!nodes.schema && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index a9011f93565..57a97168a38 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -5,34 +5,32 @@ import { configureStore, } from '@reduxjs/toolkit'; -import { rememberReducer, rememberEnhancer } from 'redux-remember'; import dynamicMiddlewares from 'redux-dynamic-middlewares'; +import { rememberEnhancer, rememberReducer } from 'redux-remember'; import canvasReducer from 'features/canvas/store/canvasSlice'; +import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; import imagesReducer from 'features/gallery/store/imagesSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; -import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; // import sessionReducer from 'features/system/store/sessionSlice'; -import configReducer from 'features/system/store/configSlice'; -import uiReducer from 'features/ui/store/uiSlice'; -import hotkeysReducer from 'features/ui/store/hotkeysSlice'; -import modelsReducer from 'features/system/store/modelSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; import boardsReducer from 'features/gallery/store/boardSlice'; +import configReducer from 'features/system/store/configSlice'; +import hotkeysReducer from 'features/ui/store/hotkeysSlice'; +import uiReducer from 'features/ui/store/uiSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; - +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { LOCALSTORAGE_PREFIX } from './constants'; import { api } from 'services/apiSlice'; const allReducers = { @@ -40,7 +38,6 @@ const allReducers = { gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - models: modelsReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, @@ -50,8 +47,8 @@ const allReducers = { images: imagesReducer, controlNet: controlNetReducer, boards: boardsReducer, - [api.reducerPath]: api.reducer, // session: sessionReducer, + [api.reducerPath]: api.reducer, }; const rootReducer = combineReducers(allReducers); @@ -63,7 +60,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'gallery', 'generation', 'lightbox', - // 'models', 'nodes', 'postprocessing', 'system', diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index a1ef69de01b..c274f23f260 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -1,28 +1,18 @@ -import { Select } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { ModelInputFieldTemplate, ModelInputFieldValue, } from 'features/nodes/types/types'; -import { selectModelsIds } from 'features/system/store/modelSlice'; -import { isEqual } from 'lodash-es'; -import { ChangeEvent, memo } from 'react'; -import { FieldComponentProps } from './types'; -const availableModelsSelector = createSelector( - [selectModelsIds], - (allModelNames) => { - return { allModelNames }; - // return map(modelList, (_, name) => name); - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { FieldComponentProps } from './types'; +import { forEach, isString } from 'lodash-es'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useTranslation } from 'react-i18next'; +import { useListModelsQuery } from 'services/apiSlice'; const ModelInputFieldComponent = ( props: FieldComponentProps @@ -30,28 +20,82 @@ const ModelInputFieldComponent = ( const { nodeId, field } = props; const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], + [pipelineModels?.entities, pipelineModels?.ids, field.value] + ); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && pipelineModels?.ids.includes(field.value)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; - const { allModelNames } = useAppSelector(availableModelsSelector); + if (!isString(firstModel)) { + return; + } - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; + handleValueChanged(firstModel); + }, [field.value, handleValueChanged, pipelineModels?.ids]); return ( - + /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5425d1cfd52..341f0c467b6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -101,21 +101,6 @@ const nodesSlice = createSlice({ builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { state.schema = action.payload; }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - state.nodes.forEach((node) => { - forEach(node.data.inputs, (input) => { - if (input.type === 'image') { - if (input.value?.image_name === image_name) { - input.value.image_url = image_url; - input.value.thumbnail_url = thumbnail_url; - } - } - }); - }); - }); }, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index efaeaddff20..ccdc3e0a277 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -23,6 +23,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 785e1d2fdbf..9ffe85b3c91 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -17,6 +17,7 @@ import { INPAINT_GRAPH, INPAINT, } from './constants'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; + const model = modelIdToPipelineModelField(modelId); + const graph: NonNullableGraph = { id: INPAINT_GRAPH, nodes: { @@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = ( prompt: negativePrompt, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [RANGE_OF_SIZE]: { type: 'range_of_size', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index ca0e56e849f..920cb5bf025 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -14,6 +14,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * Builds the Canvas tab's Text to Image graph. @@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 78a6623a166..8425ac043a0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -22,6 +22,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = ( throw new Error('No initial image found in state'); } + const model = modelIdToPipelineModelField(modelId); + // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { id: IMAGE_TO_IMAGE_GRAPH, @@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index c179a89504a..973acdfb77e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,6 +1,10 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; +import { + BaseModelType, + RandomIntInvocation, + RangeOfSizeInvocation, +} from 'services/api'; import { ITERATE, LATENTS_TO_IMAGE, @@ -14,6 +18,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; type TextToImageGraphOverrides = { width: number; @@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = ( shouldRandomizeSeed, } = state.generation; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 6a700d48132..072b1a53fd3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,9 +1,10 @@ import { Graph } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; +import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; import { AnyInvocation } from 'services/events/types'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * We need to do special handling for some fields @@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'model') { + if (field.value) { + return modelIdToPipelineModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 39e0080d112..7d4469bc411 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,7 +7,7 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const MODEL_LOADER = 'model_loader'; +export const MODEL_LOADER = 'pipeline_model_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts new file mode 100644 index 00000000000..bbcd8d9bc6d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts @@ -0,0 +1,18 @@ +import { BaseModelType, PipelineModelField } from 'services/api'; + +/** + * Crudely converts a model id to a pipeline model field + * TODO: Make better + */ +export const modelIdToPipelineModelField = ( + modelId: string +): PipelineModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: PipelineModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx index 65da89b94d9..5092893eedb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx @@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler'; const ParamSchedulerAndModel = () => { return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 2facb65f045..e7dcbf0d839 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,10 +1,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; -import { clamp, sortBy } from 'lodash-es'; +import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api'; -import { imageUrlsReceived } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -17,7 +16,6 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; -import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; export interface GenerationState { cfgScale: CfgScaleParam; @@ -220,28 +218,12 @@ export const generationSlice = createSlice({ }, }, extraReducers: (builder) => { - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (!state.model) { - const firstModel = sortBy(action.payload, 'name')[0]; - state.model = firstModel.name; - } - }); - builder.addCase(configChanged, (state, action) => { const defaultModel = action.payload.sd?.defaultModel; if (defaultModel && !state.model) { state.model = defaultModel; } }); - - // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - // const { image_name, image_url, thumbnail_url } = action.payload; - - // if (state.initialImage?.image_name === image_name) { - // state.initialImage.image_url = image_url; - // state.initialImage.thumbnail_url = thumbnail_url; - // } - // }); }, }); diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 61567d3fb8f..48eb309e7d6 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -154,3 +154,17 @@ export type StrengthParam = z.infer; */ export const isValidStrength = (val: unknown): val is StrengthParam => zStrength.safeParse(val).success; + +// /** +// * Zod schema for BaseModelType +// */ +// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); +// /** +// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. +// */ +// export type BaseModelType = z.infer; +// /** +// * Validates/type-guards a value as a base model type +// */ +// export const isValidBaseModelType = (val: unknown): val is BaseModelType => +// zBaseModelType.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index a38ab150ddf..43de14d507c 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,44 +1,59 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { isEqual } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIMantineSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { modelSelected } from 'features/parameters/store/generationSlice'; -import { selectModelsAll, selectModelsById } from '../store/modelSlice'; - -const selector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - const selectedModel = selectModelsById(state, generation.model); - - const modelData = selectModelsAll(state) - .map((m) => ({ - value: m.name, - label: m.name, - })) - .sort((a, b) => a.label.localeCompare(b.label)); - return { - selectedModel, - modelData, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); + +import { forEach, isString } from 'lodash-es'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; +import { useListModelsQuery } from 'services/apiSlice'; + +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', +}; const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, modelData } = useAppSelector(selector); + + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[selectedModelId], + [pipelineModels?.entities, selectedModelId] + ); + const handleChangeModel = useCallback( (v: string | null) => { if (!v) { @@ -49,13 +64,27 @@ const ModelSelect = () => { [dispatch] ); + useEffect(() => { + if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleChangeModel(firstModel); + }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + return ( ); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 2e0b3234c75..26c11604e10 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,6 +1,5 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { RootState } from 'app/store/store'; - import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; @@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({ export default function SettingsSchedulers() { const dispatch = useAppDispatch(); + const { t } = useTranslation(); const enabledSchedulers = useAppSelector( diff --git a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts index 193420e29cd..8ba5731a5b2 100644 --- a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts +++ b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts @@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors'; const isApplicationReadySelector = createSelector( [systemSelector, configSelector], (system, config) => { - const { wereModelsReceived, wasSchemaParsed } = system; + const { wasSchemaParsed } = system; const { disabledTabs } = config; return { disabledTabs, - wereModelsReceived, wasSchemaParsed, }; } @@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector( * Checks if the application is ready to be used, i.e. if the initial startup process is finished. */ export const useIsApplicationReady = () => { - const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( + const { disabledTabs, wasSchemaParsed } = useAppSelector( isApplicationReadySelector ); const isApplicationReady = useMemo(() => { - if (!wereModelsReceived) { - return false; - } - if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { return false; } return true; - }, [disabledTabs, wereModelsReceived, wasSchemaParsed]); + }, [disabledTabs, wasSchemaParsed]); return isApplicationReady; }; diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts deleted file mode 100644 index f857bc85bcf..00000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ /dev/null @@ -1,3 +0,0 @@ -import { RootState } from 'app/store/store'; - -export const modelSelector = (state: RootState) => state.models; diff --git a/invokeai/frontend/web/src/features/system/store/modelSlice.ts b/invokeai/frontend/web/src/features/system/store/modelSlice.ts deleted file mode 100644 index ed384258720..00000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSlice.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { createEntityAdapter } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { CkptModelInfo, DiffusersModelInfo } from 'services/api'; -import { receivedModels } from 'services/thunks/model'; - -export type Model = (CkptModelInfo | DiffusersModelInfo) & { - name: string; -}; - -export const modelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const initialModelsState = modelsAdapter.getInitialState(); - -export type ModelsState = typeof initialModelsState; - -export const modelsSlice = createSlice({ - name: 'models', - initialState: initialModelsState, - reducers: { - modelAdded: modelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(receivedModels.fulfilled, (state, action) => { - const models = action.payload; - modelsAdapter.setAll(state, models); - }); - }, -}); - -export const { - selectAll: selectModelsAll, - selectById: selectModelsById, - selectEntities: selectModelsEntities, - selectIds: selectModelsIds, - selectTotal: selectModelsTotal, -} = modelsAdapter.getSelectors((state) => state.models); - -export const { modelAdded } = modelsSlice.actions; - -export default modelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts deleted file mode 100644 index aa9fb057e1c..00000000000 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { ModelsState } from './modelSlice'; - -/** - * Models slice persist denylist - */ -export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index f86415cf379..688f69c1f71 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,20 +1,12 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; -import { ProgressImage } from 'services/events/types'; -import { makeToast } from '../../../app/components/Toaster'; -import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; -import { receivedModels } from 'services/thunks/model'; -import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; -import { LogLevelName } from 'roarr'; import { InvokeLogLevel } from 'app/logging/useLogger'; -import { TFuncKey } from 'i18next'; -import { t } from 'i18next'; import { userInvoked } from 'app/store/actions'; -import { LANGUAGES } from '../components/LanguagePicker'; -import { imageUploaded } from 'services/thunks/image'; +import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; +import { TFuncKey, t } from 'i18next'; +import { LogLevelName } from 'roarr'; import { appSocketConnected, appSocketDisconnected, @@ -26,6 +18,11 @@ import { appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; +import { ProgressImage } from 'services/events/types'; +import { imageUploaded } from 'services/thunks/image'; +import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; +import { makeToast } from '../../../app/components/Toaster'; +import { LANGUAGES } from '../components/LanguagePicker'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -379,13 +376,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Received available models from the backend - */ - builder.addCase(receivedModels.fulfilled, (state) => { - state.wereModelsReceived = true; - }); - /** * OpenAPI schema was parsed */ diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index acb7a411f79..8ce42494e53 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -25,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField'; export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; export type { ControlField } from './models/ControlField'; export type { ControlNetInvocation } from './models/ControlNetInvocation'; +export type { ControlNetModelConfig } from './models/ControlNetModelConfig'; +export type { ControlNetModelFormat } from './models/ControlNetModelFormat'; export type { ControlOutput } from './models/ControlOutput'; export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; @@ -67,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation'; export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntOutput } from './models/IntOutput'; -export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config'; export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { LatentsField } from './models/LatentsField'; @@ -87,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { LoraInfo } from './models/LoraInfo'; export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation'; export type { LoraLoaderOutput } from './models/LoraLoaderOutput'; +export type { LoRAModelConfig } from './models/LoRAModelConfig'; +export type { LoRAModelFormat } from './models/LoRAModelFormat'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; @@ -109,6 +105,8 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; +export type { PipelineModelField } from './models/PipelineModelField'; +export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation'; export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; @@ -120,16 +118,23 @@ export type { ResourceOrigin } from './models/ResourceOrigin'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { SchedulerPredictionType } from './models/SchedulerPredictionType'; -export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation'; -export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation'; +export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig'; +export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig'; +export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat'; +export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig'; +export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig'; +export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat'; export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation'; export type { SubModelType } from './models/SubModelType'; export type { SubtractInvocation } from './models/SubtractInvocation'; export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; +export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig'; export type { UNetField } from './models/UNetField'; export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { VaeField } from './models/VaeField'; +export type { VaeModelConfig } from './models/VaeModelConfig'; +export type { VaeModelFormat } from './models/VaeModelFormat'; export type { VaeRepo } from './models/VaeRepo'; export type { ValidationError } from './models/ValidationError'; export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts new file mode 100644 index 00000000000..60e2958f5c3 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ControlNetModelFormat } from './ControlNetModelFormat'; +import type { ModelError } from './ModelError'; + +export type ControlNetModelConfig = { + name: string; + base_model: BaseModelType; + type: 'controlnet'; + path: string; + description?: string; + model_format: ControlNetModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts new file mode 100644 index 00000000000..500b3e8f8c4 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type ControlNetModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index e148954f164..5fba3d83115 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from './RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from './ShowImageInvocation'; import type { StepParamEasingInvocation } from './StepParamEasingInvocation'; import type { SubtractInvocation } from './SubtractInvocation'; @@ -73,7 +72,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts new file mode 100644 index 00000000000..184a2661694 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { LoRAModelFormat } from './LoRAModelFormat'; +import type { ModelError } from './ModelError'; + +export type LoRAModelConfig = { + name: string; + base_model: BaseModelType; + type: 'lora'; + path: string; + description?: string; + model_format: LoRAModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts new file mode 100644 index 00000000000..829f8a7c576 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type LoRAModelFormat = 'lycoris' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/ModelsList.ts b/invokeai/frontend/web/src/services/api/models/ModelsList.ts index a2d88d19674..9186db3e295 100644 --- a/invokeai/frontend/web/src/services/api/models/ModelsList.ts +++ b/invokeai/frontend/web/src/services/api/models/ModelsList.ts @@ -2,16 +2,16 @@ /* tslint:disable */ /* eslint-disable */ -import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config'; +import type { ControlNetModelConfig } from './ControlNetModelConfig'; +import type { LoRAModelConfig } from './LoRAModelConfig'; +import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig'; +import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig'; +import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig'; +import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig'; +import type { TextualInversionModelConfig } from './TextualInversionModelConfig'; +import type { VaeModelConfig } from './VaeModelConfig'; export type ModelsList = { - models: Record>>; + models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>; }; diff --git a/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts new file mode 100644 index 00000000000..c2f1c07fbfc --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts @@ -0,0 +1,20 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; + +/** + * Pipeline model field + */ +export type PipelineModelField = { + /** + * Name of the model + */ + model_name: string; + /** + * Base model + */ + base_model: BaseModelType; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts similarity index 52% rename from invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts rename to invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts index f477c11a8dd..b8cdb27acfe 100644 --- a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts @@ -2,10 +2,12 @@ /* tslint:disable */ /* eslint-disable */ +import type { PipelineModelField } from './PipelineModelField'; + /** - * Loading submodels of selected model. + * Loads a pipeline model, outputting its submodels. */ -export type SD2ModelLoaderInvocation = { +export type PipelineModelLoaderInvocation = { /** * The id of this node. Must be unique among all nodes. */ @@ -14,10 +16,10 @@ export type SD2ModelLoaderInvocation = { * Whether or not this node is an intermediate node. */ is_intermediate?: boolean; - type?: 'sd2_model_loader'; + type?: 'pipeline_model_loader'; /** - * Model to load + * The model to load */ - model_name?: string; + model: PipelineModelField; }; diff --git a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts deleted file mode 100644 index 9a8a23077aa..00000000000 --- a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -/** - * Loading submodels of selected model. - */ -export type SD1ModelLoaderInvocation = { - /** - * The id of this node. Must be unique among all nodes. - */ - id: string; - /** - * Whether or not this node is an intermediate node. - */ - is_intermediate?: boolean; - type?: 'sd1_model_loader'; - /** - * Model to load - */ - model_name?: string; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts similarity index 60% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts index 6bdcb87dd49..be7077cc535 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts @@ -2,14 +2,17 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = { +export type StableDiffusion1ModelCheckpointConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts similarity index 59% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts index c88e0421783..befe0146056 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts @@ -2,14 +2,17 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = { +export type StableDiffusion1ModelDiffusersConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts new file mode 100644 index 00000000000..01b50c2fc09 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts similarity index 69% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts index ec2ae4a8459..dadd7cac9bb 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts @@ -2,15 +2,18 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = { +export type StableDiffusion2ModelCheckpointConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts similarity index 68% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts index 67b897d9d97..1e4a34c5dce 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts @@ -2,15 +2,18 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = { +export type StableDiffusion2ModelDiffusersConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts new file mode 100644 index 00000000000..7e7b8952310 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts new file mode 100644 index 00000000000..97d6aa7ffad --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts @@ -0,0 +1,17 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ModelError } from './ModelError'; + +export type TextualInversionModelConfig = { + name: string; + base_model: BaseModelType; + type: 'embedding'; + path: string; + description?: string; + model_format: null; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts new file mode 100644 index 00000000000..a73ee0aa32a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ModelError } from './ModelError'; +import type { VaeModelFormat } from './VaeModelFormat'; + +export type VaeModelConfig = { + name: string; + base_model: BaseModelType; + type: 'vae'; + path: string; + description?: string; + model_format: VaeModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts new file mode 100644 index 00000000000..497f81d16f9 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type VaeModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts deleted file mode 100644 index f8decdb3417..00000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts deleted file mode 100644 index 614749a2c54..00000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__lora__LoRAModel__Config = { - path: string; - description?: string; - format: ('lycoris' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts deleted file mode 100644 index f23d5002e3a..00000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = { - path: string; - description?: string; - format: null; - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts deleted file mode 100644 index d9314a6063f..00000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__vae__VaeModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 2e4a83b25f7..51a36caad1d 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from '../models/ShowImageInvocation'; import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation'; import type { SubtractInvocation } from '../models/SubtractInvocation'; @@ -175,7 +174,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -212,7 +211,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts index 2d42931b0b6..e2d765dd90e 100644 --- a/invokeai/frontend/web/src/services/apiSlice.ts +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -13,23 +13,68 @@ import { TagTypesFrom, TagTypesFromApi, } from '@reduxjs/toolkit/dist/query/endpointDefinitions'; +import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; +import { BaseModelType } from './api/models/BaseModelType'; +import { ModelType } from './api/models/ModelType'; +import { ModelsList } from './api/models/ModelsList'; +import { keyBy } from 'lodash-es'; type ListBoardsArg = { offset: number; limit: number }; type UpdateBoardArg = { board_id: string; changes: BoardChanges }; type AddImageToBoardArg = { board_id: string; image_name: string }; type RemoveImageFromBoardArg = { board_id: string; image_name: string }; type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; +type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType }; -const tagTypes = ['Board', 'Image']; +type ModelConfig = ModelsList['models'][number]; + +const tagTypes = ['Board', 'Image', 'Model']; type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; const LIST = 'LIST'; +const modelsAdapter = createEntityAdapter({ + selectId: (model) => getModelId(model), + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +const getModelId = ({ base_model, type, name }: ModelConfig) => + `${base_model}/${type}/${name}`; + export const api = createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), reducerPath: 'api', tagTypes, endpoints: (build) => ({ + /** + * Models Queries + */ + + listModels: build.query, ListModelsArg>({ + query: (arg) => ({ url: 'models/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.ids.map((id) => ({ + type: 'Model' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: (response: ModelsList, meta, arg) => { + return modelsAdapter.addMany( + modelsAdapter.getInitialState(), + keyBy(response.models, getModelId) + ); + }, + }), /** * Boards Queries */ @@ -174,4 +219,5 @@ export const { useRemoveImageFromBoardMutation, useListBoardImagesQuery, useGetImageDTOQuery, + useListModelsQuery, } = api; diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts deleted file mode 100644 index 97f2bd80165..00000000000 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { Model } from 'features/system/store/modelSlice'; -import { reduce, size } from 'lodash-es'; -import { ModelsService } from 'services/api'; - -const models = log.child({ namespace: 'model' }); - -export const IMAGES_PER_PAGE = 20; - -export const receivedModels = createAppAsyncThunk( - 'models/receivedModels', - async (_) => { - const response = await ModelsService.listModels(); - - const deserializedModels = reduce( - response.models['sd-1']['pipeline'], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - - return modelsAccumulator; - }, - {} as Record - ); - - models.info( - { response }, - `Received ${size(response.models['sd-1']['pipeline'])} models` - ); - - return deserializedModels; - } -); diff --git a/invokeai/frontend/web/src/theme/colors/greenTea.ts b/invokeai/frontend/web/src/theme/colors/greenTea.ts index ffecbf2ffa0..318aecbc615 100644 --- a/invokeai/frontend/web/src/theme/colors/greenTea.ts +++ b/invokeai/frontend/web/src/theme/colors/greenTea.ts @@ -4,8 +4,8 @@ import { generateColorPalette } from '../util/generateColorPalette'; export const greenTeaThemeColors: InvokeAIThemeColors = { base: generateColorPalette(223, 10), baseAlpha: generateColorPalette(223, 10, false, true), - accent: generateColorPalette(155, 80), - accentAlpha: generateColorPalette(155, 80, false, true), + accent: generateColorPalette(160, 60), + accentAlpha: generateColorPalette(160, 60, false, true), working: generateColorPalette(47, 68), workingAlpha: generateColorPalette(47, 68, false, true), warning: generateColorPalette(28, 75), @@ -14,5 +14,5 @@ export const greenTeaThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, false, true), error: generateColorPalette(0, 50), errorAlpha: generateColorPalette(0, 50, false, true), - gridLineColor: 'rgba(255, 255, 255, 0.2)', + gridLineColor: 'rgba(255, 255, 255, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/invokeAI.ts b/invokeai/frontend/web/src/theme/colors/invokeAI.ts index c39b3bed81c..82db58bd355 100644 --- a/invokeai/frontend/web/src/theme/colors/invokeAI.ts +++ b/invokeai/frontend/web/src/theme/colors/invokeAI.ts @@ -2,8 +2,8 @@ import { InvokeAIThemeColors } from 'theme/themeTypes'; import { generateColorPalette } from 'theme/util/generateColorPalette'; export const invokeAIThemeColors: InvokeAIThemeColors = { - base: generateColorPalette(225, 15), - baseAlpha: generateColorPalette(225, 15, false, true), + base: generateColorPalette(220, 15), + baseAlpha: generateColorPalette(220, 15, false, true), accent: generateColorPalette(250, 50), accentAlpha: generateColorPalette(250, 50, false, true), working: generateColorPalette(47, 67), @@ -14,5 +14,5 @@ export const invokeAIThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(113, 70, false, true), error: generateColorPalette(0, 76), errorAlpha: generateColorPalette(0, 76, false, true), - gridLineColor: 'rgba(255, 255, 255, 0.2)', + gridLineColor: 'rgba(150, 150, 180, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/lightTheme.ts b/invokeai/frontend/web/src/theme/colors/lightTheme.ts index 2a7a05bbd21..2fdbd1a769a 100644 --- a/invokeai/frontend/web/src/theme/colors/lightTheme.ts +++ b/invokeai/frontend/web/src/theme/colors/lightTheme.ts @@ -14,5 +14,5 @@ export const lightThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, true, true), error: generateColorPalette(0, 50, true), errorAlpha: generateColorPalette(0, 50, true, true), - gridLineColor: 'rgba(0, 0, 0, 0.2)', + gridLineColor: 'rgba(0, 0, 0, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts index adfb8ab2883..952e0a50668 100644 --- a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts +++ b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts @@ -14,5 +14,5 @@ export const oceanBlueColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, false, true), error: generateColorPalette(0, 100), errorAlpha: generateColorPalette(0, 100, false, true), - gridLineColor: 'rgba(136, 148, 184, 0.2)', + gridLineColor: 'rgba(136, 148, 184, 0.15)', };