Skip to content

Commit

Permalink
refactor ConfigManager. _init_config
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Dec 19, 2023
1 parent 79e345a commit fc9602a
Showing 1 changed file with 45 additions and 45 deletions.
90 changes: 45 additions & 45 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,52 +147,52 @@ def _init_validator(self) -> Validator:

def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if (
lm_id is not None
and not get_lm_provider(lm_id, self._lm_providers)[1]
):
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if (
em_id is not None
and not get_em_provider(em_id, self._em_providers)[1]
):
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
return
self._process_existing_config()
else:
self._create_default_config()

def _process_existing_config(self):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
validated_config = self._validate_lm_em_id(config)

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(validated_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

return config

def _create_default_config(self):
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
field_dict = {
Expand Down

0 comments on commit fc9602a

Please sign in to comment.