From dd1b5265fdcdfb0094800d0085245ddd22df3da6 Mon Sep 17 00:00:00 2001 From: Jason Weill Date: Wed, 20 Dec 2023 17:39:15 -0800 Subject: [PATCH] Refactoring, constants for strings --- .../jupyter_ai_magics/magics.py | 78 +++++++++---------- 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index e06d0266e..be38044f2 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -88,6 +88,15 @@ def _repr_mimebundle_(self, include=None, exclude=None): AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"} +# Strings for listing providers and models +# Avoid composing strings, to make localization easier in the future +ENV_NOT_SET = "You have not set this environment variable, so you cannot use this provider's models." +ENV_SET = "You have set this environment variable, so you can use this provider's models." +MULTIENV_NOT_SET = "You have not set all of these environment variables, so you cannot use this provider's models." +MULTIENV_SET = "You have set all of these environment variables, so you can use this provider's models." + +ENV_REQUIRES = "Requires environment variable:" +MULTIENV_REQUIRES = "Requires environment variables:" class FormatDict(dict): """Subclass of dict to be passed to str#format(). Suppresses KeyError and @@ -170,25 +179,25 @@ def _ai_env_status_for_provider_markdown(self, provider_id): ): return na_message # No emoji - # Avoid composing strings, to make localization easier in the future - not_set_title = "You have not set this environment variable, so you cannot use this provider's models." - set_title = "You have set this environment variable, so you can use this provider's models." + not_set_title = ENV_NOT_SET + set_title = ENV_SET env_status_ok = False - try: - var_name = self.providers[provider_id].auth_strategy.name + + auth_strategy = self.providers[provider_id].auth_strategy + if auth_strategy.type == "env": + var_name = auth_strategy.name env_var_display = f"`{var_name}`" env_status_ok = var_name in os.environ - except AttributeError: # No "name" attribute - # Try multiple names - try: - var_names = self.providers[provider_id].auth_strategy.names - formatted_names = [f"`{name}`" for name in var_names] - env_var_display = ", ".join(formatted_names) - env_status_ok = all(var_name in os.environ for var_name in var_names) - not_set_title = "You have not set all of these environment variables, so you cannot use this provider's models." - set_title = "You have set all of these environment variables, so you can use this provider's models." - except AttributeError: # No "names" attribute - return na_message + elif auth_strategy.type == "multienv": + # Check multiple environment variables + var_names = self.providers[provider_id].auth_strategy.names + formatted_names = [f"`{name}`" for name in var_names] + env_var_display = ", ".join(formatted_names) + env_status_ok = all(var_name in os.environ for var_name in var_names) + not_set_title = MULTIENV_NOT_SET + set_title = MULTIENV_SET + else: # No environment variables + return na_message output = f"{env_var_display} | " if env_status_ok: @@ -199,35 +208,18 @@ def _ai_env_status_for_provider_markdown(self, provider_id): return output def _ai_env_status_for_provider_text(self, provider_id): - if ( - provider_id not in self.providers - or self.providers[provider_id].auth_strategy == None - ): - return "" # No message necessary + # only handle providers with "env" or "multienv" auth strategy + auth_strategy = getattr(self.providers[provider_id], 'auth_strategy', None) + if not auth_strategy or (auth_strategy.type != "env" and auth_strategy.type != "multienv"): + return "" - # Avoid composing strings, to make localization easier in the future - prefix = "Requires environment variable:" - var_names = [] - try: - var_names.append(self.providers[provider_id].auth_strategy.name) - except AttributeError: # No "name" attribute - # Try multiple names - try: - var_names = self.providers[provider_id].auth_strategy.names - prefix = "Requires environment variables:" - except AttributeError: # No "names" attribute - return "" - - annotated_var_names = [] - for var_name in var_names: - annotated_var_name = var_name - if var_name in os.environ: - annotated_var_name += " (set)" - else: - annotated_var_name += " (not set)" - annotated_var_names.append(annotated_var_name) + prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES + envvars = [auth_strategy.name] if auth_strategy.type == "env" else auth_strategy.names[:] + + for i in range(len(envvars)): + envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)" - return prefix + " " + ", ".join(annotated_var_names) + "\n" + return prefix + " " + ", ".join(envvars) + "\n" # Is this a name of a Python variable that can be called as a LangChain chain? def _is_langchain_chain(self, name):