From 9daa9be09b41eb32d6976ea9637d97d308c051e9 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 12:04:35 -0700 Subject: [PATCH] fix mypy --- playground/streaming/agent/chat.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index e93b9eeda8..dd6992b51e 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -3,7 +3,7 @@ import typing from dotenv import load_dotenv -from pydantic.v1 import BaseModel +from pydantic import BaseModel from vocode.streaming.action.abstract_factory import AbstractActionFactory from vocode.streaming.action.base_action import BaseAction @@ -28,8 +28,9 @@ from vocode.streaming.agent import ChatGPTAgent from vocode.streaming.agent.base_agent import ( AgentResponse, + AgentResponseFillerAudio, AgentResponseMessage, - AgentResponseType, + AgentResponseStop, BaseAgent, TranscriptionAgentInput, ) @@ -39,7 +40,8 @@ BACKCHANNELS = ["Got it", "Sure", "Okay", "I understand"] -class ShoutActionConfig(ActionConfig, type="shout"): # type: ignore +class ShoutActionConfig(ActionConfig): + type: typing.Literal["shout"] = "shout" num_exclamation_marks: int @@ -114,16 +116,15 @@ async def receiver(): try: event = await agent_response_queue.get() response = event.payload - if response.type == AgentResponseType.FILLER_AUDIO: + if isinstance(response, AgentResponseFillerAudio): print("Would have sent filler audio") - elif response.type == AgentResponseType.STOP: + elif isinstance(response, AgentResponseStop): print("Agent returned stop") ended = True break - elif response.type == AgentResponseType.MESSAGE: - agent_response = typing.cast(AgentResponseMessage, response) + elif isinstance(response, AgentResponseMessage): - if isinstance(agent_response.message, EndOfTurn): + if isinstance(response.message, EndOfTurn): ignore_until_end_of_turn = False if random.random() < backchannel_probability: backchannel = random.choice(BACKCHANNELS) @@ -133,7 +134,7 @@ async def receiver(): conversation_id, is_backchannel=True, ) - elif isinstance(agent_response.message, BaseMessage): + elif isinstance(response.message, BaseMessage): if ignore_until_end_of_turn: continue @@ -141,12 +142,12 @@ async def receiver(): is_final: bool # TODO: consider allowing the user to interrupt the agent manually by responding fast if random.random() < interruption_probability: - stop_idx = random.randint(0, len(agent_response.message.text)) - message_sent = agent_response.message.text[:stop_idx] + stop_idx = random.randint(0, len(response.message.text)) + message_sent = response.message.text[:stop_idx] ignore_until_end_of_turn = True is_final = False else: - message_sent = agent_response.message.text + message_sent = response.message.text is_final = True agent.transcript.add_bot_message(