Skip to content

Commit

Permalink
chore: google-ai - gently handle the removal of function role (#1297)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Jan 16, 2025
1 parent fecf0e2 commit 501f31c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
4 changes: 2 additions & 2 deletions integrations/google_ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "google-generativeai>=0.3.1"]
dependencies = ["haystack-ai>=2.9.0", "google-generativeai>=0.3.1"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_ai_haystack#readme"
Expand Down Expand Up @@ -56,7 +56,7 @@ cov = ["test-cov", "cov-report"]
cov-retry = ["test-cov-retry", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]
[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11"]
python = ["3.9", "3.10", "3.11"]

[tool.hatch.envs.lint]
installer = "uv"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def _message_to_part(self, message: ChatMessage) -> Part:
p = Part()
p.text = message.text
return p
elif message.is_from(ChatRole.FUNCTION):
elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION):
p = Part()
p.function_response.name = message.name
p.function_response.response = message.text
return p
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
elif message.is_from(ChatRole.TOOL):
p = Part()
p.function_response.name = message.tool_call_result.origin.tool_name
p.function_response.response = message.tool_call_result.result
Expand All @@ -265,13 +265,13 @@ def _message_to_content(self, message: ChatMessage) -> Content:
elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT):
part = Part()
part.text = message.text
elif message.is_from(ChatRole.FUNCTION):
elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION):
part = Part()
part.function_response.name = message.name
part.function_response.response = message.text
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
elif message.is_from(ChatRole.TOOL):
part = Part()
part.function_response.name = message.tool_call_result.origin.tool_name
part.function_response.response = message.tool_call_result.result
Expand Down
40 changes: 21 additions & 19 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,17 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
if hasattr(ChatMessage, "from_function"):
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.text, str)
# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.text, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down Expand Up @@ -273,16 +274,17 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.text, str)
if hasattr(ChatMessage, "from_function"):
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.text, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down

0 comments on commit 501f31c

Please sign in to comment.