From a710af811768749984f103cbd6bd1fcf111dec9d Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 13 Nov 2023 12:40:22 -0800 Subject: [PATCH] If model_provider_id or embeddings_provider_id is not associated with models, set it to None (#459) * if ids are not getting associated with models, set them to None * add test_handle_bad_provider_ids * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tests * Add docstrings * adjust docstrings --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../jupyter-ai/jupyter_ai/config_manager.py | 19 +++++++++++ .../jupyter_ai/tests/test_config_manager.py | 32 ++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 5adc604cf..63bb003f8 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -169,6 +169,25 @@ def _init_config(self): ) config.embeddings_provider_id = None + # if the currently selected language or embedding model ids are + # not associated with models, set them to `None` and log a warning. + if ( + lm_id is not None + and not get_lm_provider(lm_id, self._lm_providers)[1] + ): + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + config.model_provider_id = None + if ( + em_id is not None + and not get_em_provider(em_id, self._em_providers)[1] + ): + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + config.embeddings_provider_id = None + # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(config) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index dafba7adf..8fa59ab6a 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -1,6 +1,7 @@ import json import logging import os +from unittest.mock import mock_open, patch import pytest from jupyter_ai.config_manager import ( @@ -9,7 +10,7 @@ KeyInUseError, WriteConflictError, ) -from jupyter_ai.models import DescribeConfigResponse, UpdateConfigRequest +from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from pydantic import ValidationError @@ -83,6 +84,29 @@ def reset(config_path, schema_path): pass +@pytest.fixture +def config_with_bad_provider_ids(tmp_path): + """Fixture that creates a `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models. File is created in `tmp_path` folder. Function returns path to the file.""" + config_data = { + "model_provider_id:": "foo:bar", + "embeddings_provider_id": "buzz:fizz", + "api_keys": {}, + "send_with_shift_enter": False, + "fields": {}, + } + config_path = tmp_path / "config.json" + with open(config_path, "w") as file: + json.dump(config_data, file) + return str(config_path) + + +@pytest.fixture +def cm_with_bad_provider_ids(common_cm_kwargs, config_with_bad_provider_ids): + """Config manager instance created with `config_path` set to mocked `config.json` with `model_provider_id` and `embeddings_provider_id` values that would not associate with models.""" + common_cm_kwargs["config_path"] = config_with_bad_provider_ids + return ConfigManager(**common_cm_kwargs) + + def configure_to_cohere(cm: ConfigManager): """Configures the ConfigManager to use Cohere language and embedding models with the API key set. Returns a 3-tuple of the keyword arguments used.""" @@ -290,3 +314,9 @@ def test_forbid_deleting_key_in_use(cm: ConfigManager): with pytest.raises(KeyInUseError): cm.delete_api_key("COHERE_API_KEY") + + +def test_handle_bad_provider_ids(cm_with_bad_provider_ids): + config_desc = cm_with_bad_provider_ids.get_config() + assert config_desc.model_provider_id is None + assert config_desc.embeddings_provider_id is None