diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 63bb003f8..82ef03126 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -147,52 +147,52 @@ def _init_validator(self) -> Validator: def _init_config(self): if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id - - # if the currently selected language or embedding model are - # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model( - lm_id, raise_exc=False - ): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model( - em_id, raise_exc=False - ): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None - - # if the currently selected language or embedding model ids are - # not associated with models, set them to `None` and log a warning. - if ( - lm_id is not None - and not get_lm_provider(lm_id, self._lm_providers)[1] - ): - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if ( - em_id is not None - and not get_em_provider(em_id, self._em_providers)[1] - ): - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None - - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - return + self._process_existing_config() + else: + self._create_default_config() + def _process_existing_config(self): + with open(self.config_path, encoding="utf-8") as f: + config = GlobalConfig(**json.loads(f.read())) + validated_config = self._validate_lm_em_id(config) + + # re-write to the file to validate the config and apply any + # updates to the config file immediately + self._write_config(validated_config) + + def _validate_lm_em_id(self, config): + lm_id = config.model_provider_id + em_id = config.embeddings_provider_id + + # if the currently selected language or embedding model are + # forbidden, set them to `None` and log a warning. + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.embeddings_provider_id = None + + # if the currently selected language or embedding model ids are + # not associated with models, set them to `None` and log a warning. + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + config.embeddings_provider_id = None + + return config + + def _create_default_config(self): properties = self.validator.schema.get("properties", {}) field_list = GlobalConfig.__fields__.keys() field_dict = {