Skip to content

Commit

Permalink
Backport PR #1136: Add base API URL field for Ollama and OpenAI embed…
Browse files Browse the repository at this point in the history
…ding models
  • Loading branch information
srdas authored and meeseeksmachine committed Dec 9, 2024
1 parent bfe8766 commit 802f085
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain_ollama import ChatOllama, OllamaEmbeddings

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import BaseProvider, EnvAuthStrategy, TextField
from ..providers import BaseProvider, TextField


class OllamaProvider(BaseProvider, ChatOllama):
Expand All @@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
id = "ollama"
name = "Ollama"
# source: https://ollama.com/library
model_id_key = "model"
models = [
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
]
model_id_key = "model"
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
model_id_key = "model"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
registry = True
fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
]


class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings):
Expand All @@ -122,5 +128,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding
auth_strategy = EnvAuthStrategy(
name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
)

registry = True
fields = [
TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
]
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,13 @@ def _provider_params(self, key, listing, completions: bool = False):
else:
fields = config.fields.get(model_uid, {})

# exclude empty fields
# TODO: modify the config manager to never save empty fields in the
# first place.
for field_key in fields:
if isinstance(fields[field_key], str) and not len(fields[field_key]):
fields[field_key] = None

# get authn fields
_, Provider = get_em_provider(model_uid, listing)
authn_fields = {}
Expand Down
63 changes: 43 additions & 20 deletions packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
const [sendWse, setSendWse] = useState<boolean>(false);
const [fields, setFields] = useState<Record<string, any>>({});
const [embeddingModelFields, setEmbeddingModelFields] = useState<
Record<string, any>
>({});

const [isCompleterEnabled, setIsCompleterEnabled] = useState(
props.completionProvider && props.completionProvider.isEnabled()
Expand Down Expand Up @@ -191,7 +194,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
const currFields: Record<string, any> =
server.config.fields?.[lmGlobalId] ?? {};
setFields(currFields);
}, [server, lmProvider]);

if (!emGlobalId) {
return;
}

const initEmbeddingModelFields: Record<string, any> =
server.config.fields?.[emGlobalId] ?? {};
setEmbeddingModelFields(initEmbeddingModelFields);
}, [server, lmGlobalId, emGlobalId]);

const handleSave = async () => {
// compress fields with JSON values
Expand Down Expand Up @@ -225,6 +236,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
}),
...(clmGlobalId && {
[clmGlobalId]: fields
}),
...(emGlobalId && {
[emGlobalId]: embeddingModelFields
})
}
}),
Expand Down Expand Up @@ -379,26 +393,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
{/* Embedding model section */}
<h2 className="jp-ai-ChatSettings-header">Embedding model</h2>
{server.emProviders.providers.length > 0 ? (
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
<Box>
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
)}
</Select>
{emGlobalId && (
<ModelFields
fields={emProvider?.fields}
values={embeddingModelFields}
onChange={setEmbeddingModelFields}
/>
)}
</Select>
</Box>
) : (
<p>No embedding models available.</p>
)}
Expand Down

0 comments on commit 802f085

Please sign in to comment.