From 11931a0a2bac667e7fb1377c3ded7adb7641c35b Mon Sep 17 00:00:00 2001 From: Dante Noguez Date: Mon, 5 Aug 2024 17:47:38 -0600 Subject: [PATCH 1/2] add google ai agent --- poetry.lock | 118 ++++++++++++++++++++++++- pyproject.toml | 3 + vocode/streaming/agent/google_agent.py | 115 ++++++++++++++++++++++++ vocode/streaming/models/agent.py | 10 +++ 4 files changed, 243 insertions(+), 3 deletions(-) create mode 100644 vocode/streaming/agent/google_agent.py diff --git a/poetry.lock b/poetry.lock index e6030148a..dfb8ea68e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -979,6 +979,23 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.6" +description = "Google Ai Generativelanguage API client library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-ai-generativelanguage-0.6.6.tar.gz", hash = "sha256:1739f035caeeeca5c28f887405eec8690f3372daf79fecf26454a97a4f1733a8"}, + {file = "google_ai_generativelanguage-0.6.6-py3-none-any.whl", hash = "sha256:59297737931f073d55ce1268dcc6d95111ee62850349d2b6cde942b16a4fca5c"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" + [[package]] name = "google-api-core" version = "2.19.1" @@ -1010,6 +1027,24 @@ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +[[package]] +name = "google-api-python-client" +version = "2.139.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_api_python_client-2.139.0-py2.py3-none-any.whl", hash = "sha256:1850a92505d91a82e2ca1635ab2b8dff179f4b67082c2651e1db332e8039840c"}, + {file = "google_api_python_client-2.139.0.tar.gz", hash = "sha256:ed4bc3abe2c060a87412465b4e8254620bbbc548eefc5388e2c5ff912d36a68b"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + [[package]] name = "google-auth" version = "2.32.0" @@ -1033,6 +1068,21 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + [[package]] name = "google-cloud-aiplatform" version = "1.59.0" @@ -1284,6 +1334,29 @@ files = [ [package.extras] testing = ["pytest"] +[[package]] +name = "google-generativeai" +version = "0.7.2" +description = "Google Generative AI High level API client library and tools." +optional = true +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.7.2-py3-none-any.whl", hash = "sha256:3117d1ebc92ee77710d4bc25ab4763492fddce9b6332eb25d124cf5d8b78b339"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.6" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "google-resumable-media" version = "2.7.1" @@ -1532,6 +1605,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<0.26.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httptools" version = "0.6.1" @@ -2988,6 +3075,20 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = true +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pytest" version = "8.2.2" @@ -3177,7 +3278,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4131,6 +4231,17 @@ files = [ mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.2.2" @@ -4646,9 +4757,10 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["cartesia", "google-cloud-speech", "google-cloud-texttospeech", "groq", "langchain", "langchain-anthropic", "langchain-community", "langchain-google-vertexai", "langchain-openai", "livekit", "pvkoala", "twilio", "vonage"] +all = ["cartesia", "google-cloud-speech", "google-cloud-texttospeech", "google-generativeai", "groq", "langchain", "langchain-anthropic", "langchain-community", "langchain-google-vertexai", "langchain-openai", "livekit", "pvkoala", "twilio", "vonage"] langchain = ["langchain", "langchain-community"] langchain-extras = ["langchain-anthropic", "langchain-google-vertexai", "langchain-openai"] +llms = ["google-generativeai", "groq"] synthesizers = ["cartesia", "google-cloud-texttospeech", "pvkoala"] telephony = ["twilio", "vonage"] transcribers = ["google-cloud-speech"] @@ -4656,4 +4768,4 @@ transcribers = ["google-cloud-speech"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "645eeaebdbc1191899f3c62d06923897d88986d7f3f0010724f974bee266a9f8" +content-hash = "81f8bafde8cb0c5f908348dbc0b44ecddfaf955bc61b221e3e3680f270e47a84" diff --git a/pyproject.toml b/pyproject.toml index 0a5184e45..f52ec7015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ nltk = "^3.8.1" # LLM Providers groq = { version = "^0.9.0", optional = true } +google-generativeai = {version = "^0.7.2", optional = true} # Synthesizers google-cloud-texttospeech = { version = "^2.16.3", optional = true } @@ -92,6 +93,7 @@ synthesizers = [ ] transcribers = ["google-cloud-speech"] telephony = ["twilio", "vonage"] +llms = ["groq", "google-generativeai"] langchain = ["langchain", "langchain-community"] langchain-extras = ["langchain-openai", "langchain-anthropic", "langchain-google-vertexai"] all = [ @@ -107,6 +109,7 @@ all = [ "langchain-google-vertexai", "cartesia", "groq", + "google-generativeai", "livekit" ] diff --git a/vocode/streaming/agent/google_agent.py b/vocode/streaming/agent/google_agent.py new file mode 100644 index 000000000..789f0052a --- /dev/null +++ b/vocode/streaming/agent/google_agent.py @@ -0,0 +1,115 @@ +import os +from typing import Any, AsyncGenerator, Dict + +import google.generativeai as genai +from google.generativeai.generative_models import _USER_ROLE +from google.generativeai.types import content_types, generation_types +import grpc + +grpc.aio.init_grpc_aio() # we initialize gRPC aio to avoid this issue: https://github.com/google-gemini/generative-ai-python/issues/207 +import sentry_sdk +from loguru import logger + +from vocode.streaming.action.abstract_factory import AbstractActionFactory +from vocode.streaming.action.default_factory import DefaultActionFactory +from vocode.streaming.agent.base_agent import GeneratedResponse, RespondAgent, StreamedResponse +from vocode.streaming.agent.streaming_utils import collate_response_async, stream_response_async +from vocode.streaming.models.agent import GoogleAIAgentConfig +from vocode.streaming.models.message import BaseMessage, LLMToken +from vocode.streaming.vector_db.factory import VectorDBFactory +from vocode.utils.sentry_utils import CustomSentrySpans, sentry_create_span + + +class GoogleAIAgent(RespondAgent[GoogleAIAgentConfig]): + genai_chat: genai.ChatSession + + def __init__( + self, + agent_config: GoogleAIAgentConfig, + action_factory: AbstractActionFactory = DefaultActionFactory(), + vector_db_factory=VectorDBFactory(), + **kwargs, + ): + super().__init__( + agent_config=agent_config, + action_factory=action_factory, + **kwargs, + ) + if not os.environ.get("GOOGLE_AI_API_KEY"): + raise ValueError("GOOGLE_AI_API_KEY must be set in environment or passed in") + self.genai_config = genai.configure(api_key=os.environ.get("GOOGLE_AI_API_KEY")) + self.genai_model = genai.GenerativeModel( + model_name=agent_config.model_name, + generation_config=genai.GenerationConfig( + max_output_tokens=agent_config.max_tokens, + temperature=agent_config.temperature, + ), + ) + prompt_preamble = content_types.to_content(agent_config.prompt_preamble) + prompt_preamble.role = _USER_ROLE + self.genai_chat = self.genai_model.start_chat(history=[prompt_preamble]) + + async def _create_google_ai_stream(self, message: str): + return await self.genai_chat.send_message_async(message) + + async def google_ai_get_tokens( + self, gen: generation_types.AsyncGenerateContentResponse + ) -> AsyncGenerator[str, None]: + async for msg in gen: + yield msg.text + + async def generate_response( + self, + human_input, + conversation_id: str, + is_interrupt: bool = False, + bot_was_in_medias_res: bool = False, + ) -> AsyncGenerator[GeneratedResponse, None]: + if not self.transcript: + raise ValueError("A transcript is not attached to the agent") + try: + first_sentence_total_span = sentry_create_span( + sentry_callable=sentry_sdk.start_span, op=CustomSentrySpans.LLM_FIRST_SENTENCE_TOTAL + ) + + ttft_span = sentry_create_span( + sentry_callable=sentry_sdk.start_span, op=CustomSentrySpans.TIME_TO_FIRST_TOKEN + ) + stream = await self._create_google_ai_stream(human_input) + except Exception as e: + logger.error( + f"Error while hitting Google AI with history: {self.genai_chat.history}", + exc_info=True, + ) + raise e + + response_generator = collate_response_async + + using_input_streaming_synthesizer = ( + self.conversation_state_manager.using_input_streaming_synthesizer() + ) + if using_input_streaming_synthesizer: + response_generator = stream_response_async + async for message in response_generator( + conversation_id=conversation_id, + gen=self.google_ai_get_tokens(stream), + sentry_span=ttft_span, + ): + if first_sentence_total_span: + first_sentence_total_span.finish() + + ResponseClass = ( + StreamedResponse if using_input_streaming_synthesizer else GeneratedResponse + ) + MessageType = LLMToken if using_input_streaming_synthesizer else BaseMessage + + if isinstance(message, str): + yield ResponseClass( + message=MessageType(text=message), + is_interruptible=True, + ) + else: + yield ResponseClass( + message=message, + is_interruptible=True, + ) diff --git a/vocode/streaming/models/agent.py b/vocode/streaming/models/agent.py index aa9b63c8f..3cdc834e0 100644 --- a/vocode/streaming/models/agent.py +++ b/vocode/streaming/models/agent.py @@ -31,6 +31,8 @@ ANTHROPIC_CLAUDE_3_HAIKU_MODEL_NAME = "claude-3-haiku-20240307" ANTHROPIC_CLAUDE_3_SONNET_MODEL_NAME = "claude-3-sonnet-20240229" ANTHROPIC_CLAUDE_3_OPUS_MODEL_NAME = "claude-3-opus-20240229" +GOOGLE_AI_GEMINI_FLASH_MODEL_NAME = "gemini-1.5-flash" +GOOGLE_AI_GEMINI_PRO_MODEL_NAME = "gemini-1.5-pro" GROQ_DEFAULT_MODEL_NAME = "llama3-70b-8192" GROQ_LLAMA3_8B_MODEL_NAME = "llama3-8b-8192" GROQ_LLAMA3_70B_MODEL_NAME = "llama3-70b-8192" @@ -51,6 +53,7 @@ class AgentType(str, Enum): GPT4ALL = "agent_gpt4all" LLAMACPP = "agent_llamacpp" GROQ = "agent_groq" + GOOGLE_AI = "agent_google_ai" INFORMATION_RETRIEVAL = "agent_information_retrieval" RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented" WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented" @@ -168,6 +171,13 @@ class GroqAgentConfig(AgentConfig, type=AgentType.GROQ.value): # type: ignore first_response_filler_message: Optional[str] = None +class GoogleAIAgentConfig(AgentConfig, type=AgentType.GOOGLE_AI.value): # type: ignore + prompt_preamble: str + model_name: str = GOOGLE_AI_GEMINI_FLASH_MODEL_NAME + max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS + temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE + + class InformationRetrievalAgentConfig( AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value # type: ignore ): From 487ff071129862d53996e8ac2fd61917b8317cc4 Mon Sep 17 00:00:00 2001 From: Dante Noguez Date: Mon, 5 Aug 2024 17:54:25 -0600 Subject: [PATCH 2/2] fix isort --- vocode/streaming/agent/google_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vocode/streaming/agent/google_agent.py b/vocode/streaming/agent/google_agent.py index 789f0052a..bd8e819ad 100644 --- a/vocode/streaming/agent/google_agent.py +++ b/vocode/streaming/agent/google_agent.py @@ -2,9 +2,9 @@ from typing import Any, AsyncGenerator, Dict import google.generativeai as genai +import grpc from google.generativeai.generative_models import _USER_ROLE from google.generativeai.types import content_types, generation_types -import grpc grpc.aio.init_grpc_aio() # we initialize gRPC aio to avoid this issue: https://github.com/google-gemini/generative-ai-python/issues/207 import sentry_sdk