From 3bef1e8e055895df65fd66ca9ef27cd354dd8257 Mon Sep 17 00:00:00 2001 From: William Guss Date: Thu, 20 Feb 2025 13:15:17 -0800 Subject: [PATCH] gemini implmentaiton --- examples/providers/gemini_ex.py | 24 +++- src/ell/provider.py | 3 +- src/ell/providers/google.py | 242 ++++++++++++++------------------ 3 files changed, 126 insertions(+), 143 deletions(-) diff --git a/examples/providers/gemini_ex.py b/examples/providers/gemini_ex.py index 9d79fb191..fdd13871c 100644 --- a/examples/providers/gemini_ex.py +++ b/examples/providers/gemini_ex.py @@ -9,8 +9,24 @@ # custom client client = genai.Client() -@ell.simple(model='gemini-2.0-flash', client=client, max_tokens=10) -def chat(prompt: str) -> str: - return prompt +from PIL import Image, ImageDraw -print(chat("Hello, how are you?")) \ No newline at end of file +# Create a new image with white background +img = Image.new('RGB', (512, 512), 'white') + +# Create a draw object +draw = ImageDraw.Draw(img) + +# Draw a red dot in the middle (using a small filled circle) +center = (256, 256) # Middle of 512x512 +radius = 5 # Size of the dot +draw.ellipse([center[0]-radius, center[1]-radius, + center[0]+radius, center[1]+radius], + fill='red') + + +@ell.simple(model='gemini-2.0-flash', client=client, max_tokens=10000) +def chat(prompt: str): + return [ell.user([prompt + " what is in this image", img])] + +print(chat("Write me a really long story about")) \ No newline at end of file diff --git a/src/ell/provider.py b/src/ell/provider.py index f1fc34b4c..ca28eec39 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -9,6 +9,7 @@ Dict, FrozenSet, List, + Mapping, Optional, Set, Tuple, @@ -83,7 +84,7 @@ def available_api_params(self, client: Any, api_params: Optional[Dict[str, Any]] ### TRANSLATION ############### ################################ @abstractmethod - def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]: + def translate_to_provider(self, ell_call: EllCallParams) -> Mapping[str, Any]: """Converts an ell call to provider call params!""" return NotImplemented diff --git a/src/ell/providers/google.py b/src/ell/providers/google.py index 356ea6b7c..1fbc844d6 100644 --- a/src/ell/providers/google.py +++ b/src/ell/providers/google.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, Callable, Dict, FrozenSet, Iterator, List, Literal, Optional, Tuple, Type, TypedDict, Union, cast from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall, ImageContent @@ -12,59 +12,78 @@ import requests from PIL import Image as PILImage +# TODO: Supported: +# Streaming +# Text in, image in, text out +# TODO: Not supported: +# tool use +# function calling +# structured output + + + try: + from google import genai + import google.genai.types as types + class MessageCreateParamsStreaming(TypedDict): + model: str + contents: Union[types.ContentListUnion, types.ContentListUnionDict] + config: Optional[types.GenerateContentConfigOrDict] + + class GoogleProvider(Provider): dangerous_disable_validation = False def provider_call_function(self, client : genai.Client, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: return client.models.generate_content_stream - - def translate_to_provider(self, ell_call : EllCallParams): - final_call_params = cast(MessageCreateParamsStreaming, ell_call.api_params.copy()) - # XXX: Helper, but should be depreicated due to ssot - assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))." - - dirty_msgs = [ - MessageParam( - role=cast(Literal["user", "assistant"], message.role), - content=[_content_block_to_anthropic_format(c) for c in message.content]) for message in ell_call.messages] - role_correct_msgs : List[MessageParam] = [] - for msg in dirty_msgs: - if (not len(role_correct_msgs) or role_correct_msgs[-1]['role'] != msg['role']): - role_correct_msgs.append(msg) - else: cast(List, role_correct_msgs[-1]['content']).extend(msg['content']) + + def disallowed_api_params(self) -> FrozenSet[str]: + return frozenset({"messages", "tools", "model", "stream", "stream_options", "system_instruction", "n"}) + + def translate_to_provider(self, ell_call : EllCallParams) -> MessageCreateParamsStreaming: + # final_call_params = cast(MessageCreateParamsStreaming, ell_call.api_params.copy()) + # # XXX: Helper, but should be depreicated due to ssot + assert not ell_call.tools, "Provider does not yet support tools" + + clean_api_params = ell_call.api_params.copy() + clean_api_params.pop("stream", None) + if "max_tokens" in clean_api_params: + clean_api_params["max_output_tokens"] = clean_api_params.pop("max_tokens") - system_message = None - if role_correct_msgs and role_correct_msgs[0]["role"] == "system": - system_message = role_correct_msgs.pop(0) + msgs = [ + types.Content( + role=message.role, + parts=[ + _content_block_to_google_format(c) + for c in message.content + ] + ) + for message in ell_call.messages + ] + + system_instruction : Optional[types.ContentUnion] = None + system_msg = next((m for m in msgs if m.role == "system"), None) + if system_msg: + system_instruction = system_msg + msgs = [m for m in msgs if m.role != "system"] - if system_message: - final_call_params["system"] = system_message["content"][0]["text"] - - final_call_params['stream'] = True - final_call_params["model"] = ell_call.model - final_call_params["messages"] = role_correct_msgs - - if ell_call.tools: - final_call_params["tools"] = [ - #XXX: Cleaner with LMP's as a class. - dict( - name=tool.__name__, - description=tool.__doc__, - input_schema=tool.__ell_params_model__.model_json_schema(), - ) - for tool in ell_call.tools - ] - - # print(final_call_params) - return final_call_params + return MessageCreateParamsStreaming( + model=ell_call.model, + contents=msgs, + config=types.GenerateContentConfig( + **clean_api_params, + system_instruction=system_instruction, + response_modalities=["text"], # Text only for now + automatic_function_calling=None, # TODO: Support. + ) # performs pydantic calidation. + ) def translate_from_provider( self, - provider_response : Union[Stream[RawMessageStreamEvent], AnthropicMessage], + provider_response : Iterator[types.GenerateContentResponse], ell_call: EllCallParams, provider_call_params: Dict[str, Any], origin_id: Optional[str] = None, @@ -72,93 +91,40 @@ def translate_from_provider( ) -> Tuple[List[Message], Metadata]: usage = {} - tracked_results = [] metadata = {} #XXX: Support n > 0 - if provider_call_params.get("stream", False): - content = [] - current_blocks: Dict[int, Dict[str, Any]] = {} - message_metadata = {} - - with cast(Stream[RawMessageStreamEvent], provider_response) as stream: - for chunk in stream: - if chunk.type == "message_start": - message_metadata = chunk.message.model_dump() - message_metadata.pop("content", None) # Remove content as we'll build it separately - - elif chunk.type == "content_block_start": - block = chunk.content_block.model_dump() - current_blocks[chunk.index] = block - if block["type"] == "tool_use": - if logger: logger(f" ") - - elif chunk.type == "message_delta": - message_metadata.update(chunk.delta.model_dump()) - if chunk.usage: - usage.update(chunk.usage.model_dump()) - - elif chunk.type == "message_stop": - tracked_results.append(Message(role="assistant", content=content)) - - # print(chunk) - metadata = message_metadata + message_metadata : Optional[types.GenerateContentResponseUsageMetadata] = None + total_text = "" + for chunk in provider_response: + message_metadata = chunk.usage_metadata if chunk.usage_metadata else message_metadata + text = chunk.text + if text: + if logger: logger(text) + total_text += text + content = [ContentBlock(text=_lstr(total_text,origin_trace=origin_id))] + # process metadata for ell # XXX: Unify an ell metadata format for ell studio. - usage["prompt_tokens"] = usage.get("input_tokens", 0) - usage["completion_tokens"] = usage.get("output_tokens", 0) - usage["total_tokens"] = usage['prompt_tokens'] + usage['completion_tokens'] + if message_metadata: + usage["prompt_tokens"] = message_metadata.prompt_token_count + usage["completion_tokens"] = message_metadata.candidates_token_count + usage["total_tokens"] = message_metadata.total_token_count - metadata["usage"] = usage - return tracked_results, metadata + metadata["usage"] = usage + + return [Message(role="assistant", content=content)], metadata # XXX: Make a singleton. - anthropic_provider = AnthropicProvider() - register_provider(anthropic_provider, anthropic.Anthropic) - register_provider(anthropic_provider, anthropic.AnthropicBedrock) - register_provider(anthropic_provider, anthropic.AnthropicVertex) + google_provider = GoogleProvider() + register_provider(google_provider, genai.Client) except ImportError: pass -def serialize_image_for_anthropic(img : ImageContent): +def serialize_image_for_google(img : ImageContent): if img.url: # Download the image from the URL response = requests.get(img.url) @@ -168,35 +134,35 @@ def serialize_image_for_anthropic(img : ImageContent): pil_image = img.image else: raise ValueError("Image object has neither url nor image data.") - buffer = BytesIO() - pil_image.save(buffer, format="PNG") - base64_image = base64.b64encode(buffer.getvalue()).decode() + # Convert PIL Image to bytes in memory + img_bytes_io= BytesIO() + pil_image.save(img_bytes_io, format='PNG') + img_byte_arr = img_bytes_io.getvalue() + return dict( - type="image", - source=dict( - type="base64", - media_type="image/png", - data=base64_image + inline_data=dict( + mime_type="image/png", + data=img_byte_arr ) ) -def _content_block_to_anthropic_format(content_block: ContentBlock): - if (image := content_block.image): return serialize_image_for_anthropic(image) - elif ((text := content_block.text) is not None): return dict(type="text", text=text) +def _content_block_to_google_format(content_block: ContentBlock):# -> "types.PartUnion" + if (image := content_block.image): return serialize_image_for_google(image) + elif ((text := content_block.text) is not None): return dict(text=text) elif (parsed := content_block.parsed): - return dict(type="text", text=json.dumps(parsed.model_dump(), ensure_ascii=False)) - elif (tool_call := content_block.tool_call): - return dict( - type="tool_use", - id=tool_call.tool_call_id, - name=tool_call.tool.__name__, - input=tool_call.params.model_dump() - ) - elif (tool_result := content_block.tool_result): - return dict( - type="tool_result", - tool_use_id=tool_result.tool_call_id, - content=[_content_block_to_anthropic_format(c) for c in tool_result.result] - ) + return dict(text=json.dumps(parsed.model_dump(), ensure_ascii=False)) + # elif (tool_call := content_block.tool_call): + # return dict( + # type="tool_use", + # id=tool_call.tool_call_id, + # name=tool_call.tool.__name__, + # input=tool_call.params.model_dump() + # ) + # elif (tool_result := content_block.tool_result): + # return dict( + # type="tool_result", + # tool_use_id=tool_result.tool_call_id, + # content=[_content_block_to_google_format(c) for c in tool_result.result] + # ) else: raise ValueError("Content block is not supported by anthropic")