Skip to content

Commit

Permalink
Warn on invalid model configs in the DB rather than crashing.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Jul 12, 2024
1 parent 5795617 commit 69af099
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/contributing/MODEL_MANAGER.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
record_store = ModelRecordServiceSQL(db, logger)
queue = DownloadQueueService()
queue.start()
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
download_queue=download_queue_service,
events=events,
)
Expand Down
21 changes: 19 additions & 2 deletions invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@
"""

import json
import logging
import sqlite3
from math import ceil
from pathlib import Path
from typing import List, Optional, Union

import pydantic

from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException,
ModelRecordChanges,
Expand All @@ -67,7 +70,7 @@
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""

def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, logger: logging.Logger):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
Expand All @@ -76,6 +79,7 @@ def __init__(self, db: SqliteDatabase):
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger

@property
def db(self) -> SqliteDatabase:
Expand Down Expand Up @@ -291,7 +295,20 @@ def search_by_attr(
tuple(bindings),
)
result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]

# Parse the model configs.
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
else:
results.append(model_config)

return results

def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
Expand Down
2 changes: 1 addition & 1 deletion tests/app/services/model_records/test_model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def store(
config._root = datadir
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, logger)


def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
Expand Down
4 changes: 2 additions & 2 deletions tests/backend/model_manager/model_manager_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService()
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, logger)

installer = ModelInstallService(
app_config=mm2_app_config,
Expand All @@ -128,7 +128,7 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, logger)
# add five simple config records to the database
config1 = VAEDiffusersConfig(
key="test_config_1",
Expand Down

0 comments on commit 69af099

Please sign in to comment.