Skip to content

Commit

Permalink
allows conversation state manager to wait on utterances sent to the c…
Browse files Browse the repository at this point in the history
…all (#308)

* allows conversation state manager to wait on utterances said to the call

* interruptibleagnetresponseworkers

* remove cast
  • Loading branch information
ajar98 authored Jul 27, 2023
1 parent af822cd commit e007d77
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 53 deletions.
2 changes: 1 addition & 1 deletion playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def sender():
None, lambda: input("Human: ")
)
agent.consume_nonblocking(
agent.interruptible_event_factory.create(
agent.interruptible_event_factory.create_interruptible_event(
TranscriptionAgentInput(
transcription=Transcription(
message=message, confidence=1.0, is_final=True
Expand Down
4 changes: 3 additions & 1 deletion playground/streaming/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ async def run_agents():
),
conversation_id=0,
)
agent.consume_nonblocking(agent.interruptible_event_factory.create(message))
agent.consume_nonblocking(
agent.interruptible_event_factory.create_interruptible_event(message)
)

while True:
try:
Expand Down
8 changes: 5 additions & 3 deletions vocode/streaming/action/base_action.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, Type, TypeVar
from typing import Any, Dict, Generic, Type, TypeVar, TYPE_CHECKING
from vocode.streaming.action.utils import exclude_keys_recursive
from vocode.streaming.models.actions import (
ActionConfig,
Expand All @@ -8,7 +8,9 @@
ParametersType,
ResponseType,
)
from vocode.streaming.utils.state_manager import ConversationStateManager

if TYPE_CHECKING:
from vocode.streaming.utils.state_manager import ConversationStateManager

ActionConfigType = TypeVar("ActionConfigType", bound=ActionConfig)

Expand All @@ -29,7 +31,7 @@ def __init__(
self.is_interruptible = is_interruptible

def attach_conversation_state_manager(
self, conversation_state_manager: ConversationStateManager
self, conversation_state_manager: "ConversationStateManager"
):
self.conversation_state_manager = conversation_state_manager

Expand Down
42 changes: 30 additions & 12 deletions vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import json
import logging
import random
from typing import AsyncGenerator, Generator, Generic, Optional, Tuple, TypeVar, Union
from typing import (
AsyncGenerator,
Generator,
Generic,
Optional,
Tuple,
TypeVar,
Union,
TYPE_CHECKING,
)
import typing
from opentelemetry import trace
from opentelemetry.trace import Span
Expand Down Expand Up @@ -33,13 +42,16 @@
from vocode.streaming.utils import remove_non_letters_digits
from vocode.streaming.utils.goodbye_model import GoodbyeModel
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
InterruptibleEventFactory,
InterruptibleWorker,
)

if TYPE_CHECKING:
from vocode.streaming.utils.state_manager import ConversationStateManager

tracer = trace.get_tracer(__name__)
AGENT_TRACE_NAME = "agent"

Expand Down Expand Up @@ -128,7 +140,7 @@ def __init__(
InterruptibleEvent[AgentInput]
] = asyncio.Queue()
self.output_queue: asyncio.Queue[
InterruptibleEvent[AgentResponse]
InterruptibleAgentResponseEvent[AgentResponse]
] = asyncio.Queue()
AbstractAgent.__init__(self, agent_config=agent_config)
InterruptibleWorker.__init__(
Expand Down Expand Up @@ -167,7 +179,7 @@ def attach_conversation_state_manager(
def start(self):
super().start()
if self.agent_config.initial_message is not None:
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=self.agent_config.initial_message),
is_interruptible=False,
)
Expand All @@ -180,7 +192,9 @@ def get_input_queue(
) -> asyncio.Queue[InterruptibleEvent[AgentInput]]:
return self.input_queue

def get_output_queue(self) -> asyncio.Queue[InterruptibleEvent[AgentResponse]]:
def get_output_queue(
self,
) -> asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]:
return self.output_queue

def create_goodbye_detection_task(self, message: str) -> asyncio.Task:
Expand Down Expand Up @@ -214,7 +228,7 @@ async def handle_generate_response(
if is_first_response:
agent_span_first.end()
is_first_response = False
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
Expand All @@ -240,7 +254,7 @@ async def handle_respond(
response = None
return True
if response:
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
Expand Down Expand Up @@ -288,7 +302,9 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
transcription.message
)
if self.agent_config.send_filler_audio:
self.produce_interruptible_event_nonblocking(AgentResponseFillerAudio())
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseFillerAudio()
)
self.logger.debug("Responding to transcription")
should_stop = False
if self.agent_config.generate_responses:
Expand All @@ -302,7 +318,9 @@ async def process(self, item: InterruptibleEvent[AgentInput]):

if should_stop:
self.logger.debug("Agent requested to stop")
self.produce_interruptible_event_nonblocking(AgentResponseStop())
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseStop()
)
return
if goodbye_detected_task:
try:
Expand All @@ -311,7 +329,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
)
if goodbye_detected:
self.logger.debug("Goodbye detected, ending conversation")
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseStop()
)
return
Expand Down Expand Up @@ -339,7 +357,7 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
params = json.loads(function_call.arguments)
if "user_message" in params:
user_message = params["user_message"]
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=user_message))
)
action_input: ActionInput
Expand All @@ -362,7 +380,7 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
agent_input.conversation_id,
params,
)
event = self.interruptible_event_factory.create(
event = self.interruptible_event_factory.create_interruptible_event(
action_input, is_interruptible=action.is_interruptible
)
assert self.transcript is not None
Expand Down
11 changes: 7 additions & 4 deletions vocode/streaming/agent/websocket_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import logging
from typing import Dict
from vocode.streaming.transcriber.base_transcriber import Transcription
from vocode.streaming.utils.worker import InterruptibleEvent
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
)
import websockets
from websockets.client import (
connect,
Expand Down Expand Up @@ -33,7 +36,7 @@

class WebSocketUserImplementedAgent(BaseAgent[WebSocketUserImplementedAgentConfig]):
input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]]
output_queue: asyncio.Queue[InterruptibleEvent[AgentResponse]]
output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]

def __init__(
self,
Expand Down Expand Up @@ -75,7 +78,7 @@ def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> Non
raise Exception("Unknown Socket message type")

self.logger.info("Putting interruptible agent response event in output queue")
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
agent_response, self.get_agent_config().allow_agent_to_be_cut_off
)

Expand Down Expand Up @@ -161,5 +164,5 @@ async def receiver(ws: WebSocketClientProtocol) -> None:
await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
self.output_queue.put_nowait(AgentResponseStop())
self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop())
super().terminate()
Loading

0 comments on commit e007d77

Please sign in to comment.