From 0c4004e7c704d15ba235bcffe033080fde592661 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 7 Jun 2024 14:28:00 -0700 Subject: [PATCH] WIP --- .../drivers/prompt/ollama_prompt_driver.py | 31 +++++++++++++------ poetry.lock | 7 +++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 4b2a24f9e1..3301794d48 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -1,28 +1,38 @@ from __future__ import annotations from collections.abc import Iterator -from attrs import define, field +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.utils import PromptStack, import_optional_dependency +from griptape.tokenizers import HuggingFaceTokenizer @define class OllamaPromptDriver(BasePromptDriver): """ Attributes: - api_key: Cohere API key. - model: Cohere model name. - client: Custom `cohere.Client`. + model: Model name. """ model: str = field(kw_only=True, metadata={"serializable": True}) + tokenizer: HuggingFaceTokenizer = field( + default=Factory( + lambda self: HuggingFaceTokenizer( + model=self.model, + tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), + max_output_tokens=self.max_tokens, + ), + takes_self=True, + ), + kw_only=True, + ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ollama = import_optional_dependency("ollama") response = ollama.chat(**self._base_params(prompt_stack)) - return response["message"]["content"] + return TextArtifact(value=response["message"]["content"]) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ollama = import_optional_dependency("ollama") @@ -36,8 +46,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs] return { - "message": messages, - "temperature": self.temperature, - "stop": self.tokenizer.stop_sequences, - "num_predict": self.max_tokens, + "messages": messages, + "model": self.model, + "options": { + "temperature": self.temperature, + "stop": self.tokenizer.stop_sequences, + "num_predict": self.max_tokens, + }, } diff --git a/poetry.lock b/poetry.lock index 678f5ffae8..b94c05b2ab 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3239,7 +3239,7 @@ files = [ name = "ollama" version = "0.2.1" description = "The official Python client for Ollama." -optional = false +optional = true python-versions = "<4.0,>=3.8" files = [ {file = "ollama-0.2.1-py3-none-any.whl", hash = "sha256:b6e2414921c94f573a903d1069d682ba2fb2607070ea9e19ca4a7872f2a460ec"}, @@ -6070,7 +6070,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6089,6 +6089,7 @@ drivers-prompt-cohere = ["cohere"] drivers-prompt-google = ["google-generativeai"] drivers-prompt-huggingface = ["huggingface-hub", "transformers"] drivers-prompt-huggingface-pipeline = ["huggingface-hub", "torch", "transformers"] +drivers-prompt-ollama = ["ollama"] drivers-sql-postgres = ["pgvector", "psycopg2-binary"] drivers-sql-redshift = ["boto3", "sqlalchemy-redshift"] drivers-sql-snowflake = ["snowflake-sqlalchemy"] @@ -6110,4 +6111,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "04d13cfeec16de7ec4b56eea5abaa414249e165abd4c758a00fa8205440a294d" +content-hash = "6ccbbba60b534e4756d1c36e37ce1379ee2126d37b822f186b5dbb8e8f701ff3"