From 821897ad7d25b60bffe8cd91760b397fcd46e1ff Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 16 Jan 2025 10:52:41 +0100 Subject: [PATCH] vertex: handle function role removal --- .../components/generators/google_vertex/chat/gemini.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 845e24f5f..516116321 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -215,7 +215,7 @@ def _message_to_part(self, message: ChatMessage) -> Part: return p elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): return Part.from_text(message.text) - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): return Part.from_function_response(name=message.name, response=message.text) elif message.is_from(ChatRole.USER): return self._convert_part(message.text) @@ -227,14 +227,15 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.function_call.args[k] = v elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part.from_text(message.text) - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): part = Part.from_function_response(name=message.name, response=message.text) elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + + role = "model" if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM) else "user" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage])