Skip to content

Commit

Permalink
Migrate to ChatOllama base class in Ollama provider (#1015)
Browse files Browse the repository at this point in the history
* Added separate `ollama` provider

Created a separate file `ollama.py` as a unique provider. Refactored other code accordingly.

Also changed the `Ollama` class to `ChatOllama` so that it can support binding tools to the LLM.

Updated the imports to come from `langchain_ollama` instead of `langchain_community`

Tested on several Ollama models, both LLMs and embedding models: `mxbai-embed-large`, `nomic-embed-text`, `ima/deepseek-math`, `mathstral`, `qwen2-math`, `snowflake-arctic-embed`, `mistral`, `llama3.1`, `starcoder2:15b-instruct`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
srdas and pre-commit-ci[bot] authored Sep 26, 2024
1 parent 0884211 commit e6ec9e9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 40 deletions.
2 changes: 0 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
BaseEmbeddingsProvider,
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OllamaEmbeddingsProvider,
QianfanEmbeddingsEndpointProvider,
)
from .exception import store_exception
Expand All @@ -21,7 +20,6 @@
BaseProvider,
GPT4AllProvider,
HfHubProvider,
OllamaProvider,
QianfanProvider,
TogetherAIProvider,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from langchain_community.embeddings import (
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
QianfanEmbeddingsEndpoint,
)

Expand Down Expand Up @@ -65,19 +64,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, **model_kwargs)


class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
id = "ollama"
name = "Ollama"
# source: https://ollama.com/library
models = [
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
]
model_id_key = "model"


class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
id = "huggingface_hub"
name = "Hugging Face Hub"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from langchain_ollama import ChatOllama, OllamaEmbeddings

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import BaseProvider, EnvAuthStrategy, TextField


class OllamaProvider(BaseProvider, ChatOllama):
id = "ollama"
name = "Ollama"
model_id_key = "model"
help = (
"See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
"Pass a model's name; for example, `deepseek-coder-v2`."
)
models = ["*"]
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]


class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
id = "ollama"
name = "Ollama"
# source: https://ollama.com/library
models = [
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
]
model_id_key = "model"
23 changes: 1 addition & 22 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.llms import (
AI21,
GPT4All,
HuggingFaceEndpoint,
Ollama,
Together,
)
from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

Expand Down Expand Up @@ -707,21 +701,6 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)


class OllamaProvider(BaseProvider, Ollama):
id = "ollama"
name = "Ollama"
model_id_key = "model"
help = (
"See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
"Pass a model's name; for example, `deepseek-coder-v2`."
)
models = ["*"]
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]


class TogetherAIProvider(BaseProvider, Together):
id = "togetherai"
name = "Together AI"
Expand Down
5 changes: 3 additions & 2 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ all = [
"langchain_mistralai",
"langchain_nvidia_ai_endpoints",
"langchain_openai",
"langchain_ollama",
"pillow",
"boto3",
"qianfan",
Expand All @@ -61,7 +62,7 @@ anthropic-chat = "jupyter_ai_magics.partner_providers.anthropic:ChatAnthropicPro
cohere = "jupyter_ai_magics.partner_providers.cohere:CohereProvider"
gpt4all = "jupyter_ai_magics:GPT4AllProvider"
huggingface_hub = "jupyter_ai_magics:HfHubProvider"
ollama = "jupyter_ai_magics:OllamaProvider"
ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaProvider"
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider"
openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider"
azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider"
Expand All @@ -83,7 +84,7 @@ cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider"
mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
ollama = "jupyter_ai_magics:OllamaEmbeddingsProvider"
ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaEmbeddingsProvider"
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"

Expand Down

0 comments on commit e6ec9e9

Please sign in to comment.