Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 7, 2024
1 parent 15045bd commit 872779e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 1 deletion.
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver
from .prompt.google_prompt_driver import GooglePromptDriver
from .prompt.dummy_prompt_driver import DummyPromptDriver
from .prompt.ollama_prompt_driver import OllamaPromptDriver

from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
Expand Down Expand Up @@ -109,6 +110,7 @@
"AmazonBedrockPromptDriver",
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand Down
43 changes: 43 additions & 0 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations
from collections.abc import Iterator
from attrs import define, field
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.utils import PromptStack, import_optional_dependency


@define
class OllamaPromptDriver(BasePromptDriver):
"""
Attributes:
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
"""

model: str = field(kw_only=True, metadata={"serializable": True})

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
ollama = import_optional_dependency("ollama")

response = ollama.chat(**self._base_params(prompt_stack))

return response["message"]["content"]

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
ollama = import_optional_dependency("ollama")

stream = ollama.chat(**self._base_params(prompt_stack), stream=True)

for chunk in stream:
yield TextArtifact(value=chunk["message"]["content"])

def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]

return {
"message": messages,
"temperature": self.temperature,
"stop": self.tokenizer.stop_sequences,
"num_predict": self.max_tokens,
}
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ voyageai = {version = "^0.2.1", optional = true}
elevenlabs = {version = "^1.1.2", optional = true}
torch = {version = "^2.3.0", optional = true}
pusher = {version = "^3.3.2", optional = true}
ollama = {version = "^0.2.1", optional = true}

# loaders
pandas = {version = "^1.3", optional = true}
Expand All @@ -69,6 +70,7 @@ drivers-prompt-huggingface-pipeline = ["huggingface-hub", "transformers", "torch
drivers-prompt-amazon-bedrock = ["boto3", "anthropic"]
drivers-prompt-amazon-sagemaker = ["boto3", "transformers"]
drivers-prompt-google = ["google-generativeai"]
drivers-prompt-ollama = ["ollama"]

drivers-sql-redshift = ["sqlalchemy-redshift", "boto3"]
drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connector-python"]
Expand Down Expand Up @@ -131,6 +133,7 @@ all = [
"elevenlabs",
"torch",
"pusher",
"ollama",

# loaders
"pandas",
Expand Down

0 comments on commit 872779e

Please sign in to comment.