Skip to content

Commit

Permalink
Fix Ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 19, 2024
1 parent 02515ff commit e6e63b9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
24 changes: 18 additions & 6 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.utils import PromptStack, import_optional_dependency
from griptape.common import PromptStack, TextPromptStackContent
from griptape.utils import import_optional_dependency
from griptape.tokenizers import SimpleTokenizer
from griptape.common import (
PromptStackMessage,
BaseDeltaPromptStackContent,
DeltaPromptStackMessage,
DeltaTextPromptStackContent,
)

if TYPE_CHECKING:
from ollama import Client
Expand Down Expand Up @@ -46,24 +53,29 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage:
response = self.client.chat(**self._base_params(prompt_stack))

if isinstance(response, dict):
return TextArtifact(value=response["message"]["content"])
return PromptStackMessage(
content=[TextPromptStackContent(TextArtifact(value=response["message"]["content"]))],
role=PromptStackMessage.ASSISTANT_ROLE,
)
else:
raise Exception("invalid model response")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]:
stream = self.client.chat(**self._base_params(prompt_stack), stream=True)

if isinstance(stream, Iterator):
for chunk in stream:
yield TextArtifact(value=chunk["message"]["content"])
yield DeltaTextPromptStackContent(chunk["message"]["content"], role=PromptStackMessage.ASSISTANT_ROLE)
else:
raise Exception("invalid model response")

def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]
messages = [
{"role": message.role, "content": message.to_text_artifact().to_text()} for message in prompt_stack.messages
]

return {"messages": messages, "model": self.model, "options": self.options}
22 changes: 10 additions & 12 deletions tests/unit/drivers/prompt/test_ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent
from griptape.drivers import OllamaPromptDriver
from griptape.utils import PromptStack
from griptape.common import PromptStack
import pytest


Expand All @@ -25,13 +26,11 @@ def test_init(self):
def test_try_run(self, mock_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_generic_input("generic-input")
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
prompt_stack.add_assistant_message("assistant-input")
driver = OllamaPromptDriver(model="llama")
expected_messages = [
{"role": "generic", "content": "generic-input"},
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
Expand Down Expand Up @@ -61,12 +60,10 @@ def test_try_run_bad_response(self, mock_client):
def test_try_stream_run(self, mock_stream_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_generic_input("generic-input")
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
prompt_stack.add_assistant_message("assistant-input")
expected_messages = [
{"role": "generic", "content": "generic-input"},
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
Expand All @@ -83,7 +80,8 @@ def test_try_stream_run(self, mock_stream_client):
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
stream=True,
)
assert text_artifact.value == "model-output"
if isinstance(text_artifact, DeltaTextPromptStackContent):
assert text_artifact.text == "model-output"

def test_try_stream_bad_response(self, mock_stream_client):
# Given
Expand Down

0 comments on commit e6e63b9

Please sign in to comment.