From e41334141701ff24fe9c03008155cd6c9437149b Mon Sep 17 00:00:00 2001 From: david qiu Date: Thu, 5 Dec 2024 08:35:08 -0800 Subject: [PATCH] Backport PR #1137: Update completion model fields immediately on save --- .../jupyter-ai/jupyter_ai/config_manager.py | 9 ++-- .../jupyter_ai/tests/test_config_manager.py | 44 ++++++++++++++----- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 71ca3f185..7b309faae 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -442,10 +442,10 @@ def em_provider_params(self): @property def completions_lm_provider_params(self): return self._provider_params( - "completions_model_provider_id", self._lm_providers + "completions_model_provider_id", self._lm_providers, completions=True ) - def _provider_params(self, key, listing): + def _provider_params(self, key, listing, completions: bool = False): # read config config = self._read_config() @@ -457,7 +457,10 @@ def _provider_params(self, key, listing): model_id = model_uid.split(":", 1)[1] # get config fields (e.g. base API URL, etc.) - fields = config.fields.get(model_uid, {}) + if completions: + fields = config.completions_fields.get(model_uid, {}) + else: + fields = config.fields.get(model_uid, {}) # get authn fields _, Provider = get_em_provider(model_uid, listing) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 168c675f7..4a739f6e5 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -194,20 +194,35 @@ def configure_to_openai(cm: ConfigManager): return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS -def configure_with_fields(cm: ConfigManager): +def configure_with_fields(cm: ConfigManager, completions: bool = False): """ - Configures the ConfigManager with fields and API keys. + Default behavior: Configures the ConfigManager with fields and API keys. Returns the expected result of `cm.lm_provider_params`. + + If `completions` is set to `True`, this configures the ConfigManager with + completion model fields, and returns the expected result of + `cm.completions_lm_provider_params`. """ - req = UpdateConfigRequest( - model_provider_id="openai-chat:gpt-4o", - api_keys={"OPENAI_API_KEY": "foobar"}, - fields={ - "openai-chat:gpt-4o": { - "openai_api_base": "https://example.com", - } - }, - ) + if completions: + req = UpdateConfigRequest( + completions_model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + completions_fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) + else: + req = UpdateConfigRequest( + model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) cm.update_config(req) return { "model_id": "gpt-4o", @@ -445,7 +460,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids): assert config_desc.embeddings_provider_id is None -def test_config_manager_returns_fields(cm): +def test_returns_chat_model_fields(cm): """ Asserts that `ConfigManager.lm_provider_params` returns model fields set by the user. @@ -454,6 +469,11 @@ def test_config_manager_returns_fields(cm): assert cm.lm_provider_params == expected_model_args +def test_returns_completion_model_fields(cm): + expected_model_args = configure_with_fields(cm, completions=True) + assert cm.completions_lm_provider_params == expected_model_args + + def test_config_manager_does_not_write_to_defaults( config_file_with_model_fields, schema_path ):