Skip to content

Commit

Permalink
add google ai agent
Browse files Browse the repository at this point in the history
  • Loading branch information
DanteNoguez committed Aug 5, 2024
1 parent adf0d87 commit 11931a0
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 3 deletions.
118 changes: 115 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ nltk = "^3.8.1"

# LLM Providers
groq = { version = "^0.9.0", optional = true }
google-generativeai = {version = "^0.7.2", optional = true}

# Synthesizers
google-cloud-texttospeech = { version = "^2.16.3", optional = true }
Expand Down Expand Up @@ -92,6 +93,7 @@ synthesizers = [
]
transcribers = ["google-cloud-speech"]
telephony = ["twilio", "vonage"]
llms = ["groq", "google-generativeai"]
langchain = ["langchain", "langchain-community"]
langchain-extras = ["langchain-openai", "langchain-anthropic", "langchain-google-vertexai"]
all = [
Expand All @@ -107,6 +109,7 @@ all = [
"langchain-google-vertexai",
"cartesia",
"groq",
"google-generativeai",
"livekit"
]

Expand Down
115 changes: 115 additions & 0 deletions vocode/streaming/agent/google_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import Any, AsyncGenerator, Dict

import google.generativeai as genai
from google.generativeai.generative_models import _USER_ROLE
from google.generativeai.types import content_types, generation_types
import grpc

grpc.aio.init_grpc_aio() # we initialize gRPC aio to avoid this issue: https://github.com/google-gemini/generative-ai-python/issues/207
import sentry_sdk
from loguru import logger

from vocode.streaming.action.abstract_factory import AbstractActionFactory
from vocode.streaming.action.default_factory import DefaultActionFactory
from vocode.streaming.agent.base_agent import GeneratedResponse, RespondAgent, StreamedResponse
from vocode.streaming.agent.streaming_utils import collate_response_async, stream_response_async
from vocode.streaming.models.agent import GoogleAIAgentConfig
from vocode.streaming.models.message import BaseMessage, LLMToken
from vocode.streaming.vector_db.factory import VectorDBFactory
from vocode.utils.sentry_utils import CustomSentrySpans, sentry_create_span


class GoogleAIAgent(RespondAgent[GoogleAIAgentConfig]):
genai_chat: genai.ChatSession

def __init__(
self,
agent_config: GoogleAIAgentConfig,
action_factory: AbstractActionFactory = DefaultActionFactory(),
vector_db_factory=VectorDBFactory(),
**kwargs,
):
super().__init__(
agent_config=agent_config,
action_factory=action_factory,
**kwargs,
)
if not os.environ.get("GOOGLE_AI_API_KEY"):
raise ValueError("GOOGLE_AI_API_KEY must be set in environment or passed in")
self.genai_config = genai.configure(api_key=os.environ.get("GOOGLE_AI_API_KEY"))
self.genai_model = genai.GenerativeModel(
model_name=agent_config.model_name,
generation_config=genai.GenerationConfig(
max_output_tokens=agent_config.max_tokens,
temperature=agent_config.temperature,
),
)
prompt_preamble = content_types.to_content(agent_config.prompt_preamble)
prompt_preamble.role = _USER_ROLE
self.genai_chat = self.genai_model.start_chat(history=[prompt_preamble])

async def _create_google_ai_stream(self, message: str):
return await self.genai_chat.send_message_async(message)

async def google_ai_get_tokens(
self, gen: generation_types.AsyncGenerateContentResponse
) -> AsyncGenerator[str, None]:
async for msg in gen:
yield msg.text

async def generate_response(
self,
human_input,
conversation_id: str,
is_interrupt: bool = False,
bot_was_in_medias_res: bool = False,
) -> AsyncGenerator[GeneratedResponse, None]:
if not self.transcript:
raise ValueError("A transcript is not attached to the agent")
try:
first_sentence_total_span = sentry_create_span(
sentry_callable=sentry_sdk.start_span, op=CustomSentrySpans.LLM_FIRST_SENTENCE_TOTAL
)

ttft_span = sentry_create_span(
sentry_callable=sentry_sdk.start_span, op=CustomSentrySpans.TIME_TO_FIRST_TOKEN
)
stream = await self._create_google_ai_stream(human_input)
except Exception as e:
logger.error(
f"Error while hitting Google AI with history: {self.genai_chat.history}",
exc_info=True,
)
raise e

response_generator = collate_response_async

using_input_streaming_synthesizer = (
self.conversation_state_manager.using_input_streaming_synthesizer()
)
if using_input_streaming_synthesizer:
response_generator = stream_response_async
async for message in response_generator(
conversation_id=conversation_id,
gen=self.google_ai_get_tokens(stream),
sentry_span=ttft_span,
):
if first_sentence_total_span:
first_sentence_total_span.finish()

ResponseClass = (
StreamedResponse if using_input_streaming_synthesizer else GeneratedResponse
)
MessageType = LLMToken if using_input_streaming_synthesizer else BaseMessage

if isinstance(message, str):
yield ResponseClass(
message=MessageType(text=message),
is_interruptible=True,
)
else:
yield ResponseClass(
message=message,
is_interruptible=True,
)
10 changes: 10 additions & 0 deletions vocode/streaming/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
ANTHROPIC_CLAUDE_3_HAIKU_MODEL_NAME = "claude-3-haiku-20240307"
ANTHROPIC_CLAUDE_3_SONNET_MODEL_NAME = "claude-3-sonnet-20240229"
ANTHROPIC_CLAUDE_3_OPUS_MODEL_NAME = "claude-3-opus-20240229"
GOOGLE_AI_GEMINI_FLASH_MODEL_NAME = "gemini-1.5-flash"
GOOGLE_AI_GEMINI_PRO_MODEL_NAME = "gemini-1.5-pro"
GROQ_DEFAULT_MODEL_NAME = "llama3-70b-8192"
GROQ_LLAMA3_8B_MODEL_NAME = "llama3-8b-8192"
GROQ_LLAMA3_70B_MODEL_NAME = "llama3-70b-8192"
Expand All @@ -51,6 +53,7 @@ class AgentType(str, Enum):
GPT4ALL = "agent_gpt4all"
LLAMACPP = "agent_llamacpp"
GROQ = "agent_groq"
GOOGLE_AI = "agent_google_ai"
INFORMATION_RETRIEVAL = "agent_information_retrieval"
RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented"
WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented"
Expand Down Expand Up @@ -168,6 +171,13 @@ class GroqAgentConfig(AgentConfig, type=AgentType.GROQ.value): # type: ignore
first_response_filler_message: Optional[str] = None


class GoogleAIAgentConfig(AgentConfig, type=AgentType.GOOGLE_AI.value): # type: ignore
prompt_preamble: str
model_name: str = GOOGLE_AI_GEMINI_FLASH_MODEL_NAME
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE


class InformationRetrievalAgentConfig(
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value # type: ignore
):
Expand Down

0 comments on commit 11931a0

Please sign in to comment.