diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 44a2941c9b..8e8128d7aa 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -9,6 +9,7 @@ from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver from .prompt.google_prompt_driver import GooglePromptDriver from .prompt.dummy_prompt_driver import DummyPromptDriver +from .prompt.ollama_prompt_driver import OllamaPromptDriver from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver @@ -109,6 +110,7 @@ "AmazonBedrockPromptDriver", "GooglePromptDriver", "DummyPromptDriver", + "OllamaPromptDriver", "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py new file mode 100644 index 0000000000..4b2a24f9e1 --- /dev/null +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from collections.abc import Iterator +from attrs import define, field +from griptape.artifacts import TextArtifact +from griptape.drivers import BasePromptDriver +from griptape.utils import PromptStack, import_optional_dependency + + +@define +class OllamaPromptDriver(BasePromptDriver): + """ + Attributes: + api_key: Cohere API key. + model: Cohere model name. + client: Custom `cohere.Client`. + """ + + model: str = field(kw_only=True, metadata={"serializable": True}) + + def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + ollama = import_optional_dependency("ollama") + + response = ollama.chat(**self._base_params(prompt_stack)) + + return response["message"]["content"] + + def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + ollama = import_optional_dependency("ollama") + + stream = ollama.chat(**self._base_params(prompt_stack), stream=True) + + for chunk in stream: + yield TextArtifact(value=chunk["message"]["content"]) + + def _base_params(self, prompt_stack: PromptStack) -> dict: + messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs] + + return { + "message": messages, + "temperature": self.temperature, + "stop": self.tokenizer.stop_sequences, + "num_predict": self.max_tokens, + } diff --git a/poetry.lock b/poetry.lock index ed65bc07fd..678f5ffae8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3235,6 +3235,20 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "ollama" +version = "0.2.1" +description = "The official Python client for Ollama." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.2.1-py3-none-any.whl", hash = "sha256:b6e2414921c94f573a903d1069d682ba2fb2607070ea9e19ca4a7872f2a460ec"}, + {file = "ollama-0.2.1.tar.gz", hash = "sha256:fa316baa9a81eac3beb4affb0a17deb3008fdd6ed05b123c26306cfbe4c349b6"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" + [[package]] name = "openai" version = "1.30.1" @@ -6096,4 +6110,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "74de0d1e5ee382332635cea14dc0d39e288ab65adfec91569da0dbb06fd316e2" +content-hash = "04d13cfeec16de7ec4b56eea5abaa414249e165abd4c758a00fa8205440a294d" diff --git a/pyproject.toml b/pyproject.toml index d58432b8f9..418098522a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ voyageai = {version = "^0.2.1", optional = true} elevenlabs = {version = "^1.1.2", optional = true} torch = {version = "^2.3.0", optional = true} pusher = {version = "^3.3.2", optional = true} +ollama = {version = "^0.2.1", optional = true} # loaders pandas = {version = "^1.3", optional = true} @@ -69,6 +70,7 @@ drivers-prompt-huggingface-pipeline = ["huggingface-hub", "transformers", "torch drivers-prompt-amazon-bedrock = ["boto3", "anthropic"] drivers-prompt-amazon-sagemaker = ["boto3", "transformers"] drivers-prompt-google = ["google-generativeai"] +drivers-prompt-ollama = ["ollama"] drivers-sql-redshift = ["sqlalchemy-redshift", "boto3"] drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connector-python"] @@ -131,6 +133,7 @@ all = [ "elevenlabs", "torch", "pusher", + "ollama", # loaders "pandas",