From ab3ead239077de064fe9e92ebfc479c1473e52ae Mon Sep 17 00:00:00 2001 From: Ashpreet Bedi Date: Thu, 12 Dec 2024 19:13:10 +0000 Subject: [PATCH] Update Gemini multimodal --- cookbook/providers/google/.gitignore | 6 + .../google/image_agent_file_upload.py | 24 ++++ cookbook/providers/google/video_agent.py | 24 ++-- phi/agent/agent.py | 48 +++---- phi/knowledge/pdf.py | 1 - phi/model/base.py | 8 +- phi/model/google/gemini.py | 134 ++++++++++++------ phi/model/message.py | 12 +- 8 files changed, 166 insertions(+), 91 deletions(-) create mode 100644 cookbook/providers/google/.gitignore create mode 100644 cookbook/providers/google/image_agent_file_upload.py diff --git a/cookbook/providers/google/.gitignore b/cookbook/providers/google/.gitignore new file mode 100644 index 000000000..525cad7b9 --- /dev/null +++ b/cookbook/providers/google/.gitignore @@ -0,0 +1,6 @@ +*.jpg +*.png +*.mp3 +*.wav +*.mp4 +*.mp3 diff --git a/cookbook/providers/google/image_agent_file_upload.py b/cookbook/providers/google/image_agent_file_upload.py new file mode 100644 index 000000000..6a222f815 --- /dev/null +++ b/cookbook/providers/google/image_agent_file_upload.py @@ -0,0 +1,24 @@ +from pathlib import Path + +from phi.agent import Agent +from phi.model.google import Gemini +from phi.tools.duckduckgo import DuckDuckGo + +from google.generativeai import upload_file + +agent = Agent( + model=Gemini(id="gemini-2.0-flash-exp"), + tools=[DuckDuckGo()], + markdown=True, +) +# Please download the image using +# wget https://upload.wikimedia.org/wikipedia/commons/b/bf/Krakow_-_Kosciol_Mariacki.jpg +image_path = Path(__file__).parent.joinpath("Krakow_-_Kosciol_Mariacki.jpg") +image_file = upload_file(image_path) +print(f"Uploaded image: {image_file}") + +agent.print_response( + "Tell me about this image and give me the latest news about it.", + images=[image_file], + stream=True, +) diff --git a/cookbook/providers/google/video_agent.py b/cookbook/providers/google/video_agent.py index ae04a2858..094fffec9 100644 --- a/cookbook/providers/google/video_agent.py +++ b/cookbook/providers/google/video_agent.py @@ -1,16 +1,24 @@ +import time +from pathlib import Path + from phi.agent import Agent from phi.model.google import Gemini +from google.generativeai import upload_file, get_file agent = Agent( model=Gemini(id="gemini-2.0-flash-exp"), markdown=True, ) -# Please download "GreatRedSpot.mp4" using wget https://storage.googleapis.com/generativeai-downloads/images/GreatRedSpot.mp4 -agent.print_response( - "Tell me about this video", - videos=[ - "cookbook/providers/google/GreatRedSpot.mp4", - ], - stream=True, -) +# Please download "GreatRedSpot.mp4" using +# wget https://storage.googleapis.com/generativeai-downloads/images/GreatRedSpot.mp4 +video_path = Path(__file__).parent.joinpath("GreatRedSpot.mp4") +video_file = upload_file(video_path) +# Check whether the file is ready to be used. +while video_file.state.name == "PROCESSING": + time.sleep(2) + video_file = get_file(video_file.name) + +print(f"Uploaded video: {video_file}") + +agent.print_response("Tell me about this video", videos=[video_file], stream=True) diff --git a/phi/agent/agent.py b/phi/agent/agent.py index fcbfe0785..2b7ce303e 100644 --- a/phi/agent/agent.py +++ b/phi/agent/agent.py @@ -1128,9 +1128,9 @@ def get_user_message( self, *, message: Optional[Union[str, List]], - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, **kwargs: Any, ) -> Optional[Message]: """Return the user message for the Agent. @@ -1234,9 +1234,9 @@ def get_messages_for_run( self, *, message: Optional[Union[str, List, Dict, Message]] = None, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, **kwargs: Any, ) -> Tuple[Optional[Message], List[Message], List[Message]]: @@ -1719,9 +1719,9 @@ def _run( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: bool = False, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, stream_intermediate_steps: bool = False, **kwargs: Any, @@ -1940,9 +1940,9 @@ def run( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: Literal[False] = False, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, **kwargs: Any, ) -> RunResponse: ... @@ -1953,9 +1953,9 @@ def run( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: Literal[True] = True, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, stream_intermediate_steps: bool = False, **kwargs: Any, @@ -1966,9 +1966,9 @@ def run( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: bool = False, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, stream_intermediate_steps: bool = False, **kwargs: Any, @@ -2059,9 +2059,9 @@ async def _arun( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: bool = False, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, stream_intermediate_steps: bool = False, **kwargs: Any, @@ -2277,9 +2277,9 @@ async def arun( message: Optional[Union[str, List, Dict, Message]] = None, *, stream: bool = False, - audio: Optional[Dict] = None, - images: Optional[Sequence[Union[str, Dict]]] = None, - videos: Optional[Sequence[Union[str, Dict]]] = None, + audio: Optional[Any] = None, + images: Optional[Sequence[Any]] = None, + videos: Optional[Sequence[Any]] = None, messages: Optional[Sequence[Union[Dict, Message]]] = None, stream_intermediate_steps: bool = False, **kwargs: Any, diff --git a/phi/knowledge/pdf.py b/phi/knowledge/pdf.py index a7119d90e..a0a64bcff 100644 --- a/phi/knowledge/pdf.py +++ b/phi/knowledge/pdf.py @@ -47,4 +47,3 @@ def document_lists(self) -> Iterator[List[Document]]: yield self.reader.read(url=url) else: logger.error(f"Unsupported URL: {url}") - diff --git a/phi/model/base.py b/phi/model/base.py index a98dc72af..9b8be3b15 100644 --- a/phi/model/base.py +++ b/phi/model/base.py @@ -422,7 +422,7 @@ def _process_bytes_image(self, image: bytes) -> Dict[str, Any]: image_url = f"data:image/jpeg;base64,{base64_image}" return {"type": "image_url", "image_url": {"url": image_url}} - def process_image(self, image: Union[str, Dict, bytes]) -> Optional[Dict[str, Any]]: + def process_image(self, image: Any) -> Optional[Dict[str, Any]]: """Process an image based on the format.""" if isinstance(image, dict): @@ -437,9 +437,7 @@ def process_image(self, image: Union[str, Dict, bytes]) -> Optional[Dict[str, An logger.warning(f"Unsupported image type: {type(image)}") return None - def add_images_to_message( - self, message: Message, images: Optional[Sequence[Union[str, Dict, bytes]]] = None - ) -> Message: + def add_images_to_message(self, message: Message, images: Optional[Sequence[Any]] = None) -> Message: """ Add images to a message for the model. By default, we use the OpenAI image format but other Models can override this method to use a different image format. @@ -479,7 +477,7 @@ def add_images_to_message( message.content = message_content_with_image return message - def add_audio_to_message(self, message: Message, audio: Optional[Dict] = None) -> Message: + def add_audio_to_message(self, message: Message, audio: Optional[Any] = None) -> Message: """ Add audio to a message for the model. By default, we use the OpenAI audio format but other Models can override this method to use a different audio format. diff --git a/phi/model/google/gemini.py b/phi/model/google/gemini.py index d5d0b2925..56bba8814 100644 --- a/phi/model/google/gemini.py +++ b/phi/model/google/gemini.py @@ -1,4 +1,6 @@ +import time import json +from pathlib import Path from dataclasses import dataclass, field from typing import Optional, List, Iterator, Dict, Any, Union, Callable @@ -16,6 +18,7 @@ from google.generativeai import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse from google.generativeai.types.content_types import FunctionDeclaration, Tool as GeminiTool + from google.generativeai.types import file_types from google.ai.generativelanguage_v1beta.types.generative_service import ( GenerateContentResponse as ResultGenerateContentResponse, ) @@ -139,7 +142,7 @@ def format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: """ formatted_messages: List = [] for message in messages: - message_for_model = {} + message_for_model: Dict[str, Any] = {} # Add role to the message for the model role = "model" if message.role == "system" else "user" if message.role == "tool" else message.role @@ -147,23 +150,31 @@ def format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: # Add content to the message for the model content = message.content + # Initialize message_parts to be used for Gemini + message_parts: List[Any] = [] if not content or message.role in ["tool", "model"]: - parts = message.parts # type: ignore + message_parts = message.parts # type: ignore else: if isinstance(content, str): - parts = [content] + message_parts = [content] elif isinstance(content, list): - parts = content + message_parts = content else: - parts = [" "] - message_for_model["parts"] = parts + message_parts = [" "] # Add images to the message for the model if message.images is not None and message.role == "user": for image in message.images: - if isinstance(image, str): - # Case 1: Image is a URL - if image.startswith("http://") or image.startswith("https://"): + # Case 1: Image is a file_types.File object (Recommended) + # Add it as a File object + if isinstance(image, file_types.File): + # Google recommends that if using a single image, place the text prompt after the image. + message_parts.insert(0, image) + # Case 2: If image is a string, it is a URL or a local path + elif isinstance(image, str) or isinstance(image, Path): + # Case 2.1: Image is a URL + # Download the image from the URL and add it as base64 encoded data + if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")): try: import httpx import base64 @@ -173,80 +184,109 @@ def format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: "mime_type": "image/jpeg", "data": base64.b64encode(image_content).decode("utf-8"), } - message_for_model["parts"].append(image_data) # type: ignore + message_parts.append(image_data) # type: ignore except Exception as e: logger.warning(f"Failed to download image from {image}: {e}") continue - # Case 2: Image is a path + # Case 2.2: Image is a local path + # Open the image file and add it as base64 encoded data else: try: - from os.path import exists, isfile import PIL.Image except ImportError: logger.error("`PIL.Image not installed. Please install it using 'pip install pillow'`") raise try: - if exists(image) and isfile(image): - image_data = PIL.Image.open(image) # type: ignore + image_path = image if isinstance(image, Path) else Path(image) + if image_path.exists() and image_path.is_file(): + image_data = PIL.Image.open(image_path) # type: ignore else: - logger.error(f"Image file {image} does not exist.") + logger.error(f"Image file {image_path} does not exist.") raise - message_for_model["parts"].append(image_data) # type: ignore + message_parts.append(image_data) # type: ignore except Exception as e: - logger.warning(f"Failed to load image from {image}: {e}") + logger.warning(f"Failed to load image from {image_path}: {e}") continue - + # Case 3: Image is a bytes object + # Add it as base64 encoded data elif isinstance(image, bytes): image_data = {"mime_type": "image/jpeg", "data": base64.b64encode(image).decode("utf-8")} - message_for_model["parts"].append(image_data) + message_parts.append(image_data) + else: + logger.warning(f"Unknown image type: {type(image)}") + continue if message.videos is not None and message.role == "user": try: for video in message.videos: - import time - from os.path import exists, isfile - - video_file = None - if exists(video) and isfile(video): # type: ignore - video_file = genai.upload_file(path=video) - else: - logger.error(f"Video file {video} does not exist.") - raise - - # Check whether the file is ready to be used. - while video_file.state.name == "PROCESSING": - time.sleep(10) - video_file = genai.get_file(video_file.name) + # Case 1: Video is a file_types.File object (Recommended) + # Add it as a File object + if isinstance(video, file_types.File): + # Google recommends that if using a single video, place the text prompt after the video. + message_parts.insert(0, video) + # Case 2: If video is a string, it is a local path + elif isinstance(video, str) or isinstance(video, Path): + # Upload the video file to the Gemini API + video_file = None + video_path = video if isinstance(video, Path) else Path(video) + # Check if video is already uploaded + video_file_name = video_path.name + video_file_exists = genai.get_file(video_file_name) + if video_file_exists: + video_file = video_file_exists + else: + if video_path.exists() and video_path.is_file(): + video_file = genai.upload_file(path=video_path) + else: + logger.error(f"Video file {video_path} does not exist.") + raise - if video_file.state.name == "FAILED": - raise ValueError(video_file.state.name) + # Check whether the file is ready to be used. + while video_file.state.name == "PROCESSING": + time.sleep(2) + video_file = genai.get_file(video_file.name) - message_for_model["parts"].insert(0, video_file) # type: ignore + if video_file.state.name == "FAILED": + raise ValueError(video_file.state.name) + # Google recommends that if using a single video, place the text prompt after the video. + if video_file is not None: + message_parts.insert(0, video_file) # type: ignore except Exception as e: logger.warning(f"Failed to load video from {message.videos}: {e}") continue if message.audio is not None and message.role == "user": try: - from pathlib import Path - from os.path import exists, isfile - - audio = message.audio.get("data") - if audio: - audio_file = None - if exists(audio) and isfile(audio): - audio_file = {"mime_type": "audio/mp3", "data": Path(audio).read_bytes()} + # Case 1: Audio is a file_types.File object (Recommended) + # Add it as a File object + if isinstance(message.audio, file_types.File): + # Google recommends that if using a single audio, place the text prompt after the audio. + message_parts.insert(0, message.audio) # type: ignore + # Case 2: If audio is a string, it is a local path + elif isinstance(message.audio, str) or isinstance(message.audio, Path): + audio_path = message.audio if isinstance(message.audio, Path) else Path(message.audio) + if audio_path.exists() and audio_path.is_file(): + import mimetypes + + # Get mime type from file extension + mime_type = mimetypes.guess_type(audio_path)[0] or "audio/mp3" + audio_file = {"mime_type": mime_type, "data": audio_path.read_bytes()} + message_parts.insert(0, audio_file) # type: ignore else: - logger.error(f"Audio file {audio} does not exist.") + logger.error(f"Audio file {audio_path} does not exist.") raise - message_for_model["parts"].insert(0, audio_file) # type: ignore - + # Case 3: Audio is a bytes object + # Add it as base64 encoded data + elif isinstance(message.audio, bytes): + audio_file = {"mime_type": "audio/mp3", "data": message.audio} + message_parts.insert(0, audio_file) # type: ignore except Exception as e: logger.warning(f"Failed to load video from {message.videos}: {e}") continue + message_for_model["parts"] = message_parts formatted_messages.append(message_for_model) return formatted_messages diff --git a/phi/model/message.py b/phi/model/message.py index 4c118e236..f7f1580fe 100644 --- a/phi/model/message.py +++ b/phi/model/message.py @@ -35,9 +35,9 @@ class Message(BaseModel): tool_calls: Optional[List[Dict[str, Any]]] = None # Additional modalities - images: Optional[Sequence[Union[str, Dict]]] = None - videos: Optional[Sequence[Union[str, Dict]]] = None - audio: Optional[Dict] = None + audio: Optional[Any] = None + images: Optional[Sequence[Any]] = None + videos: Optional[Sequence[Any]] = None # -*- Attributes not sent to the model # The name of the tool called @@ -109,11 +109,11 @@ def log(self, level: Optional[str] = None): if self.tool_calls: _logger(f"Tool Calls: {json.dumps(self.tool_calls, indent=2)}") if self.images: - _logger(f"Number of Images: {len(self.images)}") + _logger(f"Images added: {len(self.images)}") if self.videos: - _logger(f"Number of Videos: {len(self.videos)}") + _logger(f"Videos added: {len(self.videos)}") if self.audio: - _logger(f"Number of Audio Files: {len(self.audio)}") + _logger(f"Audio files added: {len(self.audio)}") if "id" in self.audio: _logger(f"Audio ID: {self.audio['id']}") elif "data" in self.audio: