Skip to content

Commit

Permalink
Set model type to const value in openapi schema, add model format enu…
Browse files Browse the repository at this point in the history
…ms to model schema(as they not not referenced in case of Literal definition)
  • Loading branch information
StAlKeR7779 committed Jun 20, 2023
1 parent 46dc751 commit 92c86fd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 16 deletions.
16 changes: 16 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ def custom_openapi():

invoker_schema["output"] = outputs_ref

from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__

if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue

# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)

app.openapi_schema = openapi_schema
return app.openapi_schema

Expand Down
71 changes: 55 additions & 16 deletions invokeai/backend/model_management/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
Expand Down Expand Up @@ -30,27 +33,63 @@
#},
}

def _get_all_model_configs():
configs = set()
for models in MODEL_CLASSES.values():
for _, model in models.items():
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs)

MODEL_CONFIGS = _get_all_model_configs()
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()

class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType

for cfg in MODEL_CONFIGS:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
name_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), {})

#globals()[name] = value
vars()[openapi_cfg_name] = name_wrapper
OPENAPI_MODEL_CONFIGS.append(name_wrapper)
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)

for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue

api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))

#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)

def get_model_config_enums():
enums = list()

for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")

# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]

if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)

elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))

elif field is None:
pass

else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")

return enums

0 comments on commit 92c86fd

Please sign in to comment.