From ad81851ef8cf64a5ed4641d0c9e329683ceb13de Mon Sep 17 00:00:00 2001
From: "Lumberbot (aka Jack)"
<39504233+meeseeksmachine@users.noreply.github.com>
Date: Thu, 21 Dec 2023 18:06:37 +0100
Subject: [PATCH] Backport PR #531: Adds multi-environment variable
authentication, Baidu Qianfan ERNIE-bot provider (#539)
Co-authored-by: Jason Weill <93281816+JasonWeill@users.noreply.github.com>
---
.gitignore | 2 +
docs/source/users/index.md | 4 +-
.../jupyter_ai_magics/__init__.py | 2 +
.../jupyter_ai_magics/aliases.py | 3 +
.../jupyter_ai_magics/embedding_providers.py | 13 +++
.../jupyter_ai_magics/magics.py | 97 ++++++++++++-------
.../jupyter_ai_magics/providers.py | 14 ++-
packages/jupyter-ai-magics/pyproject.toml | 5 +-
.../src/components/chat-settings.tsx | 15 +++
packages/jupyter-ai/src/handler.ts | 11 ++-
10 files changed, 125 insertions(+), 41 deletions(-)
diff --git a/.gitignore b/.gitignore
index eee408498..090e2febb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -131,3 +131,5 @@ dev.sh
.jupyter_ystore.db
.yarn
+
+.conda/
diff --git a/docs/source/users/index.md b/docs/source/users/index.md
index 348f7acac..905bc5c03 100644
--- a/docs/source/users/index.md
+++ b/docs/source/users/index.md
@@ -122,7 +122,7 @@ Jupyter AI supports a wide range of model providers and models. To use Jupyter A
Jupyter AI supports the following model providers:
-| Provider | Provider ID | Environment variable | Python package(s) |
+| Provider | Provider ID | Environment variable(s) | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
@@ -130,6 +130,7 @@ Jupyter AI supports the following model providers:
| Bedrock | `bedrock` | N/A | `boto3` |
| Bedrock (chat) | `bedrock-chat` | N/A | `boto3` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
+| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
@@ -137,6 +138,7 @@ Jupyter AI supports the following model providers:
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |
The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
+If multiple variables are listed for a provider, **all** must be specified.
To use the Bedrock models, you need access to the Bedrock service. For more information, see the
[Amazon Bedrock Homepage](https://aws.amazon.com/bedrock/).
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
index cf808476a..1bfdaeb24 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
@@ -8,6 +8,7 @@
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OpenAIEmbeddingsProvider,
+ QianfanEmbeddingsEndpointProvider,
)
from .exception import store_exception
from .magics import AiMagics
@@ -27,6 +28,7 @@
GPT4AllProvider,
HfHubProvider,
OpenAIProvider,
+ QianfanProvider,
SmEndpointProvider,
)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
index ab383af32..96cac4efe 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
@@ -3,4 +3,7 @@
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
+ "ernie-bot": "qianfan:ERNIE-Bot",
+ "ernie-bot-4": "qianfan:ERNIE-Bot-4",
+ "titan": "bedrock:amazon.titan-tg1-large",
}
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
index 7dcda78b8..75c8fa0a3 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
@@ -6,6 +6,7 @@
AwsAuthStrategy,
EnvAuthStrategy,
Field,
+ MultiEnvAuthStrategy,
)
from langchain.embeddings import (
BedrockEmbeddings,
@@ -13,6 +14,7 @@
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
+ QianfanEmbeddingsEndpoint,
)
from langchain.pydantic_v1 import BaseModel, Extra
@@ -127,3 +129,14 @@ def __init__(self, **kwargs):
models = ["all-MiniLM-L6-v2-f16"]
model_id_key = "model_id"
pypi_package_deps = ["gpt4all"]
+
+
+class QianfanEmbeddingsEndpointProvider(
+ BaseEmbeddingsProvider, QianfanEmbeddingsEndpoint
+):
+ id = "qianfan"
+ name = "ERNIE-Bot"
+ models = ["ERNIE-Bot", "ERNIE-Bot-4"]
+ model_id_key = "model"
+ pypi_package_deps = ["qianfan"]
+ auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
index 1e853da67..6a4e445d0 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
@@ -11,6 +11,7 @@
from IPython import get_ipython
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
+from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage
@@ -28,14 +29,6 @@
)
from .providers import BaseProvider
-MODEL_ID_ALIASES = {
- "gpt2": "huggingface_hub:gpt2",
- "gpt3": "openai:text-davinci-003",
- "chatgpt": "openai-chat:gpt-3.5-turbo",
- "gpt4": "openai-chat:gpt-4",
- "titan": "bedrock:amazon.titan-tg1-large",
-}
-
class TextOrMarkdown:
def __init__(self, text, markdown):
@@ -108,6 +101,18 @@ def _repr_mimebundle_(self, include=None, exclude=None):
AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}
+# Strings for listing providers and models
+# Avoid composing strings, to make localization easier in the future
+ENV_NOT_SET = "You have not set this environment variable, so you cannot use this provider's models."
+ENV_SET = (
+ "You have set this environment variable, so you can use this provider's models."
+)
+MULTIENV_NOT_SET = "You have not set all of these environment variables, so you cannot use this provider's models."
+MULTIENV_SET = "You have set all of these environment variables, so you can use this provider's models."
+
+ENV_REQUIRES = "Requires environment variable:"
+MULTIENV_REQUIRES = "Requires environment variables:"
+
class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
@@ -190,44 +195,53 @@ def _ai_env_status_for_provider_markdown(self, provider_id):
):
return na_message # No emoji
- try:
- env_var = self.providers[provider_id].auth_strategy.name
- except AttributeError: # No "name" attribute
+ not_set_title = ENV_NOT_SET
+ set_title = ENV_SET
+ env_status_ok = False
+
+ auth_strategy = self.providers[provider_id].auth_strategy
+ if auth_strategy.type == "env":
+ var_name = auth_strategy.name
+ env_var_display = f"`{var_name}`"
+ env_status_ok = var_name in os.environ
+ elif auth_strategy.type == "multienv":
+ # Check multiple environment variables
+ var_names = self.providers[provider_id].auth_strategy.names
+ formatted_names = [f"`{name}`" for name in var_names]
+ env_var_display = ", ".join(formatted_names)
+ env_status_ok = all(var_name in os.environ for var_name in var_names)
+ not_set_title = MULTIENV_NOT_SET
+ set_title = MULTIENV_SET
+ else: # No environment variables
return na_message
- output = f"`{env_var}` | "
- if os.getenv(env_var) == None:
- output += (
- '❌"
- )
+ output = f"{env_var_display} | "
+ if env_status_ok:
+ output += f'✅'
else:
- output += (
- '✅"
- )
+ output += f'❌'
return output
def _ai_env_status_for_provider_text(self, provider_id):
- if (
- provider_id not in self.providers
- or self.providers[provider_id].auth_strategy == None
+ # only handle providers with "env" or "multienv" auth strategy
+ auth_strategy = getattr(self.providers[provider_id], "auth_strategy", None)
+ if not auth_strategy or (
+ auth_strategy.type != "env" and auth_strategy.type != "multienv"
):
- return "" # No message necessary
-
- try:
- env_var = self.providers[provider_id].auth_strategy.name
- except AttributeError: # No "name" attribute
return ""
- output = f"Requires environment variable {env_var} "
- if os.getenv(env_var) != None:
- output += "(set)"
- else:
- output += "(not set)"
+ prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES
+ envvars = (
+ [auth_strategy.name]
+ if auth_strategy.type == "env"
+ else auth_strategy.names[:]
+ )
+
+ for i in range(len(envvars)):
+ envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)"
- return output + "\n"
+ return prefix + " " + ", ".join(envvars) + "\n"
# Is this a name of a Python variable that can be called as a LangChain chain?
def _is_langchain_chain(self, name):
@@ -513,13 +527,22 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
- # TODO: handle auth strategies besides EnvAuthStrategy
if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
raise OSError(
- f"Authentication environment variable {auth_strategy.name} not provided.\n"
+ f"Authentication environment variable {auth_strategy.name} is not set.\n"
f"An authentication token is required to use models from the {Provider.name} provider.\n"
f"Please specify it via `%env {auth_strategy.name}=token`. "
) from None
+ if auth_strategy.type == "multienv":
+ # Multiple environment variables must be set
+ missing_vars = [
+ var for var in auth_strategy.names if var not in os.environ
+ ]
+ raise OSError(
+ f"Authentication environment variables {missing_vars} are not set.\n"
+ f"Multiple authentication tokens are required to use models from the {Provider.name} provider.\n"
+ f"Please specify them all via `%env` commands. "
+ ) from None
# configure and instantiate provider
provider_params = {"model_id": local_model_id}
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index 8d353638d..525cb931e 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -23,6 +23,7 @@
BedrockChat,
ChatAnthropic,
ChatOpenAI,
+ QianfanChatEndpoint,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms import (
@@ -34,6 +35,7 @@
HuggingFaceHub,
OpenAI,
OpenAIChat,
+ QianfanLLMEndpoint,
SagemakerEndpoint,
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
@@ -54,7 +56,7 @@ class EnvAuthStrategy(BaseModel):
class MultiEnvAuthStrategy(BaseModel):
"""Require multiple auth tokens via multiple environment variables."""
- type: Literal["file"] = "file"
+ type: Literal["multienv"] = "multienv"
names: List[str]
@@ -775,3 +777,13 @@ async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
@property
def allows_concurrency(self):
return not "anthropic" in self.model_id
+
+
+# Baidu QianfanChat provider. temporarily living as a separate class until
+class QianfanProvider(BaseProvider, QianfanChatEndpoint):
+ id = "qianfan"
+ name = "ERNIE-Bot"
+ models = ["ERNIE-Bot", "ERNIE-Bot-4"]
+ model_id_key = "model_name"
+ pypi_package_deps = ["qianfan"]
+ auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml
index f280f1619..05e11a10d 100644
--- a/packages/jupyter-ai-magics/pyproject.toml
+++ b/packages/jupyter-ai-magics/pyproject.toml
@@ -50,7 +50,8 @@ all = [
"ipywidgets",
"pillow",
"openai",
- "boto3"
+ "boto3",
+ "qianfan"
]
[project.entry-points."jupyter_ai.model_providers"]
@@ -67,6 +68,7 @@ sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"
+qianfan = "jupyter_ai_magics:QianfanProvider"
[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
@@ -74,6 +76,7 @@ cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
+qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"
[tool.hatch.version]
source = "nodejs"
diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx
index 583cfa9f0..9f048d827 100644
--- a/packages/jupyter-ai/src/components/chat-settings.tsx
+++ b/packages/jupyter-ai/src/components/chat-settings.tsx
@@ -102,12 +102,27 @@ export function ChatSettings(): JSX.Element {
) {
newApiKeys[lmAuth.name] = '';
}
+ if (lmAuth?.type === 'multienv') {
+ lmAuth.names.forEach(apiKey => {
+ if (!server.config.api_keys.includes(apiKey)) {
+ newApiKeys[apiKey] = '';
+ }
+ });
+ }
+
if (
emAuth?.type === 'env' &&
!server.config.api_keys.includes(emAuth.name)
) {
newApiKeys[emAuth.name] = '';
}
+ if (emAuth?.type === 'multienv') {
+ emAuth.names.forEach(apiKey => {
+ if (!server.config.api_keys.includes(apiKey)) {
+ newApiKeys[apiKey] = '';
+ }
+ });
+ }
setApiKeys(newApiKeys);
}, [lmProvider, emProvider, server]);
diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts
index ec91c1f3f..809213b1f 100644
--- a/packages/jupyter-ai/src/handler.ts
+++ b/packages/jupyter-ai/src/handler.ts
@@ -135,7 +135,16 @@ export namespace AiService {
type: 'aws';
};
- export type AuthStrategy = EnvAuthStrategy | AwsAuthStrategy | null;
+ export type MultiEnvAuthStrategy = {
+ type: 'multienv';
+ names: string[];
+ };
+
+ export type AuthStrategy =
+ | AwsAuthStrategy
+ | EnvAuthStrategy
+ | MultiEnvAuthStrategy
+ | null;
export type TextField = {
type: 'text';