Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 7, 2024
1 parent 872779e commit 0c4004e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
31 changes: 22 additions & 9 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
from __future__ import annotations
from collections.abc import Iterator
from attrs import define, field
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
from griptape.drivers import BasePromptDriver
from griptape.utils import PromptStack, import_optional_dependency
from griptape.tokenizers import HuggingFaceTokenizer


@define
class OllamaPromptDriver(BasePromptDriver):
"""
Attributes:
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
model: Model name.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(
model=self.model,
tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
max_output_tokens=self.max_tokens,
),
takes_self=True,
),
kw_only=True,
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
ollama = import_optional_dependency("ollama")

response = ollama.chat(**self._base_params(prompt_stack))

return response["message"]["content"]
return TextArtifact(value=response["message"]["content"])

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
ollama = import_optional_dependency("ollama")
Expand All @@ -36,8 +46,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs]

return {
"message": messages,
"temperature": self.temperature,
"stop": self.tokenizer.stop_sequences,
"num_predict": self.max_tokens,
"messages": messages,
"model": self.model,
"options": {
"temperature": self.temperature,
"stop": self.tokenizer.stop_sequences,
"num_predict": self.max_tokens,
},
}
7 changes: 4 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0c4004e

Please sign in to comment.