Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update UI To Use New Model Manager #3548

Merged
merged 29 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9838dda
chore: Update model config type names
blessedcoolant Jun 17, 2023
663f493
chore: Rebuild API
blessedcoolant Jun 17, 2023
bf0d5f4
fix: Update missing name types to new names
blessedcoolant Jun 17, 2023
01d1760
Generate config names for openapi
StAlKeR7779 Jun 17, 2023
e374211
chore: Rebuild API with new Model API names
blessedcoolant Jun 17, 2023
f8d7477
wip: Add 2.x Models to the Model List
blessedcoolant Jun 17, 2023
ef83a2f
Add name, base_mode, type fields to model info
StAlKeR7779 Jun 17, 2023
d2f3500
chore: Rebuild API - base_model and type added
blessedcoolant Jun 17, 2023
727293d
fix: 2.1 models breaking generation
blessedcoolant Jun 17, 2023
4847212
feat: Enable 2.x Model Generation in Linear UI
blessedcoolant Jun 17, 2023
604cc1a
wip: Move Model Selector to own file
blessedcoolant Jun 17, 2023
0c36162
cleanup: Updated model slice names to be more descriptive
blessedcoolant Jun 18, 2023
6bdf68d
feat: Port Schedulers to Mantine
blessedcoolant Jun 18, 2023
e48528b
revert: getModels to receivedModels
blessedcoolant Jun 18, 2023
7033071
fix: Unserialization key issue
blessedcoolant Jun 18, 2023
6256be4
fix: Remove type from Model type name
blessedcoolant Jun 18, 2023
c4c3c96
Revert "feat: Port Schedulers to Mantine"
blessedcoolant Jun 18, 2023
6c98700
fix: Adjust the Schedular select width
blessedcoolant Jun 19, 2023
d3dec59
tweal: UI colors
blessedcoolant Jun 19, 2023
aceadac
Remove default model logic
StAlKeR7779 Jun 20, 2023
e4dc9c5
Rename format to model_format(still named format when work with config)
StAlKeR7779 Jun 20, 2023
da566b5
Update model format field to use enums
StAlKeR7779 Jun 20, 2023
21245a0
Set model type to const value in openapi schema, add model format enu…
StAlKeR7779 Jun 20, 2023
b937b7d
feat(models): update model manager service & route to return list of …
psychedelicious Jun 22, 2023
42a59aa
feat(nodes): add `sd_model_loader` node
psychedelicious Jun 22, 2023
3722cdf
chore(ui): regen api client
psychedelicious Jun 22, 2023
1bc1707
tidy(nodes): rename `sd_model_loader` to `pipeline_model_loader`
psychedelicious Jun 22, 2023
2a178f5
chore(ui): regen api client
psychedelicious Jun 22, 2023
339e7ce
feat(ui): initial implementation of model loading
psychedelicious Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions invokeai/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]

models_router = APIRouter(prefix="/v1/models", tags=["models"])

Expand Down Expand Up @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")

class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
models: list[MODEL_CONFIGS]


@models_router.get(
Expand All @@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: BaseModelType = Query(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
Expand Down
16 changes: 16 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def custom_openapi():

invoker_schema["output"] = outputs_ref

from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__

if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue

# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)

app.openapi_schema = openapi_schema
return app.openapi_schema

Expand Down
144 changes: 25 additions & 119 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,115 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on


class SD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
class PipelineModelField(BaseModel):
"""Pipeline model field"""

type: Literal["sd1_model_loader"] = "sd1_model_loader"

model_name: str = Field(default="", description="Model to load")
# TODO: precision?

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}

def invoke(self, context: InvocationContext) -> ModelLoaderOutput:

base_model = BaseModelType.StableDiffusion1 # TODO:

# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")

"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)

if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)

if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")


return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae,
),
)
)
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""

# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"

type: Literal["sd2_model_loader"] = "sd2_model_loader"

model_name: str = Field(default="", description="Model to load")
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?

# Schema customisation
Expand All @@ -160,22 +64,24 @@ class Config(InvocationConfig):
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
"model": "model"
}
},
}

def invoke(self, context: InvocationContext) -> ModelLoaderOutput:

base_model = BaseModelType.StableDiffusion2 # TODO:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline

# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
):
raise Exception(f"Unkown model name: {self.model_name}!")
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")

"""
if not context.services.model_manager.model_exists(
Expand Down Expand Up @@ -210,39 +116,39 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Vae,
),
)
Expand Down
43 changes: 4 additions & 39 deletions invokeai/app/services/model_manager_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass

from invokeai.backend.model_management.model_manager import (
Expand Down Expand Up @@ -69,19 +69,6 @@ def model_exists(
) -> bool:
pass

@abstractmethod
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
"""
pass

@abstractmethod
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass

@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Expand Down Expand Up @@ -270,17 +257,6 @@ def model_exists(
model_type,
)

def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()

def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name, base_model, model_type)

def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Expand All @@ -297,21 +273,10 @@ def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
) -> list[dict]:
# ) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)

Expand Down
Loading
Loading