Skip to content

Commit

Permalink
Backport PR jupyterlab#459: If model_provider_id or embeddings_provid…
Browse files Browse the repository at this point in the history
…er_id is not associated with models, set it to None (jupyterlab#467)

Co-authored-by: Andrii Ieroshenko <[email protected]>
  • Loading branch information
meeseeksmachine and andrii-i authored Nov 13, 2023
1 parent 04494ef commit 2fdb5ba
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
19 changes: 19 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ def _init_config(self):
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if (
lm_id is not None
and not get_lm_provider(lm_id, self._lm_providers)[1]
):
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if (
em_id is not None
and not get_em_provider(em_id, self._em_providers)[1]
):
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down
32 changes: 31 additions & 1 deletion packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from unittest.mock import mock_open, patch

import pytest
from jupyter_ai.config_manager import (
Expand All @@ -9,7 +10,7 @@
KeyInUseError,
WriteConflictError,
)
from jupyter_ai.models import DescribeConfigResponse, UpdateConfigRequest
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from pydantic import ValidationError

Expand Down Expand Up @@ -83,6 +84,29 @@ def reset(config_path, schema_path):
pass


@pytest.fixture
def config_with_bad_provider_ids(tmp_path):
"""Fixture that creates a `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models. File is created in `tmp_path` folder. Function returns path to the file."""
config_data = {
"model_provider_id:": "foo:bar",
"embeddings_provider_id": "buzz:fizz",
"api_keys": {},
"send_with_shift_enter": False,
"fields": {},
}
config_path = tmp_path / "config.json"
with open(config_path, "w") as file:
json.dump(config_data, file)
return str(config_path)


@pytest.fixture
def cm_with_bad_provider_ids(common_cm_kwargs, config_with_bad_provider_ids):
"""Config manager instance created with `config_path` set to mocked `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models."""
common_cm_kwargs["config_path"] = config_with_bad_provider_ids
return ConfigManager(**common_cm_kwargs)


def configure_to_cohere(cm: ConfigManager):
"""Configures the ConfigManager to use Cohere language and embedding models
with the API key set. Returns a 3-tuple of the keyword arguments used."""
Expand Down Expand Up @@ -290,3 +314,9 @@ def test_forbid_deleting_key_in_use(cm: ConfigManager):

with pytest.raises(KeyInUseError):
cm.delete_api_key("COHERE_API_KEY")


def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
config_desc = cm_with_bad_provider_ids.get_config()
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None

0 comments on commit 2fdb5ba

Please sign in to comment.