Skip to content

Commit

Permalink
Merge pull request #1304 from phidatahq/update-gemini-add_tool-fn-phi…
Browse files Browse the repository at this point in the history
…-1733

update-gemini-add_tool-fn-phi-1733
  • Loading branch information
ashpreetbedi authored Oct 22, 2024
2 parents e738b8b + 53f636d commit 43f7733
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 54 deletions.
1 change: 0 additions & 1 deletion cookbook/providers/google/agent_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
instructions=["Use tables where possible."],
markdown=True,
show_tool_calls=True,
debug_mode=True,
)

# Get the response in a variable
Expand Down
2 changes: 1 addition & 1 deletion cookbook/providers/google/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from phi.agent import Agent, RunResponse # noqa
from phi.model.google import Gemini

agent = Agent(model=Gemini(id="gemini-1.5-flash"), markdown=True, debug_mode=True)
agent = Agent(model=Gemini(id="gemini-1.5-flash"), markdown=True)

# Get the response in a variable
# run: RunResponse = agent.run("Share a 2 sentence horror story")
Expand Down
7 changes: 5 additions & 2 deletions phi/memory/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,13 @@ def get_messages_from_last_n_chats(
return messages_from_last_n_history

def get_message_pairs(
self, user_role: str = "user", assistant_role: str = "assistant"
self, user_role: str = "user", assistant_role: Optional[List[str]] = None
) -> List[Tuple[Message, Message]]:
"""Returns a list of tuples of (user message, assistant response)."""

if assistant_role is None:
assistant_role = ["assistant", "model", "CHATBOT"]

chats_as_message_pairs: List[Tuple[Message, Message]] = []
for chat in self.chats:
if chat.response and chat.response.messages:
Expand All @@ -182,7 +185,7 @@ def get_message_pairs(

# Start from the end to look for the assistant response
for message in chat.response.messages[::-1]:
if message.role == assistant_role:
if message.role in assistant_role:
assistant_messages_from_chat = message
break

Expand Down
99 changes: 49 additions & 50 deletions phi/model/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,57 +189,58 @@ def add_tool(self, tool: Union["Tool", "Toolkit", Callable, dict, "Function"]) -
Args:
tool: The tool to add. Can be a Tool, Toolkit, Callable, dict, or Function.
"""
# Initialize function declarations if necessary
if self.function_declarations is None:
self.function_declarations = []

# Initialize functions if necessary
if self.functions is None:
self.functions = {}

if isinstance(tool, Toolkit):
# Add all functions from the toolkit
self.functions.update(tool.functions)
for func in tool.functions.values():
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self._format_functions(func.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Functions from toolkit '{tool.name}' added to LLM.")

elif isinstance(tool, Function):
# Add the single Function instance
self.functions[tool.name] = tool
function_declaration = FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=self._format_functions(tool.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Function '{tool.name}' added to LLM.")

elif callable(tool):
# Convert the callable to a Function instance and add it
func = Function.from_callable(tool)
self.functions[func.name] = func
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self._format_functions(func.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Function '{func.name}' added to LLM.")

elif isinstance(tool, Tool):
logger.warning(f"Tool of type '{type(tool).__name__}' is not yet supported by Gemini.")

elif isinstance(tool, dict):
logger.warning("Tool of type 'dict' is not yet supported by Gemini.")

else:
logger.warning(f"Unsupported tool type: {type(tool).__name__}")
# If the tool is a Tool or Dict, log a warning.
if isinstance(tool, Tool) or isinstance(tool, Dict):
logger.warning("Tool of type 'Tool' or 'dict' is not yet supported by Gemini.")

# If the tool is a Callable or Toolkit, add its functions to the Model
elif callable(tool) or isinstance(tool, Toolkit) or isinstance(tool, Function):
if self.functions is None:
self.functions = {}

if isinstance(tool, Toolkit):
# For each function in the toolkit
for name, func in tool.functions.items():
# If the function does not exist in self.functions, add to self.tools
if name not in self.functions:
self.functions[name] = func
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self._format_functions(func.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Function {name} from {tool.name} added to model.")

elif isinstance(tool, Function):
if tool.name not in self.functions:
self.functions[tool.name] = tool
function_declaration = FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=self._format_functions(tool.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Function {tool.name} added to model.")

elif callable(tool):
try:
function_name = tool.__name__
if function_name not in self.functions:
func = Function.from_callable(tool)
self.functions[func.name] = func
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self._format_functions(func.parameters),
)
self.function_declarations.append(function_declaration)
logger.debug(f"Function '{func.name}' added to model.")
except Exception as e:
logger.warning(f"Could not add function {tool}: {e}")

def invoke(self, messages: List[Message]):
"""
Expand Down Expand Up @@ -477,8 +478,6 @@ def response(self, messages: List[Message]) -> ModelResponse:
response: GenerateContentResponse = self.invoke(messages=messages)
response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
logger.debug(f"Gemini response type: {type(response)}")
logger.debug(f"Gemini response: {response}")

# Create assistant message
assistant_message = self._create_assistant_message(response=response, response_timer=response_timer)
Expand Down

0 comments on commit 43f7733

Please sign in to comment.