Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problem of all installed models being assigned "<NOKEY>" #5841

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def install_path(
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", self._create_key())

info: AnyModelConfig = self._probe_model(Path(model_path), config)

Expand Down Expand Up @@ -278,7 +279,7 @@ def sync_to_config(self) -> None:
self._logger.info("Model installer (re)initialized")

def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
Expand Down Expand Up @@ -531,9 +532,9 @@ def _create_key(self) -> str:
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
key = self._create_key()
if config and not config.get("key", None):
config["key"] = key
# Note that we may be passed a pre-populated AnyModelConfig object,
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", self._create_key())
info = info or ModelProbe.probe(model_path, config)

model_path = model_path.absolute()
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def probe(
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)

model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None))
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion tests/app/services/model_install/test_model_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
assert key is not None
assert key != "<NOKEY>"
assert len(key) == 32


Expand Down Expand Up @@ -58,12 +59,13 @@ def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase,
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash", "key": "xyzzy"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
assert model_record.key == "xyzzy"


def test_install(
Expand Down Expand Up @@ -129,6 +131,7 @@ def test_background_install(
model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == destination
assert model_record.key != "<NOKEY>"
assert Path(mm2_app_config.models_dir / model_record.path).exists()

# see if metadata was properly passed through
Expand Down
Loading