Skip to content

Commit

Permalink
Model manager fixes (#3541)
Browse files Browse the repository at this point in the history
Fix lora import
Fix sd2 config - `variant` field not added
Fix list models api - `base_model` arg not provided, redundant assert
check
  • Loading branch information
blessedcoolant authored Jun 15, 2023
2 parents 5c74045 + 5f2d079 commit 4cbc802
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/api/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def list_models(
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models

Expand Down
1 change: 0 additions & 1 deletion invokeai/backend/model_management/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def list_models(
named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml
"""
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"

models = dict()
for model_key in sorted(self.models, key=str.casefold):
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_management/models/lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from typing import Optional, Union, Literal
from .base import (
Expand Down
2 changes: 2 additions & 0 deletions invokeai/backend/model_management/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ class StableDiffusion2Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool

class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool

Expand Down

0 comments on commit 4cbc802

Please sign in to comment.