Skip to content

Commit

Permalink
Fix issue with gemini functions with no parameters (#1562)
Browse files Browse the repository at this point in the history
## Description

Fixes #1558 

This ensures the Gemini model implementation is compatible with both
functions that have parameters and without parameters.

-------
Co-authored-by: Dirk Brand <[email protected]>
  • Loading branch information
dirkbrnd authored Dec 13, 2024
1 parent bd734bc commit 8f55f8b
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions phi/model/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
GenerateContentResponse as ResultGenerateContentResponse,
)
from google.protobuf.struct_pb2 import Struct
except ImportError:
except (ModuleNotFoundError, ImportError):
logger.error("`google-generativeai` not installed. Please install it using `pip install google-generativeai`")
raise

Expand Down Expand Up @@ -301,6 +301,7 @@ def format_functions(self, params: Dict[str, Any]) -> Dict[str, Any]:
Dict[str, Any]: The converted parameters dictionary compatible with Gemini.
"""
formatted_params = {}

for key, value in params.items():
if key == "properties" and isinstance(value, dict):
converted_properties = {}
Expand All @@ -322,8 +323,33 @@ def format_functions(self, params: Dict[str, Any]) -> Dict[str, Any]:
formatted_params[key] = converted_properties
else:
formatted_params[key] = value

return formatted_params

def _build_function_declaration(self, func: Function) -> FunctionDeclaration:
"""
Builds the function declaration for Gemini tool calling.
Args:
func: An instance of the function.
Returns:
FunctionDeclaration: The formatted function declaration.
"""
formatted_params = self.format_functions(func.parameters)
if "properties" in formatted_params and formatted_params["properties"]:
# We have parameters to add
return FunctionDeclaration(
name=func.name,
description=func.description,
parameters=formatted_params,
)
else:
return FunctionDeclaration(
name=func.name,
description=func.description,
)

def add_tool(
self,
tool: Union["Tool", "Toolkit", Callable, dict, "Function"],
Expand Down Expand Up @@ -356,11 +382,7 @@ def add_tool(
func._agent = agent
func.process_entrypoint()
self.functions[name] = func
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self.format_functions(func.parameters),
)
function_declaration = self._build_function_declaration(func)
self.function_declarations.append(function_declaration)
logger.debug(f"Function {name} from {tool.name} added to model.")

Expand All @@ -369,11 +391,8 @@ def add_tool(
tool._agent = agent
tool.process_entrypoint()
self.functions[tool.name] = tool
function_declaration = FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=self.format_functions(tool.parameters),
)

function_declaration = self._build_function_declaration(tool)
self.function_declarations.append(function_declaration)
logger.debug(f"Function {tool.name} added to model.")

Expand All @@ -383,11 +402,7 @@ def add_tool(
if function_name not in self.functions:
func = Function.from_callable(tool)
self.functions[func.name] = func
function_declaration = FunctionDeclaration(
name=func.name,
description=func.description,
parameters=self.format_functions(func.parameters),
)
function_declaration = self._build_function_declaration(func)
self.function_declarations.append(function_declaration)
logger.debug(f"Function '{func.name}' added to model.")
except Exception as e:
Expand Down

0 comments on commit 8f55f8b

Please sign in to comment.