Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Google Search Grounding static tool #147

Merged
merged 13 commits into from
Nov 15, 2024
Merged
7 changes: 6 additions & 1 deletion aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -36,9 +37,13 @@ def send_message_async(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> BisonPrompt:
tools.not_supported()
static_tools.not_supported()
return BisonPrompt.parse(messages)

@override
Expand Down
6 changes: 5 additions & 1 deletion aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -16,7 +17,10 @@
class ChatCompletionAdapter(ABC, Generic[P]):
@abstractmethod
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> P | UserError:
pass

Expand Down
42 changes: 36 additions & 6 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_5 import (
Gemini_1_5_Prompt,
)
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
Expand Down Expand Up @@ -105,14 +106,19 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiPrompt | UserError:
match self.deployment:
case ChatCompletionDeployment.GEMINI_PRO_1:
return await Gemini_1_0_Pro_Prompt.parse(tools, messages)
return await Gemini_1_0_Pro_Prompt.parse(
tools, static_tools, messages
)
case ChatCompletionDeployment.GEMINI_PRO_VISION_1:
return await Gemini_1_0_Pro_Vision_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case (
ChatCompletionDeployment.GEMINI_PRO_1_5_PREVIEW
Expand All @@ -122,7 +128,7 @@ async def parse_prompt(
| ChatCompletionDeployment.GEMINI_FLASH_1_5_V2
):
return await Gemini_1_5_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)
Expand All @@ -136,7 +142,10 @@ def _get_model(
parameters = create_generation_config(params) if params else None

if prompt is not None:
tools = prompt.tools.to_gemini_tools()
tools = (
adubovik marked this conversation as resolved.
Show resolved Hide resolved
prompt.tools.to_gemini_tools()
+ prompt.static_tools.to_gemini_tools()
)
tool_config = prompt.tools.to_gemini_tool_config()
system_instruction = cast(
List[str | Part | Image] | None,
Expand Down Expand Up @@ -190,6 +199,7 @@ async def process_chunks(
yield content

await create_function_calls(candidate, consumer, tools)
await create_grounding(candidate, consumer)
await create_attachments_from_citations(candidate, consumer)
await set_finish_reason(candidate, consumer)

Expand All @@ -211,7 +221,6 @@ async def chat(
)

completion = ""

async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
Expand Down Expand Up @@ -321,6 +330,27 @@ async def create_function_calls(
)


async def create_grounding(candidate: Candidate, consumer: Consumer) -> None:
# TODO: candidate.grounding_metadata.grounding_supports[i].segment
adubovik marked this conversation as resolved.
Show resolved Hide resolved
# actually points to piece of generation, that it grounds, like that
# segment {
# end_index: 61
# text: "Carlos Alcaraz won the men\'s singles title at Wimbledon 2024."
# }
# We need to figure out how to use such references in the future
if (
not candidate.grounding_metadata
or not candidate.grounding_metadata.grounding_chunks
):
return

for chunk in candidate.grounding_metadata.grounding_chunks:
if chunk.web and chunk.web.uri:
await consumer.add_attachment(
Attachment(url=chunk.web.uri, title=chunk.web.title)
)


def to_openai_finish_reason(
finish_reason: GenFinishReason, retriable: bool
) -> FinishReason | None:
Expand Down
11 changes: 11 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel, Field
from vertexai.preview.generative_models import Content, Part
from vertexai.preview.generative_models import Tool as GeminiTool

from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatablePrompt

Expand All @@ -20,6 +22,9 @@ class GeminiPrompt(BaseModel, TruncatablePrompt, ABC):
system_instruction: List[Part] | None = None
contents: List[Content]
tools: ToolsConfig = Field(default_factory=ToolsConfig.noop)
static_tools: StaticToolsConfig = Field(
default_factory=StaticToolsConfig.noop
)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -68,4 +73,10 @@ def select(self, indices: Set[int]) -> "GeminiPrompt":
system_instruction=system_instruction,
contents=contents,
tools=self.tools,
static_tools=self.static_tools,
)

def to_gemini_tools(self) -> List[GeminiTool]:
regular_tools = self.tools.to_gemini_tools()
static_tools = self.static_tools.to_gemini_tools()
return regular_tools + static_tools
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
)
from aidial_adapter_vertexai.chat.gemini.processor import AttachmentProcessors
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig


class Gemini_1_0_Pro_Prompt(GeminiPrompt):
@classmethod
async def parse(
cls, tools: ToolsConfig, messages: List[Message]
cls,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
raise ValidationError(
Expand All @@ -34,4 +38,5 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.request import get_attachments
from aidial_adapter_vertexai.dial_api.storage import FileStorage
Expand All @@ -28,9 +29,11 @@ async def parse(
cls,
file_storage: FileStorage | None,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Union[Self, UserError]:
tools.not_supported()
static_tools.not_supported()

if len(messages) == 0:
raise ValidationError(
Expand Down
3 changes: 3 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/gemini_1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.storage import FileStorage

Expand All @@ -25,6 +26,7 @@ async def parse(
cls,
file_storage: Optional[FileStorage],
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
Expand Down Expand Up @@ -55,6 +57,7 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)


Expand Down
8 changes: 6 additions & 2 deletions aidial_adapter_vertexai/chat/imagen/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import (
Expand Down Expand Up @@ -43,10 +44,13 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> ImagenPrompt:
tools.not_supported()

static_tools.not_supported()
if len(messages) == 0:
raise ValidationError("The list of messages must not be empty")

Expand Down
82 changes: 82 additions & 0 deletions aidial_adapter_vertexai/chat/static_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Self

from aidial_sdk.chat_completion.request import (
AzureChatCompletionRequest,
StaticTool,
)
from pydantic import BaseModel
from vertexai.preview.generative_models import Tool as GeminiTool
from vertexai.preview.generative_models import grounding

from aidial_adapter_vertexai.chat.errors import ValidationError


class ToolName(str, Enum):
"""
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding
adubovik marked this conversation as resolved.
Show resolved Hide resolved
"""

GOOGLE_SEARCH = "google_search"


class StaticToolProcessor(ABC):
@abstractmethod
def validate_config(self, config: dict | None) -> None: ...

@abstractmethod
def to_gemini_tools(self, tool: StaticTool) -> List[GeminiTool]: ...

def process(self, tool: StaticTool) -> List[GeminiTool]:
self.validate_config(tool.static_function.configuration)
return self.to_gemini_tools(tool)


class GoogleSearchGroundingTool(StaticToolProcessor):
def validate_config(self, config: dict | None) -> None:
if config:
raise ValidationError(
"Google search tool doesn't support configuration"
)

def to_gemini_tools(self, tool: StaticTool) -> List[GeminiTool]:
return [
adubovik marked this conversation as resolved.
Show resolved Hide resolved
GeminiTool.from_google_search_retrieval(
grounding.GoogleSearchRetrieval()
)
]


class StaticToolsConfig(BaseModel):
tools: List[StaticTool]

@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self:
if request.tools is None:
return cls(tools=[])

return cls(
tools=[
tool for tool in request.tools if isinstance(tool, StaticTool)
adubovik marked this conversation as resolved.
Show resolved Hide resolved
]
)

@classmethod
def noop(cls) -> Self:
return cls(tools=[])

def to_gemini_tools(self) -> List[GeminiTool]:
gemini_tools = []
adubovik marked this conversation as resolved.
Show resolved Hide resolved
for tool in self.tools:
if tool.static_function.name == ToolName.GOOGLE_SEARCH.value:
adubovik marked this conversation as resolved.
Show resolved Hide resolved
gemini_tools.extend(GoogleSearchGroundingTool().process(tool))
else:
raise ValidationError(
f"Unsupported static tool: {tool.static_function.name}"
)
return gemini_tools

def not_supported(self) -> None:
if len(self.tools) > 0:
adubovik marked this conversation as resolved.
Show resolved Hide resolved
raise ValidationError("Static tools aren't supported")
12 changes: 8 additions & 4 deletions aidial_adapter_vertexai/chat/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Role,
ToolChoice,
)
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest, Tool
from pydantic import BaseModel
from vertexai.preview.generative_models import (
FunctionDeclaration as GeminiFunction,
Expand Down Expand Up @@ -117,7 +117,11 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:
tool_ids = None

elif request.tools is not None:
functions = [tool.function for tool in request.tools]
functions = [
tool.function
for tool in request.tools
if isinstance(tool, Tool)
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
Expand All @@ -137,9 +141,9 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:

return cls(functions=selected, required=required, tool_ids=tool_ids)

def to_gemini_tools(self) -> List[GeminiTool] | None:
def to_gemini_tools(self) -> List[GeminiTool]:
if not self.functions:
return None
return []
adubovik marked this conversation as resolved.
Show resolved Hide resolved

return [
GeminiTool(
Expand Down
Loading
Loading