From 7746f03ad69950b8abf2f1b795477c064eadfc89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 5 Nov 2024 08:54:21 +0100 Subject: [PATCH] Add Gemini integration --- outlines/models/__init__.py | 1 + outlines/models/gemini.py | 121 ++++++++++++++++++++++++ outlines/types/__init__.py | 2 + pyproject.toml | 2 + tests/models/test_gemini.py | 178 ++++++++++++++++++++++++++++++++++++ 5 files changed, 304 insertions(+) create mode 100644 outlines/models/gemini.py create mode 100644 tests/models/test_gemini.py diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index af24a8169..d1b12aecc 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -9,6 +9,7 @@ from typing import Union from .exllamav2 import ExLlamaV2Model, exl2 +from .gemini import Gemini from .llamacpp import LlamaCpp, llamacpp from .mlxlm import MLXLM, mlxlm from .openai import AzureOpenAI, OpenAI diff --git a/outlines/models/gemini.py b/outlines/models/gemini.py new file mode 100644 index 000000000..7614861c2 --- /dev/null +++ b/outlines/models/gemini.py @@ -0,0 +1,121 @@ +"""Integration with Gemini's API.""" +from enum import EnumMeta +from functools import singledispatchmethod +from types import NoneType +from typing import Optional, Union + +from pydantic import BaseModel +from typing_extensions import _TypedDictMeta # type: ignore + +from outlines.prompts import Vision +from outlines.types import Choice, Json + +__all__ = ["Gemini"] + + +class GeminiBase: + """Base class for the Gemini clients. + + `GeminiBase` is responsible for preparing the arguments to Gemini's + `generate_contents` methods: the input (prompt and possibly image), as well + as the output type (only JSON). + + """ + + @singledispatchmethod + def format_input(self, model_input): + """Generate the `messages` argument to pass to the client. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available with Gemini. The only available types are `str` and `Vision`." + ) + + @format_input.register(str) + def format_str_input(self, model_input: str): + """Generate the `messages` argument to pass to the client when the user + only passes a prompt. + + """ + return {"contents": [model_input]} + + @format_input.register(Vision) + def format_vision_input(self, model_input: Vision): + """Generate the `messages` argument to pass to the client when the user + passes a prompt and an image. + + """ + return {"contents": [model_input.prompt, model_input.image]} + + @singledispatchmethod + def format_output_type(self, output_type): + if output_type.__origin__ == list: + if len(output_type.__args__) == 1 and isinstance( + output_type.__args__[0], Json + ): + return { + "response_mime_type": "application/json", + "response_schema": list[ + output_type.__args__[0].original_definition + ], + } + else: + raise TypeError + else: + raise NotImplementedError + + @format_output_type.register(NoneType) + def format_none_output_type(self, output_type): + return {} + + @format_output_type.register(Json) + def format_json_output_type(self, output_type): + if issubclass(output_type.original_definition, BaseModel): + return { + "response_mime_type": "application/json", + "response_schema": output_type.original_definition, + } + elif isinstance(output_type.original_definition, _TypedDictMeta): + return { + "response_mime_type": "application/json", + "response_schema": output_type.original_definition, + } + else: + raise NotImplementedError + + @format_output_type.register(Choice) + def format_enum_output_type(self, output_type): + return { + "response_mime_type": "text/x.enum", + "response_schema": output_type.definition, + } + + +class Gemini(GeminiBase): + def __init__(self, model_name: str, *args, **kwargs): + import google.generativeai as genai + + self.client = genai.GenerativeModel(model_name, *args, **kwargs) + + def generate( + self, + model_input: Union[str, Vision], + output_type: Optional[Union[Json, EnumMeta]] = None, + **inference_kwargs, + ): + import google.generativeai as genai + + contents = self.format_input(model_input) + generation_config = genai.GenerationConfig( + **self.format_output_type(output_type) + ) + completion = self.client.generate_content( + generation_config=generation_config, **contents, **inference_kwargs + ) + + return completion.text diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index f59a05a70..ad6e50cd6 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -27,6 +27,8 @@ class Json: """ def __init__(self, definition: Union[str, dict, BaseModel]): + self.original_definition = definition + if isinstance(definition, type(BaseModel)): definition = definition.model_json_schema() if isinstance(definition, str): diff --git a/pyproject.toml b/pyproject.toml index 55c8b938a..bd0185f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ test = [ "mlx-lm; platform_machine == 'arm64' and sys_platform == 'darwin'", "huggingface_hub", "openai>=1.0.0", + "google-generativeai", "vllm; sys_platform != 'darwin'", "transformers", "pillow", @@ -112,6 +113,7 @@ module = [ "exllamav2.*", "jinja2", "jsonschema.*", + "google.*", "mamba_ssm.*", "mlx_lm.*", "mlx.*", diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py new file mode 100644 index 000000000..2014d64e9 --- /dev/null +++ b/tests/models/test_gemini.py @@ -0,0 +1,178 @@ +import io +import json +from enum import Enum +from typing import List + +import PIL +import pytest +import requests +from pydantic import BaseModel +from typing_extensions import TypedDict + +from outlines.models.gemini import Gemini +from outlines.prompts import Vision +from outlines.types import Choice, Json + +MODEL_NAME = "gemini-1.5-flash-latest" + + +def test_gemini_wrong_init_parameters(): + with pytest.raises(TypeError, match="got an unexpected"): + Gemini(MODEL_NAME, foo=10) + + +def test_gemini_wrong_inference_parameters(): + with pytest.raises(TypeError, match="got an unexpected"): + model = Gemini(MODEL_NAME) + model.generate("prompt", foo=10) + + +@pytest.mark.api_call +def test_gemini_simple_call(): + model = Gemini(MODEL_NAME) + result = model.generate("Respond with one word. Not more.") + assert isinstance(result, str) + + +@pytest.mark.api_call +def test_gemini_simple_vision(): + model = Gemini(MODEL_NAME) + + url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" + r = requests.get(url, stream=True) + if r.status_code == 200: + image = PIL.Image.open(io.BytesIO(r.content)) + + result = model.generate(Vision("What does this logo represent?", image)) + assert isinstance(result, str) + + +@pytest.mark.api_call +def test_gemini_simple_pydantic(): + model = Gemini(MODEL_NAME) + + class Foo(BaseModel): + bar: int + + result = model.generate("foo?", Json(Foo)) + assert isinstance(result, str) + assert "bar" in json.loads(result) + + +@pytest.mark.xfail(reason="Vision models do not work with structured outputs.") +@pytest.mark.api_call +def test_gemini_simple_vision_pydantic(): + model = Gemini(MODEL_NAME) + + url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" + r = requests.get(url, stream=True) + if r.status_code == 200: + image = PIL.Image.open(io.BytesIO(r.content)) + + class Logo(BaseModel): + name: int + + result = model.generate(Vision("What does this logo represent?", image), Logo) + assert isinstance(result, str) + assert "name" in json.loads(result) + + +@pytest.mark.xfail(reason="Gemini seems to be unable to follow nested schemas.") +@pytest.mark.api_call +def test_gemini_nested_pydantic(): + model = Gemini(MODEL_NAME) + + class Bar(BaseModel): + fu: str + + class Foo(BaseModel): + sna: int + bar: Bar + + result = model.generate("foo?", Json(Foo)) + assert isinstance(result, str) + assert "sna" in json.loads(result) + assert "bar" in json.loads(result) + assert "fu" in json.loads(result)["bar"] + + +@pytest.mark.xfail( + reason="The Gemini SDK's serialization method does not support Json Schema dictionaries." +) +@pytest.mark.api_call +def test_gemini_simple_json_schema_dict(): + model = Gemini(MODEL_NAME) + + schema = { + "properties": {"bar": {"title": "Bar", "type": "integer"}}, + "required": ["bar"], + "title": "Foo", + "type": "object", + } + result = model.generate("foo?", Json(schema)) + assert isinstance(result, str) + assert "bar" in json.loads(result) + + +@pytest.mark.xfail( + reason="The Gemini SDK's serialization method does not support Json Schema strings." +) +@pytest.mark.api_call +def test_gemini_simple_json_schema_string(): + model = Gemini(MODEL_NAME) + + schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}" + result = model.generate("foo?", Json(schema)) + assert isinstance(result, str) + assert "bar" in json.loads(result) + + +@pytest.mark.api_call +def test_gemini_simple_typed_dict(): + model = Gemini(MODEL_NAME) + + class Foo(TypedDict): + bar: int + + result = model.generate("foo?", Json(Foo)) + assert isinstance(result, str) + assert "bar" in json.loads(result) + + +@pytest.mark.api_call +def test_gemini_simple_enum(): + model = Gemini(MODEL_NAME) + + class Foo(Enum): + bar = "Bar" + foor = "Foo" + + result = model.generate("foo?", Choice(Foo)) + assert isinstance(result, str) + assert result == "Foo" or result == "Bar" + + +@pytest.mark.api_call +def test_gemini_simple_list_pydantic(): + model = Gemini(MODEL_NAME) + + class Foo(BaseModel): + bar: int + + result = model.generate("foo?", list[Json(Foo)]) + assert isinstance(json.loads(result), list) + assert isinstance(json.loads(result)[0], dict) + assert "bar" in json.loads(result)[0] + + +@pytest.mark.api_call +def test_gemini_simple_list_annotation_pydantic(): + model = Gemini(MODEL_NAME) + + class Foo(BaseModel): + bar: int + + result = model.generate("foo?", List[Json(Foo)]) + assert isinstance(json.loads(result), list) + assert isinstance(json.loads(result)[0], dict) + assert "bar" in json.loads(result)[0]