diff --git a/examples/providers/gemini_ex.py b/examples/providers/gemini_ex.py index fdd13871c..0e962921a 100644 --- a/examples/providers/gemini_ex.py +++ b/examples/providers/gemini_ex.py @@ -7,7 +7,6 @@ ell.init(verbose=True) # custom client -client = genai.Client() from PIL import Image, ImageDraw @@ -25,7 +24,7 @@ fill='red') -@ell.simple(model='gemini-2.0-flash', client=client, max_tokens=10000) +@ell.simple(model='gemini-2.0-flash', max_tokens=10000) def chat(prompt: str): return [ell.user([prompt + " what is in this image", img])] diff --git a/src/ell/models/google.py b/src/ell/models/google.py index 0aac84d2d..a2a5b5e2d 100644 --- a/src/ell/models/google.py +++ b/src/ell/models/google.py @@ -21,6 +21,7 @@ """ import os +from typing import Optional from ell.configurator import config import openai @@ -29,7 +30,7 @@ logger = logging.getLogger(__name__) -def register(client: openai.Client): +def register(client: Optional[openai.Client] = None): """ Register OpenAI models with the provided client. @@ -46,7 +47,8 @@ def register(client: openai.Client): configuration with the registered models. """ standard_models = [ - 'gemini-2.0-flash-exp' + 'gemini-2.0-flash-exp', + 'gemini-2.0-flash', 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', @@ -58,10 +60,16 @@ def register(client: openai.Client): default_client = None try: - gemini_api_key = os.environ.get("GEMINI_API_KEY") + gemini_api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not gemini_api_key: - raise openai.OpenAIError("GEMINI_API_KEY not found in environment variables") - default_client = openai.Client(base_url="https://generativelanguage.googleapis.com/v1beta/openai/", api_key=gemini_api_key) + raise openai.OpenAIError("Neither GEMINI_API_KEY nor GOOGLE_API_KEY found in environment variables") + try: + from google import genai + default_client = genai.Client() + except ImportError: + logger.debug(f"{colorama.Fore.YELLOW}google.genai not found - using openai proxy for google models {colorama.Style.RESET_ALL}") + default_client = openai.Client(base_url="https://generativelanguage.googleapis.com/v1beta/openai/", api_key=gemini_api_key) + except openai.OpenAIError as e: pass diff --git a/src/ell/util/_warnings.py b/src/ell/util/_warnings.py index bd933a146..16db92cbf 100644 --- a/src/ell/util/_warnings.py +++ b/src/ell/util/_warnings.py @@ -55,7 +55,7 @@ def {fn.__name__}(...): ell.simple(model, client=my_client)(...) ``` {Style.RESET_ALL}""") - elif (client_to_use := config.registry[model].default_client) is None or not client_to_use.api_key: + elif (client_to_use := config.registry[model].default_client) is None or (hasattr(client_to_use, "api_key") and not client_to_use.api_key): logger.warning(_no_api_key_warning(model, fn.__name__, client_to_use, long=False))