Skip to content

Commit

Permalink
Add Ollama Prompt Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 18, 2024
1 parent 6ceadf6 commit 3cc531e
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseTask.parents_outputs` to get the textual output of all parent tasks.
- `BaseTask.parents_output_text` to get a concatenated string of all parent tasks' outputs.
- `parents_output_text` to Workflow context.
- `OllamaPromptModelDriver` for using models with Ollama.

### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
Expand Down
23 changes: 23 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,29 @@ agent.run(
)
```

### Ollama

!!! info
This driver requires the `drivers-prompt-ollama` [extra](../index.md#extras).

The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_driver.md) connects to the [Ollama Chat Completion API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion).

```python
from griptape.config import StructureConfig
from griptape.drivers import OllamaPromptDriver
from griptape.structures import Agent


agent = Agent(
config=StructureConfig(
prompt_driver=OllamaPromptDriver(
model="llama3",
),
),
)
agent.run("What color is the sky at different times of the day?")
```

### Hugging Face Hub

!!! info
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver
from .prompt.google_prompt_driver import GooglePromptDriver
from .prompt.dummy_prompt_driver import DummyPromptDriver
from .prompt.ollama_prompt_driver import OllamaPromptDriver

from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
Expand Down Expand Up @@ -109,6 +110,7 @@
"AmazonBedrockPromptDriver",
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand Down
69 changes: 69 additions & 0 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations
from collections.abc import Iterator
from typing import TYPE_CHECKING, Optional
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.utils import PromptStack, import_optional_dependency
from griptape.tokenizers import SimpleTokenizer

if TYPE_CHECKING:
from ollama import Client


@define
class OllamaPromptDriver(BasePromptDriver):
"""
Attributes:
model: Model name.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)
tokenizer: BaseTokenizer = field(
default=Factory(
lambda self: SimpleTokenizer(
characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens
),
takes_self=True,
),
kw_only=True,
)
options: dict = field(
default=Factory(
lambda self: {
"temperature": self.temperature,
"stop": self.tokenizer.stop_sequences,
"num_predict": self.max_tokens,
},
takes_self=True,
),
kw_only=True,
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
response = self.client.chat(**self._base_params(prompt_stack))

if isinstance(response, dict):
return TextArtifact(value=response["message"]["content"])
else:
raise Exception("invalid model response")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
stream = self.client.chat(**self._base_params(prompt_stack), stream=True)

if isinstance(stream, Iterator):
for chunk in stream:
yield TextArtifact(value=chunk["message"]["content"])
else:
raise Exception("invalid model response")

def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]

return {"messages": messages, "model": self.model, "options": self.options}
9 changes: 5 additions & 4 deletions griptape/tokenizers/base_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class BaseTokenizer(ABC):
max_output_tokens: int = field(kw_only=True, default=None)

def __attrs_post_init__(self) -> None:
if self.max_input_tokens is None:
self.max_input_tokens = self._default_max_input_tokens()
if hasattr(self, "model"):
if self.max_input_tokens is None:
self.max_input_tokens = self._default_max_input_tokens()

if self.max_output_tokens is None:
self.max_output_tokens = self._default_max_output_tokens()
if self.max_output_tokens is None:
self.max_output_tokens = self._default_max_output_tokens()

def count_input_tokens_left(self, text: str) -> int:
diff = self.max_input_tokens - self.count_tokens(text)
Expand Down
3 changes: 1 addition & 2 deletions griptape/tokenizers/simple_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations
from attrs import define, field
from typing import Optional
from griptape.tokenizers import BaseTokenizer


@define()
class SimpleTokenizer(BaseTokenizer):
model: Optional[str] = field(default=None, kw_only=True)
model: str = field(init=False, kw_only=True)
characters_per_token: int = field(kw_only=True)

def count_tokens(self, text: str) -> int:
Expand Down
19 changes: 17 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ voyageai = {version = "^0.2.1", optional = true}
elevenlabs = {version = "^1.1.2", optional = true}
torch = {version = "^2.3.0", optional = true}
pusher = {version = "^3.3.2", optional = true}
ollama = {version = "^0.2.1", optional = true}

# loaders
pandas = {version = "^1.3", optional = true}
Expand All @@ -69,6 +70,7 @@ drivers-prompt-huggingface-pipeline = ["huggingface-hub", "transformers", "torch
drivers-prompt-amazon-bedrock = ["boto3", "anthropic"]
drivers-prompt-amazon-sagemaker = ["boto3", "transformers"]
drivers-prompt-google = ["google-generativeai"]
drivers-prompt-ollama = ["ollama"]

drivers-sql-redshift = ["sqlalchemy-redshift", "boto3"]
drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connector-python"]
Expand Down Expand Up @@ -131,6 +133,7 @@ all = [
"elevenlabs",
"torch",
"pusher",
"ollama",

# loaders
"pandas",
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/drivers/prompt/test_ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from griptape.drivers import OllamaPromptDriver
from griptape.utils import PromptStack
import pytest


class TestOllamaPromptDriver:
@pytest.fixture
def mock_client(self, mocker):
mock_client = mocker.patch("ollama.Client")

mock_client.return_value.chat.return_value = {"message": {"content": "model-output"}}

return mock_client

@pytest.fixture
def mock_stream_client(self, mocker):
mock_stream_client = mocker.patch("ollama.Client")
mock_stream_client.return_value.chat.return_value = iter([{"message": {"content": "model-output"}}])

return mock_stream_client

def test_init(self):
assert OllamaPromptDriver(model="llama")

def test_try_run(self, mock_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_generic_input("generic-input")
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
driver = OllamaPromptDriver(model="llama")
expected_messages = [
{"role": "generic", "content": "generic-input"},
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
]

# When
text_artifact = driver.try_run(prompt_stack)

# Then
mock_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
)
assert text_artifact.value == "model-output"

def test_try_run_bad_response(self, mock_client):
# Given
prompt_stack = PromptStack()
driver = OllamaPromptDriver(model="llama")
mock_client.return_value.chat.return_value = "bad-response"

# When/Then
with pytest.raises(Exception, match="invalid model response"):
driver.try_run(prompt_stack)

def test_try_stream_run(self, mock_stream_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_generic_input("generic-input")
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
expected_messages = [
{"role": "generic", "content": "generic-input"},
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
]
driver = OllamaPromptDriver(model="llama", stream=True)

# When
text_artifact = next(driver.try_stream(prompt_stack))

# Then
mock_stream_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
stream=True,
)
assert text_artifact.value == "model-output"

def test_try_stream_bad_response(self, mock_stream_client):
# Given
prompt_stack = PromptStack()
driver = OllamaPromptDriver(model="llama", stream=True)
mock_stream_client.return_value.chat.return_value = "bad-response"

# When/Then
with pytest.raises(Exception, match="invalid model response"):
next(driver.try_stream(prompt_stack))
2 changes: 1 addition & 1 deletion tests/unit/tokenizers/test_simple_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class TestSimpleTokenizer:
@pytest.fixture
def tokenizer(self):
return SimpleTokenizer(model="any model", max_input_tokens=1024, max_output_tokens=4096, characters_per_token=6)
return SimpleTokenizer(max_input_tokens=1024, max_output_tokens=4096, characters_per_token=6)

def test_token_count(self, tokenizer):
assert tokenizer.count_tokens("foo bar huzzah") == 3
Expand Down

0 comments on commit 3cc531e

Please sign in to comment.