Skip to content

Commit

Permalink
gemini implmentaiton
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Feb 20, 2025
1 parent 6ce241d commit 3bef1e8
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 143 deletions.
24 changes: 20 additions & 4 deletions examples/providers/gemini_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,24 @@
# custom client
client = genai.Client()

@ell.simple(model='gemini-2.0-flash', client=client, max_tokens=10)
def chat(prompt: str) -> str:
return prompt
from PIL import Image, ImageDraw

print(chat("Hello, how are you?"))
# Create a new image with white background
img = Image.new('RGB', (512, 512), 'white')

# Create a draw object
draw = ImageDraw.Draw(img)

# Draw a red dot in the middle (using a small filled circle)
center = (256, 256) # Middle of 512x512
radius = 5 # Size of the dot
draw.ellipse([center[0]-radius, center[1]-radius,
center[0]+radius, center[1]+radius],
fill='red')


@ell.simple(model='gemini-2.0-flash', client=client, max_tokens=10000)
def chat(prompt: str):
return [ell.user([prompt + " what is in this image", img])]

print(chat("Write me a really long story about"))
3 changes: 2 additions & 1 deletion src/ell/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
FrozenSet,
List,
Mapping,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -83,7 +84,7 @@ def available_api_params(self, client: Any, api_params: Optional[Dict[str, Any]]
### TRANSLATION ###############
################################
@abstractmethod
def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]:
def translate_to_provider(self, ell_call: EllCallParams) -> Mapping[str, Any]:
"""Converts an ell call to provider call params!"""
return NotImplemented

Expand Down
242 changes: 104 additions & 138 deletions src/ell/providers/google.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, FrozenSet, Iterator, List, Literal, Optional, Tuple, Type, TypedDict, Union, cast
from ell.provider import EllCallParams, Metadata, Provider
from ell.types import Message, ContentBlock, ToolCall, ImageContent

Expand All @@ -12,153 +12,119 @@
import requests
from PIL import Image as PILImage

# TODO: Supported:
# Streaming
# Text in, image in, text out
# TODO: Not supported:
# tool use
# function calling
# structured output



try:

from google import genai
import google.genai.types as types

class MessageCreateParamsStreaming(TypedDict):
model: str
contents: Union[types.ContentListUnion, types.ContentListUnionDict]
config: Optional[types.GenerateContentConfigOrDict]


class GoogleProvider(Provider):
dangerous_disable_validation = False

def provider_call_function(self, client : genai.Client, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]:
return client.models.generate_content_stream

def translate_to_provider(self, ell_call : EllCallParams):
final_call_params = cast(MessageCreateParamsStreaming, ell_call.api_params.copy())
# XXX: Helper, but should be depreicated due to ssot
assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))."

dirty_msgs = [
MessageParam(
role=cast(Literal["user", "assistant"], message.role),
content=[_content_block_to_anthropic_format(c) for c in message.content]) for message in ell_call.messages]
role_correct_msgs : List[MessageParam] = []
for msg in dirty_msgs:
if (not len(role_correct_msgs) or role_correct_msgs[-1]['role'] != msg['role']):
role_correct_msgs.append(msg)
else: cast(List, role_correct_msgs[-1]['content']).extend(msg['content'])

def disallowed_api_params(self) -> FrozenSet[str]:
return frozenset({"messages", "tools", "model", "stream", "stream_options", "system_instruction", "n"})

def translate_to_provider(self, ell_call : EllCallParams) -> MessageCreateParamsStreaming:
# final_call_params = cast(MessageCreateParamsStreaming, ell_call.api_params.copy())
# # XXX: Helper, but should be depreicated due to ssot
assert not ell_call.tools, "Provider does not yet support tools"

clean_api_params = ell_call.api_params.copy()
clean_api_params.pop("stream", None)
if "max_tokens" in clean_api_params:
clean_api_params["max_output_tokens"] = clean_api_params.pop("max_tokens")

system_message = None
if role_correct_msgs and role_correct_msgs[0]["role"] == "system":
system_message = role_correct_msgs.pop(0)
msgs = [
types.Content(
role=message.role,
parts=[
_content_block_to_google_format(c)
for c in message.content
]
)
for message in ell_call.messages
]

system_instruction : Optional[types.ContentUnion] = None
system_msg = next((m for m in msgs if m.role == "system"), None)
if system_msg:
system_instruction = system_msg
msgs = [m for m in msgs if m.role != "system"]

if system_message:
final_call_params["system"] = system_message["content"][0]["text"]


final_call_params['stream'] = True
final_call_params["model"] = ell_call.model
final_call_params["messages"] = role_correct_msgs

if ell_call.tools:
final_call_params["tools"] = [
#XXX: Cleaner with LMP's as a class.
dict(
name=tool.__name__,
description=tool.__doc__,
input_schema=tool.__ell_params_model__.model_json_schema(),
)
for tool in ell_call.tools
]

# print(final_call_params)
return final_call_params
return MessageCreateParamsStreaming(
model=ell_call.model,
contents=msgs,
config=types.GenerateContentConfig(
**clean_api_params,
system_instruction=system_instruction,
response_modalities=["text"], # Text only for now
automatic_function_calling=None, # TODO: Support.
) # performs pydantic calidation.
)

def translate_from_provider(
self,
provider_response : Union[Stream[RawMessageStreamEvent], AnthropicMessage],
provider_response : Iterator[types.GenerateContentResponse],
ell_call: EllCallParams,
provider_call_params: Dict[str, Any],
origin_id: Optional[str] = None,
logger: Optional[Callable[..., None]] = None,
) -> Tuple[List[Message], Metadata]:

usage = {}
tracked_results = []
metadata = {}

#XXX: Support n > 0

if provider_call_params.get("stream", False):
content = []
current_blocks: Dict[int, Dict[str, Any]] = {}
message_metadata = {}

with cast(Stream[RawMessageStreamEvent], provider_response) as stream:
for chunk in stream:
if chunk.type == "message_start":
message_metadata = chunk.message.model_dump()
message_metadata.pop("content", None) # Remove content as we'll build it separately

elif chunk.type == "content_block_start":
block = chunk.content_block.model_dump()
current_blocks[chunk.index] = block
if block["type"] == "tool_use":
if logger: logger(f" <tool_use: {block['name']}(")
block["input"] = "" # force it to be a string, XXX: can implement partially parsed json later.
elif chunk.type == "content_block_delta":
if chunk.index in current_blocks:
block = current_blocks[chunk.index]
if (delta := chunk.delta).type == "text_delta":
block["text"] += delta.text
if logger: logger(delta.text)
if delta.type == "input_json_delta":
block["input"] += delta.partial_json
if logger: logger(delta.partial_json)

elif chunk.type == "content_block_stop":
if chunk.index in current_blocks:
block = current_blocks.pop(chunk.index)
if block["type"] == "text":
content.append(ContentBlock(text=_lstr(block["text"],origin_trace=origin_id)))
elif block["type"] == "tool_use":
try:
matching_tool = ell_call.get_tool_by_name(block["name"])
if matching_tool:
content.append(
ContentBlock(
tool_call=ToolCall(
tool=matching_tool,
tool_call_id=_lstr(
block['id'],origin_trace=origin_id
),
params=json.loads(block['input']) if block['input'] else {},
)
)
)
except json.JSONDecodeError:
if logger: logger(f" - FAILED TO PARSE JSON")
pass
if logger: logger(f")>")

elif chunk.type == "message_delta":
message_metadata.update(chunk.delta.model_dump())
if chunk.usage:
usage.update(chunk.usage.model_dump())

elif chunk.type == "message_stop":
tracked_results.append(Message(role="assistant", content=content))

# print(chunk)
metadata = message_metadata
message_metadata : Optional[types.GenerateContentResponseUsageMetadata] = None
total_text = ""
for chunk in provider_response:
message_metadata = chunk.usage_metadata if chunk.usage_metadata else message_metadata
text = chunk.text
if text:
if logger: logger(text)
total_text += text
content = [ContentBlock(text=_lstr(total_text,origin_trace=origin_id))]


# process metadata for ell
# XXX: Unify an ell metadata format for ell studio.
usage["prompt_tokens"] = usage.get("input_tokens", 0)
usage["completion_tokens"] = usage.get("output_tokens", 0)
usage["total_tokens"] = usage['prompt_tokens'] + usage['completion_tokens']
if message_metadata:
usage["prompt_tokens"] = message_metadata.prompt_token_count
usage["completion_tokens"] = message_metadata.candidates_token_count
usage["total_tokens"] = message_metadata.total_token_count

metadata["usage"] = usage
return tracked_results, metadata
metadata["usage"] = usage

return [Message(role="assistant", content=content)], metadata

# XXX: Make a singleton.
anthropic_provider = AnthropicProvider()
register_provider(anthropic_provider, anthropic.Anthropic)
register_provider(anthropic_provider, anthropic.AnthropicBedrock)
register_provider(anthropic_provider, anthropic.AnthropicVertex)
google_provider = GoogleProvider()
register_provider(google_provider, genai.Client)

except ImportError:
pass

def serialize_image_for_anthropic(img : ImageContent):
def serialize_image_for_google(img : ImageContent):
if img.url:
# Download the image from the URL
response = requests.get(img.url)
Expand All @@ -168,35 +134,35 @@ def serialize_image_for_anthropic(img : ImageContent):
pil_image = img.image
else:
raise ValueError("Image object has neither url nor image data.")
buffer = BytesIO()
pil_image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode()
# Convert PIL Image to bytes in memory
img_bytes_io= BytesIO()
pil_image.save(img_bytes_io, format='PNG')
img_byte_arr = img_bytes_io.getvalue()

return dict(
type="image",
source=dict(
type="base64",
media_type="image/png",
data=base64_image
inline_data=dict(
mime_type="image/png",
data=img_byte_arr
)
)

def _content_block_to_anthropic_format(content_block: ContentBlock):
if (image := content_block.image): return serialize_image_for_anthropic(image)
elif ((text := content_block.text) is not None): return dict(type="text", text=text)
def _content_block_to_google_format(content_block: ContentBlock):# -> "types.PartUnion"
if (image := content_block.image): return serialize_image_for_google(image)
elif ((text := content_block.text) is not None): return dict(text=text)
elif (parsed := content_block.parsed):
return dict(type="text", text=json.dumps(parsed.model_dump(), ensure_ascii=False))
elif (tool_call := content_block.tool_call):
return dict(
type="tool_use",
id=tool_call.tool_call_id,
name=tool_call.tool.__name__,
input=tool_call.params.model_dump()
)
elif (tool_result := content_block.tool_result):
return dict(
type="tool_result",
tool_use_id=tool_result.tool_call_id,
content=[_content_block_to_anthropic_format(c) for c in tool_result.result]
)
return dict(text=json.dumps(parsed.model_dump(), ensure_ascii=False))
# elif (tool_call := content_block.tool_call):
# return dict(
# type="tool_use",
# id=tool_call.tool_call_id,
# name=tool_call.tool.__name__,
# input=tool_call.params.model_dump()
# )
# elif (tool_result := content_block.tool_result):
# return dict(
# type="tool_result",
# tool_use_id=tool_result.tool_call_id,
# content=[_content_block_to_google_format(c) for c in tool_result.result]
# )
else:
raise ValueError("Content block is not supported by anthropic")

0 comments on commit 3bef1e8

Please sign in to comment.