Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #459 on branch 1.x (If model_provider_id or embeddings_provider_id is not associated with models, set it to None) #467

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ def _init_config(self):
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if (
lm_id is not None
and not get_lm_provider(lm_id, self._lm_providers)[1]
):
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if (
em_id is not None
and not get_em_provider(em_id, self._em_providers)[1]
):
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down
32 changes: 31 additions & 1 deletion packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from unittest.mock import mock_open, patch

import pytest
from jupyter_ai.config_manager import (
Expand All @@ -9,7 +10,7 @@
KeyInUseError,
WriteConflictError,
)
from jupyter_ai.models import DescribeConfigResponse, UpdateConfigRequest
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from pydantic import ValidationError

Expand Down Expand Up @@ -83,6 +84,29 @@ def reset(config_path, schema_path):
pass


@pytest.fixture
def config_with_bad_provider_ids(tmp_path):
"""Fixture that creates a `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models. File is created in `tmp_path` folder. Function returns path to the file."""
config_data = {
"model_provider_id:": "foo:bar",
"embeddings_provider_id": "buzz:fizz",
"api_keys": {},
"send_with_shift_enter": False,
"fields": {},
}
config_path = tmp_path / "config.json"
with open(config_path, "w") as file:
json.dump(config_data, file)
return str(config_path)


@pytest.fixture
def cm_with_bad_provider_ids(common_cm_kwargs, config_with_bad_provider_ids):
"""Config manager instance created with `config_path` set to mocked `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models."""
common_cm_kwargs["config_path"] = config_with_bad_provider_ids
return ConfigManager(**common_cm_kwargs)


def configure_to_cohere(cm: ConfigManager):
"""Configures the ConfigManager to use Cohere language and embedding models
with the API key set. Returns a 3-tuple of the keyword arguments used."""
Expand Down Expand Up @@ -290,3 +314,9 @@ def test_forbid_deleting_key_in_use(cm: ConfigManager):

with pytest.raises(KeyInUseError):
cm.delete_api_key("COHERE_API_KEY")


def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
config_desc = cm_with_bad_provider_ids.get_config()
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None
Loading