Skip to content

Commit

Permalink
Partial google support
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 18, 2024
1 parent 5c588db commit fc53df6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 38 deletions.
128 changes: 101 additions & 27 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Optional
import json

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import (
BaseDeltaPromptStackContent,
BasePromptStackContent,
Expand All @@ -15,14 +15,20 @@
PromptStack,
PromptStackMessage,
TextPromptStackContent,
ActionCallPromptStackContent,
ActionResultPromptStackContent,
)
from griptape.artifacts import TextArtifact, ActionCallArtifact
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, GoogleTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively

if TYPE_CHECKING:
from google.generativeai import GenerativeModel
from google.generativeai.types import ContentDict, GenerateContentResponse
from google.generativeai.protos import Part
from griptape.tools import BaseTool
from schema import Schema


@define
Expand All @@ -45,46 +51,29 @@ class GooglePromptDriver(BasePromptDriver):
)
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})

def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

messages = self._prompt_stack_to_messages(prompt_stack)
response: GenerateContentResponse = self.model_client.generate_content(
messages,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
),
messages, **self._base_params(prompt_stack)
)

usage_metadata = response.usage_metadata

return PromptStackMessage(
content=[TextPromptStackContent(TextArtifact(response.text))],
content=[self.__message_content_to_prompt_stack_content(part) for part in response.parts],
role=PromptStackMessage.ASSISTANT_ROLE,
usage=PromptStackMessage.Usage(
input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count
),
)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

messages = self._prompt_stack_to_messages(prompt_stack)
response: Iterator[GenerateContentResponse] = self.model_client.generate_content(
messages,
stream=True,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
),
messages, **self._base_params(prompt_stack), stream=True
)

for chunk in response:
Expand All @@ -100,6 +89,20 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
),
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

return {
"generation_config": GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
),
**self._prompt_stack_to_tools(prompt_stack),
}

def _default_model_client(self) -> GenerativeModel:
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)
Expand All @@ -120,21 +123,92 @@ def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:

return inputs

def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> ContentDict | str:
def _prompt_stack_to_tools(self, prompt_stack: PromptStack) -> dict:
return (
{
"tools": self.__to_tools(prompt_stack.actions),
"tool_config": {"function_calling_config": {"mode": self.tool_choice}},
}
if prompt_stack.actions and self.use_native_tools
else {}
)

def __message_content_to_prompt_stack_content(self, content: Part) -> BasePromptStackContent:
MessageToDict = import_optional_dependency("google.protobuf.json_format").MessageToDict
# https://stackoverflow.com/questions/64403737/attribute-error-descriptor-while-trying-to-convert-google-vision-response-to-dic
content_dict = MessageToDict(content._pb)

if "text" in content_dict:
return TextPromptStackContent(TextArtifact(content_dict["text"]))
elif "functionCall" in content_dict:
function_call = content_dict["functionCall"]

name, path = function_call["name"].split("_", 1)

return ActionCallPromptStackContent(
artifact=ActionCallArtifact(
value=ActionCallArtifact.ActionCall(
tag=function_call["name"], name=name, path=path, input=json.dumps(function_call["args"])
)
)
)
else:
raise ValueError(f"Unsupported message content type {content_dict}")

def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> ContentDict | Part | str:
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
Part = import_optional_dependency("google.generativeai.protos").Part
FunctionCall = import_optional_dependency("google.generativeai.protos").FunctionCall
FunctionResponse = import_optional_dependency("google.generativeai.protos").FunctionResponse

if isinstance(content, TextPromptStackContent):
return content.artifact.to_text()
elif isinstance(content, ImagePromptStackContent):
return ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
elif isinstance(content, ActionCallPromptStackContent):
action = content.artifact.value

return Part(function_call=FunctionCall(name=action.tag, args=json.loads(action.input)))
elif isinstance(content, ActionResultPromptStackContent):
artifact = content.artifact

return Part(
function_response=FunctionResponse(
name=f"{content.action_name}_{content.action_path}", response=artifact.to_dict()
)
)

else:
raise ValueError(f"Unsupported content type: {type(content)}")
raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

def __to_role(self, message: PromptStackMessage) -> str:
if message.is_assistant():
return "model"
else:
return "user"

def __to_content(self, message: PromptStackMessage) -> list[ContentDict | str]:
def __to_tools(self, tools: list[BaseTool]) -> list[dict]:
FunctionDeclaration = import_optional_dependency("google.generativeai.types").FunctionDeclaration

tool_declarations = []
for tool in tools:
for activity in tool.activities():
schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")

schema = remove_key_in_dict_recursively(schema, "additionalProperties")
tool_declaration = FunctionDeclaration(
name=f"{tool.name}_{tool.activity_name(activity)}",
description=tool.activity_description(activity),
parameters={
"type": schema["type"],
"properties": schema["properties"],
"required": schema.get("required", []),
},
)

tool_declarations.append(tool_declaration)

return tool_declarations

def __to_content(self, message: PromptStackMessage) -> list[ContentDict | str | Part]:
return [self.__prompt_stack_content_message_content(content) for content in message.content]
3 changes: 2 additions & 1 deletion griptape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .chat import Chat
from .futures import execute_futures_dict
from .token_counter import TokenCounter
from .dict_utils import remove_null_values_in_dict_recursively, dict_merge
from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively
from .file_utils import load_file, load_files
from .hash import str_to_hash
from .import_utils import import_optional_dependency
Expand Down Expand Up @@ -37,6 +37,7 @@ def minify_json(value: str) -> str:
"TokenCounter",
"remove_null_values_in_dict_recursively",
"dict_merge",
"remove_key_in_dict_recursively",
"Stream",
"load_artifact_from_memory",
"deprecation_warn",
Expand Down
7 changes: 7 additions & 0 deletions griptape/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ def remove_null_values_in_dict_recursively(d: dict) -> dict:
return d


def remove_key_in_dict_recursively(d: dict, key: str) -> dict:
if isinstance(d, dict):
return {k: remove_key_in_dict_recursively(v, key) for k, v in d.items() if k != key}
else:
return d


def dict_merge(dct: dict, merge_dct: dict, add_keys: bool = True) -> dict:
"""Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
Expand Down
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ redis = { version = "^4.6.0", optional = true }
opensearch-py = { version = "^2.3.1", optional = true }
pgvector = { version = "^0.2.3", optional = true }
psycopg2-binary = { version = "^2.9.9", optional = true }
google-generativeai = { version = "^0.6.0", optional = true }
google-generativeai = { version = "^0.7.0", optional = true }
protobuf = { version = "4.25.3", optional = true }
trafilatura = {version = "^1.6", optional = true}
playwright = {version = "^1.42", optional = true}
beautifulsoup4 = {version = "^4.12.3", optional = true}
Expand All @@ -69,7 +70,7 @@ drivers-prompt-huggingface = ["huggingface-hub", "transformers"]
drivers-prompt-huggingface-pipeline = ["huggingface-hub", "transformers", "torch"]
drivers-prompt-amazon-bedrock = ["boto3", "anthropic"]
drivers-prompt-amazon-sagemaker = ["boto3", "transformers"]
drivers-prompt-google = ["google-generativeai"]
drivers-prompt-google = ["google-generativeai", "protobuf"]
drivers-prompt-ollama = ["ollama"]

drivers-sql-redshift = ["sqlalchemy-redshift", "boto3"]
Expand Down Expand Up @@ -125,6 +126,7 @@ all = [
"pgvector",
"psycopg2-binary",
"google-generativeai",
"protobuf",
"trafilatura",
"playwright",
"beautifulsoup4",
Expand Down

0 comments on commit fc53df6

Please sign in to comment.