From 802f0850611fa94574244641b9a57a144784403e Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Mon, 9 Dec 2024 15:48:40 -0800 Subject: [PATCH] Backport PR #1136: Add base API URL field for Ollama and OpenAI embedding models --- .../partner_providers/ollama.py | 8 ++- .../partner_providers/openai.py | 10 ++- .../jupyter-ai/jupyter_ai/config_manager.py | 7 +++ .../src/components/chat-settings.tsx | 63 +++++++++++++------ 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py index 5babc5adb..bf7d8474a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py @@ -1,7 +1,7 @@ from langchain_ollama import ChatOllama, OllamaEmbeddings from ..embedding_providers import BaseEmbeddingsProvider -from ..providers import BaseProvider, EnvAuthStrategy, TextField +from ..providers import BaseProvider, TextField class OllamaProvider(BaseProvider, ChatOllama): @@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): id = "ollama" name = "Ollama" # source: https://ollama.com/library + model_id_key = "model" models = [ "nomic-embed-text", "mxbai-embed-large", "all-minilm", "snowflake-arctic-embed", ] - model_id_key = "model" + registry = True + fields = [ + TextField(key="base_url", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index 7e97995d4..34ca76a8e 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -107,6 +107,12 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): model_id_key = "model" pypi_package_deps = ["langchain_openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + registry = True + fields = [ + TextField( + key="openai_api_base", label="Base API URL (optional)", format="text" + ), + ] class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings): @@ -122,5 +128,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding auth_strategy = EnvAuthStrategy( name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" ) - registry = True + fields = [ + TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 7b309faae..4732e152c 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -462,6 +462,13 @@ def _provider_params(self, key, listing, completions: bool = False): else: fields = config.fields.get(model_uid, {}) + # exclude empty fields + # TODO: modify the config manager to never save empty fields in the + # first place. + for field_key in fields: + if isinstance(fields[field_key], str) and not len(fields[field_key]): + fields[field_key] = None + # get authn fields _, Provider = get_em_provider(model_uid, listing) authn_fields = {} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 78312edee..5922bcff1 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -91,6 +91,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [embeddingModelFields, setEmbeddingModelFields] = useState< + Record + >({}); const [isCompleterEnabled, setIsCompleterEnabled] = useState( props.completionProvider && props.completionProvider.isEnabled() @@ -191,7 +194,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const currFields: Record = server.config.fields?.[lmGlobalId] ?? {}; setFields(currFields); - }, [server, lmProvider]); + + if (!emGlobalId) { + return; + } + + const initEmbeddingModelFields: Record = + server.config.fields?.[emGlobalId] ?? {}; + setEmbeddingModelFields(initEmbeddingModelFields); + }, [server, lmGlobalId, emGlobalId]); const handleSave = async () => { // compress fields with JSON values @@ -225,6 +236,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { }), ...(clmGlobalId && { [clmGlobalId]: fields + }), + ...(emGlobalId && { + [emGlobalId]: embeddingModelFields }) } }), @@ -379,26 +393,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { {/* Embedding model section */}

Embedding model

{server.emProviders.providers.length > 0 ? ( - { + const emGid = e.target.value === 'null' ? null : e.target.value; + setEmGlobalId(emGid); + }} + MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }} + > + None + {server.emProviders.providers.map(emp => + emp.models + .filter(em => em !== '*') // TODO: support registry providers + .map(em => ( + + {emp.name} :: {em} + + )) + )} + + {emGlobalId && ( + )} - + ) : (

No embedding models available.

)}