Skip to content

Commit

Permalink
Merge branch 'main' into lstein/remove-hardcoded-cuda-device
Browse files Browse the repository at this point in the history
  • Loading branch information
lstein authored Jul 5, 2023
2 parents fc41954 + 17c5568 commit 9f9ce08
Show file tree
Hide file tree
Showing 133 changed files with 3,966 additions and 3,991 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ checkpoints
# If it's a Mac
.DS_Store

invokeai/frontend/web/dist/*

# Let the frontend manage its own gitignore
!invokeai/frontend/web/*

Expand Down
47 changes: 30 additions & 17 deletions invokeai/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from typing import Literal, Optional, Union

from fastapi import Query
from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]

models_router = APIRouter(prefix="/v1/models", tags=["models"])


class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
Expand Down Expand Up @@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")

class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")

class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
Expand Down Expand Up @@ -86,7 +89,6 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw })
return models


@models_router.post(
"/",
operation_id="update_model",
Expand All @@ -109,27 +111,38 @@ async def update_model(
return model_response

@models_router.post(
"/",
"/import",
operation_id="import_model",
responses={200: {"status": "success"}},
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
name: str = Query(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger

installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
)
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found')

@models_router.delete(
"/{model_name}",
Expand Down
25 changes: 17 additions & 8 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)

from pydantic import BaseModel, Field
from pydantic import BaseConfig, BaseModel, Field

if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
Expand Down Expand Up @@ -65,8 +66,13 @@ def get_invocations(cls):
@classmethod
def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))

return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)

@classmethod
def get_output_type(cls):
return signature(cls.invoke).return_annotation
Expand All @@ -75,11 +81,11 @@ def get_output_type(cls):
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
#fmt: off

# fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
# fmt: on


# TODO: figure out a better way to provide these hints
Expand All @@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
"model",
"control",
"image_collection",
"vae_model",
"lora_model",
],
]
tags: List[str]
title: str


class CustomisedSchemaExtra(TypedDict):
ui: UIConfig


class InvocationConfig(BaseModel.Config):
class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
Expand Down
Loading

0 comments on commit 9f9ce08

Please sign in to comment.