Skip to content

Commit

Permalink
Better google function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 19, 2024
1 parent 5048418 commit 6bb026e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
2 changes: 0 additions & 2 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import json
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
Expand Down Expand Up @@ -257,5 +256,4 @@ def __to_tools(self, tools: list[BaseTool]) -> list[dict]:
}
)

print(json.dumps(tool_definitions, indent=2))
return tool_definitions
52 changes: 41 additions & 11 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterator
import json
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
Expand All @@ -16,6 +17,7 @@
TextPromptStackContent,
ActionCallPromptStackContent,
ActionResultPromptStackContent,
DeltaActionCallPromptStackContent,
)
from griptape.artifacts import TextArtifact, ActionArtifact
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -75,19 +77,28 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess
messages, **self._base_params(prompt_stack), stream=True
)

prompt_token_count = None
for chunk in response:
print(chunk.parts)
usage_metadata = chunk.usage_metadata

yield DeltaTextPromptStackContent(chunk.text)

# TODO: Only yield the first one
yield DeltaPromptStackMessage(
role=PromptStackMessage.ASSISTANT_ROLE,
delta_usage=DeltaPromptStackMessage.DeltaUsage(
input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count
),
)
for part in chunk.parts:
yield self.__message_content_delta_to_prompt_stack_content_delta(part)

# Only want to output the prompt token count once
if prompt_token_count is None:
prompt_token_count = usage_metadata.prompt_token_count
yield DeltaPromptStackMessage(
role=PromptStackMessage.ASSISTANT_ROLE,
delta_usage=DeltaPromptStackMessage.DeltaUsage(
input_tokens=usage_metadata.prompt_token_count,
output_tokens=usage_metadata.candidates_token_count,
),
)
else:
yield DeltaPromptStackMessage(
role=PromptStackMessage.ASSISTANT_ROLE,
delta_usage=DeltaPromptStackMessage.DeltaUsage(output_tokens=usage_metadata.candidates_token_count),
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig
Expand Down Expand Up @@ -177,6 +188,23 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent
else:
raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

def __message_content_delta_to_prompt_stack_content_delta(self, content: Part) -> BaseDeltaPromptStackContent:
MessageToDict = import_optional_dependency("google.protobuf.json_format").MessageToDict
# https://stackoverflow.com/questions/64403737/attribute-error-descriptor-while-trying-to-convert-google-vision-response-to-dic
content_dict = MessageToDict(content._pb)
if "text" in content_dict:
return DeltaTextPromptStackContent(content_dict["text"])
elif "functionCall" in content_dict:
function_call = content_dict["functionCall"]

name, path = function_call["name"].split("_", 1)

return DeltaActionCallPromptStackContent(
tag=function_call["name"], name=name, path=path, delta_input=json.dumps(function_call["args"])
)
else:
raise ValueError(f"Unsupported message content type {content_dict}")

def __to_role(self, message: PromptStackMessage) -> str:
if message.is_assistant():
return "model"
Expand All @@ -189,7 +217,9 @@ def __to_tools(self, tools: list[BaseTool]) -> list[dict]:
tool_declarations = []
for tool in tools:
for activity in tool.activities():
schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")
schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")["properties"][
"values"
]

schema = remove_key_in_dict_recursively(schema, "additionalProperties")
tool_declaration = FunctionDeclaration(
Expand Down

0 comments on commit 6bb026e

Please sign in to comment.