Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Google Search Grounding static tool #147

Merged
merged 13 commits into from
Nov 15, 2024
Merged
7 changes: 6 additions & 1 deletion aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -36,9 +37,13 @@ def send_message_async(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> BisonPrompt:
tools.not_supported()
static_tools.not_supported()
return BisonPrompt.parse(messages)

@override
Expand Down
6 changes: 5 additions & 1 deletion aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
Expand All @@ -16,7 +17,10 @@
class ChatCompletionAdapter(ABC, Generic[P]):
@abstractmethod
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> P | UserError:
pass

Expand Down
32 changes: 26 additions & 6 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_5 import (
Gemini_1_5_Prompt,
)
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
Expand Down Expand Up @@ -105,14 +106,19 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiPrompt | UserError:
match self.deployment:
case ChatCompletionDeployment.GEMINI_PRO_1:
return await Gemini_1_0_Pro_Prompt.parse(tools, messages)
return await Gemini_1_0_Pro_Prompt.parse(
tools, static_tools, messages
)
case ChatCompletionDeployment.GEMINI_PRO_VISION_1:
return await Gemini_1_0_Pro_Vision_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case (
ChatCompletionDeployment.GEMINI_PRO_1_5_PREVIEW
Expand All @@ -122,7 +128,7 @@ async def parse_prompt(
| ChatCompletionDeployment.GEMINI_FLASH_1_5_V2
):
return await Gemini_1_5_Prompt.parse(
self.file_storage, tools, messages
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)
Expand All @@ -136,7 +142,7 @@ def _get_model(
parameters = create_generation_config(params) if params else None

if prompt is not None:
tools = prompt.tools.to_gemini_tools()
tools = prompt.to_gemini_tools() or None
tool_config = prompt.tools.to_gemini_tool_config()
system_instruction = cast(
List[str | Part | Image] | None,
Expand Down Expand Up @@ -190,6 +196,7 @@ async def process_chunks(
yield content

await create_function_calls(candidate, consumer, tools)
await create_grounding(candidate, consumer)
await create_attachments_from_citations(candidate, consumer)
await set_finish_reason(candidate, consumer)

Expand All @@ -211,7 +218,6 @@ async def chat(
)

completion = ""

async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
Expand Down Expand Up @@ -321,6 +327,20 @@ async def create_function_calls(
)


async def create_grounding(candidate: Candidate, consumer: Consumer) -> None:
if (
not candidate.grounding_metadata
or not candidate.grounding_metadata.grounding_chunks
):
return

for chunk in candidate.grounding_metadata.grounding_chunks:
if chunk.web and chunk.web.uri:
await consumer.add_attachment(
Attachment(url=chunk.web.uri, title=chunk.web.title)
)


def to_openai_finish_reason(
finish_reason: GenFinishReason, retriable: bool
) -> FinishReason | None:
Expand Down
11 changes: 11 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel, Field
from vertexai.preview.generative_models import Content, Part
from vertexai.preview.generative_models import Tool as GeminiTool

from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatablePrompt

Expand All @@ -20,6 +22,9 @@ class GeminiPrompt(BaseModel, TruncatablePrompt, ABC):
system_instruction: List[Part] | None = None
contents: List[Content]
tools: ToolsConfig = Field(default_factory=ToolsConfig.noop)
static_tools: StaticToolsConfig = Field(
default_factory=StaticToolsConfig.noop
)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -68,4 +73,10 @@ def select(self, indices: Set[int]) -> "GeminiPrompt":
system_instruction=system_instruction,
contents=contents,
tools=self.tools,
static_tools=self.static_tools,
)

def to_gemini_tools(self) -> List[GeminiTool]:
regular_tools = self.tools.to_gemini_tools()
static_tools = self.static_tools.to_gemini_tools()
return regular_tools + static_tools
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
)
from aidial_adapter_vertexai.chat.gemini.processor import AttachmentProcessors
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig


class Gemini_1_0_Pro_Prompt(GeminiPrompt):
@classmethod
async def parse(
cls, tools: ToolsConfig, messages: List[Message]
cls,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
raise ValidationError(
Expand All @@ -34,4 +38,5 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.request import get_attachments
from aidial_adapter_vertexai.dial_api.storage import FileStorage
Expand All @@ -28,9 +29,11 @@ async def parse(
cls,
file_storage: FileStorage | None,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Union[Self, UserError]:
tools.not_supported()
static_tools.not_supported()

if len(messages) == 0:
raise ValidationError(
Expand Down
3 changes: 3 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/prompt/gemini_1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_video_processor,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.dial_api.storage import FileStorage

Expand All @@ -25,6 +26,7 @@ async def parse(
cls,
file_storage: Optional[FileStorage],
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> Self | UserError:
if len(messages) == 0:
Expand Down Expand Up @@ -55,6 +57,7 @@ async def parse(
system_instruction=conversation.system_instruction,
contents=conversation.contents,
tools=tools,
static_tools=static_tools,
)


Expand Down
8 changes: 6 additions & 2 deletions aidial_adapter_vertexai/chat/imagen/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import (
Expand Down Expand Up @@ -43,10 +44,13 @@ def __init__(

@override
async def parse_prompt(
self, tools: ToolsConfig, messages: List[Message]
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> ImagenPrompt:
tools.not_supported()

static_tools.not_supported()
if len(messages) == 0:
raise ValidationError("The list of messages must not be empty")

Expand Down
87 changes: 87 additions & 0 deletions aidial_adapter_vertexai/chat/static_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, NoReturn, Self

from aidial_sdk.chat_completion.request import (
AzureChatCompletionRequest,
StaticFunction,
StaticTool,
)
from pydantic import BaseModel
from vertexai.preview.generative_models import Tool as GeminiTool
from vertexai.preview.generative_models import grounding

from aidial_adapter_vertexai.chat.errors import ValidationError


class ToolName(str, Enum):
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding
GOOGLE_SEARCH = "google_search"


class StaticToolProcessor(ABC):
@staticmethod
@abstractmethod
def parse_gemini_tools(
static_function: StaticFunction,
) -> List[GeminiTool] | None: ...


class GoogleSearchGroundingTool(StaticToolProcessor):
@staticmethod
def parse_gemini_tools(
static_function: StaticFunction,
) -> List[GeminiTool] | None:
if static_function.name == ToolName.GOOGLE_SEARCH:
if static_function.configuration:
raise ValidationError(
"Google search tool doesn't support configuration"
)
return [
GeminiTool.from_google_search_retrieval(
grounding.GoogleSearchRetrieval()
)
]
return None


def unknown_tool_name(
static_function: StaticFunction,
) -> NoReturn:
raise ValidationError(
f"Unsupported static function: {static_function.name}"
)


class StaticToolsConfig(BaseModel):
functions: List[StaticFunction]

@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self:
if request.tools is None:
return cls(functions=[])

return cls(
functions=[
tool.static_function
for tool in request.tools
if isinstance(tool, StaticTool)
]
)

@classmethod
def noop(cls) -> Self:
return cls(functions=[])

def to_gemini_tools(self) -> List[GeminiTool]:
ret: List[GeminiTool] = []
for tool in self.functions:
ret.extend(
GoogleSearchGroundingTool.parse_gemini_tools(tool)
or unknown_tool_name(tool)
)
return ret

def not_supported(self) -> None:
if self.functions:
raise ValidationError("Static tools aren't supported")
12 changes: 8 additions & 4 deletions aidial_adapter_vertexai/chat/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Role,
ToolChoice,
)
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest, Tool
from pydantic import BaseModel
from vertexai.preview.generative_models import (
FunctionDeclaration as GeminiFunction,
Expand Down Expand Up @@ -117,7 +117,11 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:
tool_ids = None

elif request.tools is not None:
functions = [tool.function for tool in request.tools]
functions = [
tool.function
for tool in request.tools
if isinstance(tool, Tool)
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
Expand All @@ -137,9 +141,9 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self:

return cls(functions=selected, required=required, tool_ids=tool_ids)

def to_gemini_tools(self) -> List[GeminiTool] | None:
def to_gemini_tools(self) -> List[GeminiTool]:
if not self.functions:
return None
return []
adubovik marked this conversation as resolved.
Show resolved Hide resolved

return [
GeminiTool(
Expand Down
Loading
Loading