Skip to content

Commit

Permalink
Update UI To Use New Model Manager (#3548)
Browse files Browse the repository at this point in the history
PR for the Model Manager UI work related to 3.0

[DONE]

- Update ModelType Config names to be specific so that the front end can
parse them correctly.
- Rebuild frontend schema to reflect these changes.
- Update Linear UI Text To Image and Image to Image to work with the new
model loader.
- Updated the ModelInput component in the Node Editor to work with the
new changes.

[TODO REMEMBER]

- Add proper types for ModelLoaderType in `ModelSelect.tsx`

[TODO] 

- Everything else.
  • Loading branch information
blessedcoolant committed Jun 22, 2023
2 parents 2d889e1 + 339e7ce commit 22c337b
Show file tree
Hide file tree
Showing 67 changed files with 709 additions and 667 deletions.
11 changes: 5 additions & 6 deletions invokeai/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]

models_router = APIRouter(prefix="/v1/models", tags=["models"])

Expand Down Expand Up @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")

class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
models: list[MODEL_CONFIGS]


@models_router.get(
Expand All @@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: BaseModelType = Query(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
Expand Down
16 changes: 16 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def custom_openapi():

invoker_schema["output"] = outputs_ref

from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__

if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue

# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)

app.openapi_schema = openapi_schema
return app.openapi_schema

Expand Down
144 changes: 25 additions & 119 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,115 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on


class SD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
class PipelineModelField(BaseModel):
"""Pipeline model field"""

type: Literal["sd1_model_loader"] = "sd1_model_loader"

model_name: str = Field(default="", description="Model to load")
# TODO: precision?

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}

def invoke(self, context: InvocationContext) -> ModelLoaderOutput:

base_model = BaseModelType.StableDiffusion1 # TODO:

# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")

"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")


return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae,
),
)
)
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""

# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"

type: Literal["sd2_model_loader"] = "sd2_model_loader"

model_name: str = Field(default="", description="Model to load")
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?

# Schema customisation
Expand All @@ -160,22 +64,24 @@ class Config(InvocationConfig):
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
"model": "model"
}
},
}

def invoke(self, context: InvocationContext) -> ModelLoaderOutput:

base_model = BaseModelType.StableDiffusion2 # TODO:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline

# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
):
raise Exception(f"Unkown model name: {self.model_name}!")
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")

"""
if not context.services.model_manager.model_exists(
Expand Down Expand Up @@ -210,39 +116,39 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Vae,
),
)
Expand Down
43 changes: 4 additions & 39 deletions invokeai/app/services/model_manager_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass

from invokeai.backend.model_management.model_manager import (
Expand Down Expand Up @@ -69,19 +69,6 @@ def model_exists(
) -> bool:
pass

@abstractmethod
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
"""
pass

@abstractmethod
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass

@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Expand Down Expand Up @@ -270,17 +257,6 @@ def model_exists(
model_type,
)

def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()

def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name, base_model, model_type)

def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Expand All @@ -297,21 +273,10 @@ def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
) -> list[dict]:
# ) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)

Expand Down
Loading

0 comments on commit 22c337b

Please sign in to comment.