Skip to content

Commit

Permalink
ensure no empty field dictionaries in config
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Sep 7, 2023
1 parent 4bbb475 commit 00db056
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def _validate_config(self, config: GlobalConfig):
def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
# remove any empty field dictionaries
new_config.fields = {k: v for k, v in new_config.fields.items() if v}

self._validate_config(new_config)
with open(self.config_path, "w") as f:
json.dump(new_config.dict(), f, indent=self.indentation_depth)
Expand Down
12 changes: 11 additions & 1 deletion packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os

Expand Down Expand Up @@ -129,7 +130,7 @@ def test_indentation_depth(common_cm_kwargs, config_path):
"""Asserts that the CM indents the configuration and respects the
`indentation_depth` trait when specified."""
INDENT_DEPTH = 7
cm = ConfigManager(**common_cm_kwargs, indentation_depth=INDENT_DEPTH)
ConfigManager(**common_cm_kwargs, indentation_depth=INDENT_DEPTH)
with open(config_path) as f:
config_file = f.read()
config_lines = config_file.split("\n")
Expand Down Expand Up @@ -162,6 +163,15 @@ def test_update(cm: ConfigManager):
assert cm.em_provider_params == {**API_PARAMS, "model_id": EM_LID}


def test_update_no_empty_field_dicts(cm: ConfigManager, config_path):
LM_GID, _, _, _, _ = configure_to_cohere(cm)
cm.update_config(UpdateConfigRequest(fields={LM_GID: {}}))

with open(config_path) as f:
raw_config = json.loads(f.read())
assert raw_config["fields"] == {}


def test_update_fails_with_invalid_req():
with pytest.raises(ValidationError):
UpdateConfigRequest(send_with_shift_enter=None)
Expand Down

0 comments on commit 00db056

Please sign in to comment.