diff --git a/.vscode/settings.json b/.vscode/settings.json index 5353050c8..03e207285 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,5 +12,6 @@ "rewrap.wrappingColumn": 100, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "isort.check": true + "isort.check": true, + "python.testing.pytestArgs": ["tests"] } diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index 659443c9a..e93b9eeda 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -21,12 +21,13 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.transcript import Transcript from vocode.streaming.utils.state_manager import AbstractConversationStateManager -from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent +from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent, QueueConsumer load_dotenv() from vocode.streaming.agent import ChatGPTAgent from vocode.streaming.agent.base_agent import ( + AgentResponse, AgentResponseMessage, AgentResponseType, BaseAgent, @@ -96,6 +97,11 @@ async def run_agent( ): ended = False conversation_id = create_conversation_id() + agent_response_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = ( + asyncio.Queue() + ) + agent_consumer = QueueConsumer(input_queue=agent_response_queue) + agent.agent_responses_consumer = agent_consumer async def receiver(): nonlocal ended @@ -106,7 +112,7 @@ async def receiver(): while not ended: try: - event = await agent.get_output_queue().get() + event = await agent_response_queue.get() response = event.payload if response.type == AgentResponseType.FILLER_AUDIO: print("Would have sent filler audio") @@ -152,6 +158,13 @@ async def receiver(): break async def sender(): + if agent.agent_config.initial_message is not None: + agent.agent_responses_consumer.consume_nonblocking( + InterruptibleAgentResponseEvent( + payload=AgentResponseMessage(message=agent.agent_config.initial_message), + agent_response_tracker=asyncio.Event(), + ) + ) while not ended: try: message = await asyncio.get_event_loop().run_in_executor( @@ -175,10 +188,10 @@ async def sender(): actions_worker = None if isinstance(agent, ChatGPTAgent): actions_worker = ActionsWorker( - input_queue=agent.actions_queue, - output_queue=agent.get_input_queue(), action_factory=agent.action_factory, ) + actions_worker.consumer = agent + agent.actions_consumer = actions_worker actions_worker.attach_conversation_state_manager(agent.conversation_state_manager) actions_worker.start() @@ -215,13 +228,6 @@ async def agent_main(): ) agent.attach_conversation_state_manager(DummyConversationManager()) agent.attach_transcript(transcript) - if agent.agent_config.initial_message is not None: - agent.output_queue.put_nowait( - InterruptibleAgentResponseEvent( - payload=AgentResponseMessage(message=agent.agent_config.initial_message), - agent_response_tracker=asyncio.Event(), - ) - ) agent.start() try: diff --git a/playground/streaming/transcriber/transcribe.py b/playground/streaming/transcriber/transcribe.py index 111c27246..d7a097df3 100644 --- a/playground/streaming/transcriber/transcribe.py +++ b/playground/streaming/transcriber/transcribe.py @@ -1,3 +1,4 @@ +from vocode.streaming.input_device.file_input_device import FileInputDevice from vocode.streaming.input_device.microphone_input import MicrophoneInput from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, Transcription from vocode.streaming.transcriber.base_transcriber import BaseTranscriber @@ -5,6 +6,15 @@ DeepgramEndpointingConfig, DeepgramTranscriber, ) +from vocode.streaming.utils.worker import AsyncWorker + + +class TranscriptionPrinter(AsyncWorker[Transcription]): + async def _run_loop(self): + while True: + transcription: Transcription = await self._input_queue.get() + print(transcription) + if __name__ == "__main__": import asyncio @@ -13,25 +23,23 @@ load_dotenv() - async def print_output(transcriber: BaseTranscriber): - while True: - transcription: Transcription = await transcriber.output_queue.get() - print(transcription) - async def listen(): - microphone_input = MicrophoneInput.from_default_device() + input_device = MicrophoneInput.from_default_device() + # input_device = FileInputDevice(file_path="spacewalk.wav") # replace with the transcriber you want to test transcriber = DeepgramTranscriber( DeepgramTranscriberConfig.from_input_device( - microphone_input, endpointing_config=DeepgramEndpointingConfig() + input_device, endpointing_config=DeepgramEndpointingConfig() ) ) transcriber.start() - asyncio.create_task(print_output(transcriber)) + transcription_printer = TranscriptionPrinter() + transcriber.consumer = transcription_printer + transcription_printer.start() print("Start speaking...press Ctrl+C to end. ") while True: - chunk = await microphone_input.get_audio() + chunk = await input_device.get_audio() transcriber.send_audio(chunk) asyncio.run(listen()) diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index 69e407d33..890dcb44e 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Optional from pytest_mock import MockerFixture @@ -52,6 +51,7 @@ def __init__( self.wait_for_interrupt = wait_for_interrupt self.chunks_before_interrupt = chunks_before_interrupt self.interrupt_event = asyncio.Event() + self.dummy_playback_queue = asyncio.Queue() async def process(self, item): self.interruptible_event = item @@ -61,6 +61,7 @@ async def process(self, item): audio_chunk.on_interrupt() audio_chunk.state = ChunkState.INTERRUPTED else: + self.dummy_playback_queue.put_nowait(audio_chunk) audio_chunk.on_play() audio_chunk.state = ChunkState.PLAYED self.interruptible_event.is_interruptible = False @@ -69,7 +70,7 @@ async def _run_loop(self): chunk_counter = 0 while True: try: - item = await self.input_queue.get() + item = await self._input_queue.get() except asyncio.CancelledError: return if self.wait_for_interrupt and chunk_counter == self.chunks_before_interrupt: @@ -80,7 +81,7 @@ async def _run_loop(self): def flush(self): while True: try: - item = self.input_queue.get_nowait() + item = self._input_queue.get_nowait() except asyncio.QueueEmpty: break self.process(item) diff --git a/tests/fixtures/synthesizer.py b/tests/fixtures/synthesizer.py new file mode 100644 index 000000000..6181db2a2 --- /dev/null +++ b/tests/fixtures/synthesizer.py @@ -0,0 +1,50 @@ +import wave +from io import BytesIO + +from vocode.streaming.models.message import BaseMessage +from vocode.streaming.models.synthesizer import SynthesizerConfig +from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer, SynthesisResult + + +def create_fake_audio(message: str, synthesizer_config: SynthesizerConfig): + file = BytesIO() + with wave.open(file, "wb") as wave_file: + wave_file.setnchannels(1) + wave_file.setsampwidth(2) + wave_file.setframerate(synthesizer_config.sampling_rate) + wave_file.writeframes(message.encode()) + file.seek(0) + return file + + +class TestSynthesizerConfig(SynthesizerConfig, type="synthesizer_test"): + __test__ = False + + +class TestSynthesizer(BaseSynthesizer[TestSynthesizerConfig]): + """Accepts text and creates a SynthesisResult containing audio data which is the same as the text as bytes.""" + + __test__ = False + + def __init__(self, synthesizer_config: SynthesizerConfig): + super().__init__(synthesizer_config) + + async def create_speech_uncached( + self, + message: BaseMessage, + chunk_size: int, + is_first_text_chunk: bool = False, + is_sole_text_chunk: bool = False, + ) -> SynthesisResult: + return self.create_synthesis_result_from_wav( + synthesizer_config=self.synthesizer_config, + message=message, + chunk_size=chunk_size, + file=create_fake_audio( + message=message.text, synthesizer_config=self.synthesizer_config + ), + ) + + @classmethod + def get_voice_identifier(cls, synthesizer_config: TestSynthesizerConfig) -> str: + return "test_voice" diff --git a/tests/fixtures/transcriber.py b/tests/fixtures/transcriber.py new file mode 100644 index 000000000..4c1b2e9f4 --- /dev/null +++ b/tests/fixtures/transcriber.py @@ -0,0 +1,24 @@ +import asyncio + +from vocode.streaming.models.transcriber import TranscriberConfig +from vocode.streaming.transcriber.base_transcriber import BaseAsyncTranscriber, Transcription + + +class TestTranscriberConfig(TranscriberConfig, type="transcriber_test"): + __test__ = False + + +class TestAsyncTranscriber(BaseAsyncTranscriber[TestTranscriberConfig]): + """Accepts fake audio chunks and sends out transcriptions which are the same as the audio chunks.""" + + __test__ = False + + async def _run_loop(self): + while True: + try: + audio_chunk = await self._input_queue.get() + self.produce_nonblocking( + Transcription(message=audio_chunk.decode("utf-8"), confidence=1, is_final=True) + ) + except asyncio.CancelledError: + return diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index 972040579..26519ad14 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -19,7 +19,11 @@ from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import Transcript from vocode.streaming.utils.state_manager import ConversationStateManager -from vocode.streaming.utils.worker import InterruptibleEvent +from vocode.streaming.utils.worker import ( + InterruptibleAgentResponseEvent, + InterruptibleEvent, + QueueConsumer, +) @pytest.fixture(autouse=True) @@ -51,11 +55,16 @@ def _create_agent( return agent -async def _consume_until_end_of_turn(agent: BaseAgent, timeout: float = 0.1) -> List[AgentResponse]: +async def _consume_until_end_of_turn( + agent_consumer: QueueConsumer[InterruptibleAgentResponseEvent[AgentResponse]], + timeout: float = 0.1, +) -> List[AgentResponse]: agent_responses = [] try: while True: - agent_response = await asyncio.wait_for(agent.output_queue.get(), timeout=timeout) + agent_response = await asyncio.wait_for( + agent_consumer.input_queue.get(), timeout=timeout + ) agent_responses.append(agent_response.payload) if isinstance(agent_response.payload, AgentResponseMessage) and isinstance( agent_response.payload.message, EndOfTurn @@ -127,37 +136,10 @@ async def test_generate_responses(mocker: MockerFixture): agent, Transcription(message="Hello?", confidence=1.0, is_final=True), ) + agent_consumer = QueueConsumer() + agent.agent_responses_consumer = agent_consumer agent.start() - agent_responses = await _consume_until_end_of_turn(agent) - agent.terminate() - - messages = [response.message for response in agent_responses] - - assert messages == [BaseMessage(text="Hi, how are you doing today?"), EndOfTurn()] - - -@pytest.mark.asyncio -async def test_generate_response(mocker: MockerFixture): - agent_config = ChatGPTAgentConfig( - prompt_preamble="Have a pleasant conversation about life", - generate_responses=True, - ) - agent = _create_agent(mocker, agent_config) - _mock_generate_response( - mocker, - agent, - [ - GeneratedResponse( - message=BaseMessage(text="Hi, how are you doing today?"), is_interruptible=True - ) - ], - ) - _send_transcription( - agent, - Transcription(message="Hello?", confidence=1.0, is_final=True), - ) - agent.start() - agent_responses = await _consume_until_end_of_turn(agent) + agent_responses = await _consume_until_end_of_turn(agent_consumer) agent.terminate() messages = [response.message for response in agent_responses] diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 1d4ef026d..2266a7e47 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -9,16 +9,23 @@ from tests.fakedata.conversation import ( DEFAULT_CHAT_GPT_AGENT_CONFIG, + DummyOutputDevice, create_fake_agent, create_fake_streaming_conversation, ) +from tests.fixtures.synthesizer import TestSynthesizer, TestSynthesizerConfig +from tests.fixtures.transcriber import TestAsyncTranscriber, TestTranscriberConfig +from vocode.streaming.agent.echo_agent import EchoAgent from vocode.streaming.models.actions import ActionInput -from vocode.streaming.models.agent import InterruptSensitivity +from vocode.streaming.models.agent import EchoAgentConfig, InterruptSensitivity +from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.events import Sender +from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import ActionStart, Message, Transcript +from vocode.streaming.streaming_conversation import StreamingConversation from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult -from vocode.streaming.utils.worker import AsyncWorker +from vocode.streaming.utils.worker import QueueConsumer class ShouldIgnoreUtteranceTestCase(BaseModel): @@ -27,9 +34,9 @@ class ShouldIgnoreUtteranceTestCase(BaseModel): expected: bool -async def _consume_worker_output(worker: AsyncWorker, timeout: float = 0.1): +async def _get_from_consumer_queue_if_exists(queue_consumer: QueueConsumer, timeout: float = 0.1): try: - return await asyncio.wait_for(worker.output_queue.get(), timeout=timeout) + return await asyncio.wait_for(queue_consumer.input_queue.get(), timeout=timeout) except asyncio.TimeoutError: return None @@ -174,8 +181,6 @@ def test_should_ignore_utterance( conversation = mocker.MagicMock() transcriptions_worker = StreamingConversation.TranscriptionsWorker( - input_queue=mocker.MagicMock(), - output_queue=mocker.MagicMock(), conversation=conversation, interruptible_event_factory=mocker.MagicMock(), ) @@ -229,9 +234,8 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( mocker, ) - streaming_conversation.transcriptions_worker.input_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.output_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.start() + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer streaming_conversation.transcriptions_worker.consume_nonblocking( Transcription( message="sup", @@ -253,7 +257,8 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( is_final=True, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + streaming_conversation.transcriptions_worker.start() + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called streaming_conversation.transcript.add_bot_message( @@ -269,8 +274,8 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert transcription_agent_input.payload.transcription.message == "hi, who is this?" assert streaming_conversation.broadcast_interrupt.called @@ -293,9 +298,6 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( mocker, ) - streaming_conversation.transcriptions_worker.input_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.output_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.start() streaming_conversation.initial_message_tracker.set() streaming_conversation.transcript.add_bot_message( text="Hi, I was wondering", @@ -303,6 +305,10 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( conversation_id="test", ) + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + streaming_conversation.transcriptions_worker.start() + streaming_conversation.transcriptions_worker.consume_nonblocking( Transcription( message="i'm listening", @@ -310,7 +316,8 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( is_final=False, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response streaming_conversation.transcript.event_logs[-1].text = ( @@ -325,7 +332,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -336,8 +343,8 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert ( transcription_agent_input.payload.transcription.message == "I have not yet gotten a chance." @@ -359,9 +366,6 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( mocker, ) - streaming_conversation.transcriptions_worker.input_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.output_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.start() streaming_conversation.initial_message_tracker.set() streaming_conversation.transcript.add_bot_message( text="Hi, I was wondering", @@ -369,6 +373,9 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( conversation_id="test", ) + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + streaming_conversation.transcriptions_worker.start() streaming_conversation.transcriptions_worker.consume_nonblocking( Transcription( message="Sorry, could you stop", @@ -377,7 +384,7 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -388,8 +395,8 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert ( transcription_agent_input.payload.transcription.message @@ -409,11 +416,11 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun mocker, ) - streaming_conversation.transcriptions_worker.input_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.output_queue = asyncio.Queue() - streaming_conversation.transcriptions_worker.start() streaming_conversation.initial_message_tracker.set() + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + streaming_conversation.transcriptions_worker.start() streaming_conversation.transcriptions_worker.consume_nonblocking( Transcription( message="Sorry,", @@ -421,7 +428,8 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun is_final=False, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -431,8 +439,8 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun is_final=True, ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert transcription_agent_input.payload.transcription.message == "Sorry, what?" assert streaming_conversation.broadcast_interrupt.called @@ -449,7 +457,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.terminate() @@ -574,3 +582,31 @@ async def chunk_generator(): assert cut_off assert transcript_message.text != "Hi there" assert not transcript_message.is_final + + +@pytest.mark.asyncio +async def test_streaming_conversation_pipeline( + mocker: MockerFixture, +): + output_device = DummyOutputDevice(sampling_rate=48000, audio_encoding=AudioEncoding.LINEAR16) + streaming_conversation = StreamingConversation( + output_device=output_device, + transcriber=TestAsyncTranscriber( + TestTranscriberConfig( + sampling_rate=48000, + audio_encoding=AudioEncoding.LINEAR16, + chunk_size=480, + ) + ), + agent=EchoAgent( + EchoAgentConfig(initial_message=BaseMessage(text="Hi there")), + ), + synthesizer=TestSynthesizer(TestSynthesizerConfig.from_output_device(output_device)), + ) + await streaming_conversation.start() + await streaming_conversation.initial_message_tracker.wait() + streaming_conversation.receive_audio(b"test") + initial_message_audio_chunk = await output_device.dummy_playback_queue.get() + assert initial_message_audio_chunk.data == b"Hi there" + first_response_audio_chunk = await output_device.dummy_playback_queue.get() + assert first_response_audio_chunk.data == b"test" diff --git a/vocode/streaming/action/worker.py b/vocode/streaming/action/worker.py index 26c51b240..b256489cb 100644 --- a/vocode/streaming/action/worker.py +++ b/vocode/streaming/action/worker.py @@ -12,6 +12,7 @@ ) from vocode.streaming.utils.state_manager import AbstractConversationStateManager from vocode.streaming.utils.worker import ( + AbstractWorker, InterruptibleEvent, InterruptibleEventFactory, InterruptibleWorker, @@ -19,16 +20,14 @@ class ActionsWorker(InterruptibleWorker): + consumer: AbstractWorker[InterruptibleEvent[ActionResultAgentInput]] + def __init__( self, action_factory: AbstractActionFactory, - input_queue: asyncio.Queue[InterruptibleEvent[ActionInput]], - output_queue: asyncio.Queue[InterruptibleEvent[AgentInput]], interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), ): super().__init__( - input_queue=input_queue, - output_queue=output_queue, interruptible_event_factory=interruptible_event_factory, ) self.action_factory = action_factory @@ -43,22 +42,24 @@ async def process(self, item: InterruptibleEvent[ActionInput]): action = self.action_factory.create_action(action_input.action_config) action.attach_conversation_state_manager(self.conversation_state_manager) action_output = await action.run(action_input) - self.produce_interruptible_event_nonblocking( - ActionResultAgentInput( - conversation_id=action_input.conversation_id, - action_input=action_input, - action_output=action_output, - vonage_uuid=( - action_input.vonage_uuid - if isinstance(action_input, VonagePhoneConversationActionInput) - else None - ), - twilio_sid=( - action_input.twilio_sid - if isinstance(action_input, TwilioPhoneConversationActionInput) - else None + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_event( + ActionResultAgentInput( + conversation_id=action_input.conversation_id, + action_input=action_input, + action_output=action_output, + vonage_uuid=( + action_input.vonage_uuid + if isinstance(action_input, VonagePhoneConversationActionInput) + else None + ), + twilio_sid=( + action_input.twilio_sid + if isinstance(action_input, TwilioPhoneConversationActionInput) + else None + ), + is_quiet=action.quiet, ), - is_quiet=action.quiet, - ), - is_interruptible=False, + is_interruptible=False, + ) ) diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index f6a56a7e7..7dd5a3fe7 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -38,6 +38,7 @@ from vocode.streaming.utils import unrepeating_randomizer from vocode.streaming.utils.speed_manager import SpeedManager from vocode.streaming.utils.worker import ( + AbstractWorker, InterruptibleAgentResponseEvent, InterruptibleEvent, InterruptibleEventFactory, @@ -154,6 +155,9 @@ def get_cut_off_response(self) -> str: class BaseAgent(AbstractAgent[AgentConfigType], InterruptibleWorker): + agent_responses_consumer: AbstractWorker[InterruptibleAgentResponseEvent[AgentResponse]] + actions_consumer: Optional[AbstractWorker[InterruptibleEvent[ActionInput]]] + def __init__( self, agent_config: AgentConfigType, @@ -161,18 +165,12 @@ def __init__( interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), ): self.input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] = asyncio.Queue() - self.output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = ( - asyncio.Queue() - ) AbstractAgent.__init__(self, agent_config=agent_config) InterruptibleWorker.__init__( self, - input_queue=self.input_queue, - output_queue=self.output_queue, interruptible_event_factory=interruptible_event_factory, ) self.action_factory = action_factory - self.actions_queue: asyncio.Queue[InterruptibleEvent[ActionInput]] = asyncio.Queue() self.transcript: Optional[Transcript] = None self.functions = self.get_functions() if self.agent_config.actions else None @@ -211,11 +209,6 @@ def get_input_queue( ) -> asyncio.Queue[InterruptibleEvent[AgentInput]]: return self.input_queue - def get_output_queue( - self, - ) -> asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]: - return self.output_queue - def is_first_response(self): assert self.transcript is not None @@ -299,14 +292,16 @@ async def handle_generate_response( continue agent_response_tracker = agent_input.agent_response_tracker or asyncio.Event() - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=generated_response.message, - is_first=is_first_response_of_turn, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=generated_response.message, + is_first=is_first_response_of_turn, + ), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off + and generated_response.is_interruptible, + agent_response_tracker=agent_response_tracker, ), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off - and generated_response.is_interruptible, - agent_response_tracker=agent_response_tracker, ) if isinstance(generated_response.message, BaseMessage): responses_buffer = f"{responses_buffer} {generated_response.message.text}" @@ -330,14 +325,15 @@ async def handle_generate_response( end_of_turn_agent_response_tracker = ( agent_input.agent_response_tracker or asyncio.Event() ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=EndOfTurn(), - is_first=is_first_response_of_turn, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=EndOfTurn(), + is_first=is_first_response_of_turn, + ), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + agent_response_tracker=end_of_turn_agent_response_tracker, ), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off - and generated_response.is_interruptible, - agent_response_tracker=end_of_turn_agent_response_tracker, ) phrase_trigger_match = ( @@ -374,13 +370,17 @@ async def handle_respond(self, transcription: Transcription, conversation_id: st response = None return True if response: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=BaseMessage(text=response)), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=BaseMessage(text=response)), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + ) ) return should_stop else: @@ -408,15 +408,19 @@ async def process(self, item: InterruptibleEvent[AgentInput]): logger.debug("Action is quiet, skipping response generation") return if agent_input.action_output.canned_response is not None: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=agent_input.action_output.canned_response, - is_sole_text_chunk=True, - ), - is_interruptible=True, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=agent_input.action_output.canned_response, + is_sole_text_chunk=True, + ), + is_interruptible=True, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + ) ) return transcription = Transcription( @@ -432,8 +436,10 @@ async def process(self, item: InterruptibleEvent[AgentInput]): return if self.agent_config.send_filler_audio: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseFillerAudio(), + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseFillerAudio(), + ) ) logger.debug("Responding to transcription") @@ -451,7 +457,11 @@ async def process(self, item: InterruptibleEvent[AgentInput]): if should_stop: logger.debug("Agent requested to stop") - self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop()) + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseStop(), + ) + ) return except asyncio.CancelledError: pass @@ -478,16 +488,20 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp if "user_message" in params: user_message = params["user_message"] user_message_tracker = asyncio.Event() - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=BaseMessage(text=user_message), - is_sole_text_chunk=True, - ), - is_interruptible=action.is_interruptible, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=BaseMessage(text=user_message), + is_sole_text_chunk=True, + ), + is_interruptible=action.is_interruptible, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), - agent_response_tracker=user_message_tracker, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + agent_response_tracker=user_message_tracker, + ) ) action_input = self.create_action_input(action, agent_input, params, user_message_tracker) self.enqueue_action_input(action, action_input, agent_input.conversation_id) @@ -534,6 +548,9 @@ def enqueue_action_input( action_input: ActionInput, conversation_id: str, ): + if self.actions_consumer is None: + logger.warning("No actions consumer attached, skipping action") + return event = self.interruptible_event_factory.create_interruptible_event( action_input, is_interruptible=action.is_interruptible, @@ -543,7 +560,7 @@ def enqueue_action_input( action_input=action_input, conversation_id=conversation_id, ) - self.actions_queue.put_nowait(event) + self.actions_consumer.consume_nonblocking(event) async def respond( self, diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index 1a09e199d..f8233588f 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -27,7 +27,6 @@ class WebSocketUserImplementedAgent(BaseAgent[WebSocketUserImplementedAgentConfig]): input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] - output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] def __init__( self, @@ -63,8 +62,11 @@ def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> Non raise Exception("Unknown Socket message type") logger.info("Putting interruptible agent response event in output queue") - self.produce_interruptible_agent_response_event_nonblocking( - agent_response, self.get_agent_config().allow_agent_to_be_cut_off + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + agent_response, + is_interruptible=self.get_agent_config().allow_agent_to_be_cut_off, + ) ) async def _process(self) -> None: @@ -137,5 +139,9 @@ async def receiver(ws: WebSocketClientProtocol) -> None: await asyncio.gather(sender(ws), receiver(ws)) def terminate(self): - self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop()) + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseStop() + ) + ) super().terminate() diff --git a/vocode/streaming/models/synthesizer.py b/vocode/streaming/models/synthesizer.py index cc7969eae..ce8af4bab 100644 --- a/vocode/streaming/models/synthesizer.py +++ b/vocode/streaming/models/synthesizer.py @@ -220,7 +220,7 @@ class BarkSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.BARK.value): DEFAULT_POLLY_LANGUAGE_CODE = "en-US" DEFAULT_POLLY_VOICE_ID = "Matthew" -DEFAULT_POLLY_SAMPLING_RATE = SamplingRate.RATE_16000 +DEFAULT_POLLY_SAMPLING_RATE = SamplingRate.RATE_16000.value class PollySynthesizerConfig(SynthesizerConfig, type=SynthesizerType.POLLY.value): # type: ignore diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py index 985077f63..2137d1321 100644 --- a/vocode/streaming/output_device/abstract_output_device.py +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -16,7 +16,7 @@ class AbstractOutputDevice(AsyncWorker[InterruptibleEvent[AudioChunk]]): """ def __init__(self, sampling_rate: int, audio_encoding: AudioEncoding): - super().__init__(input_queue=asyncio.Queue()) + super().__init__() self.sampling_rate = sampling_rate self.audio_encoding = audio_encoding diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index d273cb9e0..4679fb0a7 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -19,8 +19,7 @@ class _PlaybackWorker(ThreadAsyncWorker[bytes]): def __init__(self, *, device_info: dict, sampling_rate: int): self.sampling_rate = sampling_rate self.device_info = device_info - self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - super().__init__(self.input_queue) + super().__init__() self.stream = sd.OutputStream( channels=1, samplerate=self.sampling_rate, @@ -28,7 +27,7 @@ def __init__(self, *, device_info: dict, sampling_rate: int): device=int(self.device_info["index"]), ) self._ended = False - self.input_queue.put_nowait(self.sampling_rate * b"\x00") + self.consume_nonblocking(self.sampling_rate * b"\x00") self.stream.start() def _run_loop(self): diff --git a/vocode/streaming/output_device/livekit_output_device.py b/vocode/streaming/output_device/livekit_output_device.py index 5a2e1a040..bfb5dd5d5 100644 --- a/vocode/streaming/output_device/livekit_output_device.py +++ b/vocode/streaming/output_device/livekit_output_device.py @@ -25,7 +25,7 @@ def __init__( async def _run_loop(self): while True: try: - item = await self.input_queue.get() + item = await self._input_queue.get() except asyncio.CancelledError: return diff --git a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py index 98493b9d8..b621b213f 100644 --- a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py +++ b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py @@ -27,7 +27,7 @@ async def _run_loop(self): while True: start_time = time.time() try: - item = await self.input_queue.get() + item = await self._input_queue.get() except asyncio.CancelledError: return diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 52e862ca6..f6fc02500 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -66,11 +66,13 @@ from vocode.streaming.utils.speed_manager import SpeedManager from vocode.streaming.utils.state_manager import ConversationStateManager from vocode.streaming.utils.worker import ( + AbstractWorker, AsyncQueueWorker, InterruptibleAgentResponseEvent, InterruptibleAgentResponseWorker, InterruptibleEvent, InterruptibleEventFactory, + InterruptibleWorker, ) from vocode.utils.sentry_utils import ( CustomSentrySpans, @@ -143,16 +145,14 @@ class TranscriptionsWorker(AsyncQueueWorker[Transcription]): """Processes all transcriptions: sends an interrupt if needed and sends final transcriptions to the output queue""" + consumer: AbstractWorker[InterruptibleEvent[Transcription]] + def __init__( self, - input_queue: asyncio.Queue[Transcription], - output_queue: asyncio.Queue[InterruptibleEvent[AgentInput]], conversation: "StreamingConversation", interruptible_event_factory: InterruptibleEventFactory, ): - super().__init__(input_queue, output_queue) - self.input_queue = input_queue - self.output_queue = output_queue + super().__init__() self.conversation = conversation self.interruptible_event_factory = interruptible_event_factory self.in_interrupt_endpointing_config = False @@ -313,9 +313,9 @@ async def process(self, transcription: Transcription): agent_response_tracker=agent_response_tracker, ), ) - self.output_queue.put_nowait(event) + self.consumer.consume_nonblocking(event) - class FillerAudioWorker(InterruptibleAgentResponseWorker): + class FillerAudioWorker(InterruptibleWorker[InterruptibleAgentResponseEvent[FillerAudio]]): """ - Waits for a configured number of seconds and then sends filler audio to the output - Exposes wait_for_filler_audio_to_finish() which the AgentResponsesWorker waits on before @@ -324,11 +324,9 @@ class FillerAudioWorker(InterruptibleAgentResponseWorker): def __init__( self, - input_queue: asyncio.Queue[InterruptibleAgentResponseEvent[FillerAudio]], conversation: "StreamingConversation", ): - super().__init__(input_queue=input_queue) - self.input_queue = input_queue + super().__init__() self.conversation = conversation self.current_filler_seconds_per_chunk: Optional[int] = None self.filler_audio_started_event: Optional[threading.Event] = None @@ -369,26 +367,21 @@ async def process(self, item: InterruptibleAgentResponseEvent[FillerAudio]): except asyncio.CancelledError: pass - class AgentResponsesWorker(InterruptibleAgentResponseWorker): + class AgentResponsesWorker(InterruptibleWorker[InterruptibleAgentResponseEvent[AgentResponse]]): """Runs Synthesizer.create_speech and sends the SynthesisResult to the output queue""" + consumer: AbstractWorker[ + InterruptibleAgentResponseEvent[ + Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] + ] + ] + def __init__( self, - input_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]], - output_queue: asyncio.Queue[ - InterruptibleAgentResponseEvent[ - Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] - ] - ], conversation: "StreamingConversation", interruptible_event_factory: InterruptibleEventFactory, ): - super().__init__( - input_queue=input_queue, - output_queue=output_queue, - ) - self.input_queue = input_queue - self.output_queue = output_queue + super().__init__() self.conversation = conversation self.interruptible_event_factory = interruptible_event_factory self.chunk_size = self.conversation._get_synthesizer_chunk_size() @@ -437,10 +430,12 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): logger.debug("Sending end of turn") if isinstance(self.conversation.synthesizer, InputStreamingSynthesizer): await self.conversation.synthesizer.handle_end_of_turn() - self.produce_interruptible_agent_response_event_nonblocking( - (agent_response_message.message, None), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response_message.message, None), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), ) self.is_first_text_chunk = True return @@ -507,10 +502,12 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): if not synthesis_result.cached and synthesis_span: synthesis_result.synthesis_total_span = synthesis_span synthesis_result.ttft_span = ttft_span - self.produce_interruptible_agent_response_event_nonblocking( - (agent_response_message.message, synthesis_result), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response_message.message, synthesis_result), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), ) self.last_agent_response_tracker = item.agent_response_tracker if not isinstance(agent_response_message.message, SilenceMessage): @@ -518,20 +515,20 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): except asyncio.CancelledError: pass - class SynthesisResultsWorker(InterruptibleAgentResponseWorker): + class SynthesisResultsWorker( + InterruptibleWorker[ + InterruptibleAgentResponseEvent[ + Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] + ] + ] + ): """Plays SynthesisResults from the output queue on the output device""" def __init__( self, - input_queue: asyncio.Queue[ - InterruptibleAgentResponseEvent[ - Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] - ] - ], conversation: "StreamingConversation", ): - super().__init__(input_queue=input_queue) - self.input_queue = input_queue + super().__init__() self.conversation = conversation self.last_transcript_message: Optional[Message] = None @@ -604,49 +601,52 @@ def __init__( self.interruptible_events: queue.Queue[InterruptibleEvent] = queue.Queue() self.interruptible_event_factory = self.QueueingInterruptibleEventFactory(conversation=self) - self.agent.set_interruptible_event_factory(self.interruptible_event_factory) self.synthesis_results_queue: asyncio.Queue[ InterruptibleAgentResponseEvent[ Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] ] ] = asyncio.Queue() - self.filler_audio_queue: asyncio.Queue[InterruptibleAgentResponseEvent[FillerAudio]] = ( - asyncio.Queue() - ) self.state_manager = self.create_state_manager() + + # Transcriptions Worker self.transcriptions_worker = self.TranscriptionsWorker( - input_queue=self.transcriber.output_queue, - output_queue=self.agent.get_input_queue(), conversation=self, interruptible_event_factory=self.interruptible_event_factory, ) + self.transcriber.consumer = self.transcriptions_worker + + # Agent + self.transcriptions_worker.consumer = self.agent + self.agent.set_interruptible_event_factory(self.interruptible_event_factory) self.agent.attach_conversation_state_manager(self.state_manager) + + # Agent Responses Worker self.agent_responses_worker = self.AgentResponsesWorker( - input_queue=self.agent.get_output_queue(), - output_queue=self.synthesis_results_queue, conversation=self, interruptible_event_factory=self.interruptible_event_factory, ) + self.agent.agent_responses_consumer = self.agent_responses_worker + + # Actions Worker self.actions_worker = None if self.agent.get_agent_config().actions: self.actions_worker = ActionsWorker( - input_queue=self.agent.actions_queue, - output_queue=self.agent.get_input_queue(), - interruptible_event_factory=self.interruptible_event_factory, action_factory=self.agent.action_factory, + interruptible_event_factory=self.interruptible_event_factory, ) self.actions_worker.attach_conversation_state_manager(self.state_manager) - self.synthesis_results_worker = self.SynthesisResultsWorker( - input_queue=self.synthesis_results_queue, - conversation=self, - ) + self.actions_worker.consumer = self.agent + self.agent.actions_consumer = self.actions_worker + + # Synthesis Results Worker + self.synthesis_results_worker = self.SynthesisResultsWorker(conversation=self) + self.agent_responses_worker.consumer = self.synthesis_results_worker + + # Filler Audio Worker self.filler_audio_worker = None self.filler_audio_config: Optional[FillerAudioConfig] = None if self.agent.get_agent_config().send_filler_audio: - self.filler_audio_worker = self.FillerAudioWorker( - input_queue=self.filler_audio_queue, - conversation=self, - ) + self.filler_audio_worker = self.FillerAudioWorker(conversation=self) self.speed_coefficient = speed_coefficient self.speed_manager = SpeedManager( diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index bdbd6d7d5..8ccd6aad2 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -22,6 +22,7 @@ from vocode.streaming.utils import convert_wav, get_chunk_size_per_second from vocode.streaming.utils.async_requester import AsyncRequestor from vocode.streaming.utils.create_task import asyncio_create_task +from vocode.streaming.utils.worker import QueueConsumer FILLER_PHRASES = [ BaseMessage(text="Um..."), @@ -410,14 +411,12 @@ async def experimental_mp3_streaming_output_generator( response: aiohttp.ClientResponse, chunk_size: int, ) -> AsyncGenerator[SynthesisResult.ChunkResult, None]: - miniaudio_worker_input_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue() - miniaudio_worker_output_queue: asyncio.Queue[Tuple[bytes, bool]] = asyncio.Queue() + miniaudio_worker_consumer: QueueConsumer = QueueConsumer() miniaudio_worker = MiniaudioWorker( self.synthesizer_config, chunk_size, - miniaudio_worker_input_queue, - miniaudio_worker_output_queue, ) + miniaudio_worker.consumer = miniaudio_worker_consumer miniaudio_worker.start() stream_reader = response.content @@ -428,12 +427,12 @@ async def send_chunks(): miniaudio_worker.consume_nonblocking(None) # sentinel try: - asyncio_create_task(send_chunks(), reraise_cancelled=True) + asyncio_create_task(send_chunks()) # Await the output queue of the MiniaudioWorker and yield the wav chunks in another loop while True: # Get the wav chunk and the flag from the output queue of the MiniaudioWorker - wav_chunk, is_last = await miniaudio_worker.output_queue.get() + wav_chunk, is_last = await miniaudio_worker_consumer.input_queue.get() if self.synthesizer_config.should_encode_as_wav: wav_chunk = encode_as_wav(wav_chunk, self.synthesizer_config) diff --git a/vocode/streaming/synthesizer/miniaudio_worker.py b/vocode/streaming/synthesizer/miniaudio_worker.py index 92d33adc3..fcba60460 100644 --- a/vocode/streaming/synthesizer/miniaudio_worker.py +++ b/vocode/streaming/synthesizer/miniaudio_worker.py @@ -10,23 +10,39 @@ from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.utils import convert_wav from vocode.streaming.utils.mp3_helper import decode_mp3 -from vocode.streaming.utils.worker import ThreadAsyncWorker +from vocode.streaming.utils.worker import AbstractWorker, ThreadAsyncWorker class MiniaudioWorker(ThreadAsyncWorker[Union[bytes, None]]): + consumer: AbstractWorker[Tuple[bytes, bool]] + def __init__( self, synthesizer_config: SynthesizerConfig, chunk_size: int, - input_queue: asyncio.Queue[Union[bytes, None]], - output_queue: asyncio.Queue[Tuple[bytes, bool]], ) -> None: - super().__init__(input_queue, output_queue) - self.output_queue = output_queue # for typing + super().__init__() self.synthesizer_config = synthesizer_config self.chunk_size = chunk_size self._ended = False + async def run_thread_forwarding(self): + try: + await asyncio.gather( + self._forward_to_thread(), + self._forward_from_thread(), + ) + except asyncio.CancelledError: + return + + async def _forward_from_thread(self): + while True: + try: + chunk, done = await self.output_janus_queue.async_q.get() + self.consumer.consume_nonblocking((chunk, done)) + except asyncio.CancelledError: + break + def _run_loop(self): # tracks the mp3 so far current_mp3_buffer = bytearray() diff --git a/vocode/streaming/telephony/constants.py b/vocode/streaming/telephony/constants.py index 889971bbf..34c1766fc 100644 --- a/vocode/streaming/telephony/constants.py +++ b/vocode/streaming/telephony/constants.py @@ -1,12 +1,12 @@ from vocode.streaming.models.audio import AudioEncoding, SamplingRate # TODO(EPD-186): namespace as Twilio -DEFAULT_SAMPLING_RATE = SamplingRate.RATE_8000 +DEFAULT_SAMPLING_RATE: int = SamplingRate.RATE_8000.value DEFAULT_AUDIO_ENCODING = AudioEncoding.MULAW DEFAULT_CHUNK_SIZE = 20 * 160 MULAW_SILENCE_BYTE = b"\xff" -VONAGE_SAMPLING_RATE = SamplingRate.RATE_16000 +VONAGE_SAMPLING_RATE: int = SamplingRate.RATE_16000.value VONAGE_AUDIO_ENCODING = AudioEncoding.LINEAR16 VONAGE_CHUNK_SIZE = 640 # 20ms at 16kHz with 16bit samples VONAGE_CONTENT_TYPE = "audio/l16;rate=16000" diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 1e3cf4c0f..cf876a180 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -66,12 +66,6 @@ def __init__( events_manager=events_manager, speed_coefficient=speed_coefficient, ) - self.transcriptions_worker = self.TranscriptionsWorker( - input_queue=self.transcriber.output_queue, - output_queue=self.agent.get_input_queue(), - conversation=self, - interruptible_event_factory=self.interruptible_event_factory, - ) self.config_manager = config_manager def attach_ws(self, ws: WebSocket): diff --git a/vocode/streaming/transcriber/assembly_ai_transcriber.py b/vocode/streaming/transcriber/assembly_ai_transcriber.py index 47fa12c36..331810e73 100644 --- a/vocode/streaming/transcriber/assembly_ai_transcriber.py +++ b/vocode/streaming/transcriber/assembly_ai_transcriber.py @@ -75,7 +75,7 @@ def send_audio(self, chunk): if ( len(self.buffer) / (2 * self.transcriber_config.sampling_rate) ) >= self.transcriber_config.buffer_size_seconds: - self.input_queue.put_nowait(self.buffer) + self.consume_nonblocking(self.buffer) self.buffer = bytearray() def terminate(self): @@ -106,7 +106,7 @@ async def process(self): async def sender(ws): # sends audio to websocket while not self._ended: try: - data = await asyncio.wait_for(self.input_queue.get(), 5) + data = await asyncio.wait_for(self._input_queue.get(), 5) except asyncio.exceptions.TimeoutError: break num_channels = 1 @@ -133,7 +133,7 @@ async def receiver(ws): is_final = "message_type" in data and data["message_type"] == "FinalTranscript" if "text" in data and data["text"]: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=data["text"], confidence=data["confidence"], diff --git a/vocode/streaming/transcriber/azure_transcriber.py b/vocode/streaming/transcriber/azure_transcriber.py index c272fdf89..2cc666340 100644 --- a/vocode/streaming/transcriber/azure_transcriber.py +++ b/vocode/streaming/transcriber/azure_transcriber.py @@ -79,12 +79,12 @@ def recognized_sentence_final(self, evt): op=CustomSentrySpans.LATENCY_OF_CONVERSATION, start_timestamp=datetime.now(tz=timezone.utc), ) - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=evt.result.text, confidence=1.0, is_final=True) ) def recognized_sentence_stream(self, evt): - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=evt.result.text, confidence=1.0, is_final=False) ) diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index 3e89d23d1..beea852b1 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -8,18 +8,19 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.utils.speed_manager import SpeedManager -from vocode.streaming.utils.worker import AsyncWorker, ThreadAsyncWorker +from vocode.streaming.utils.worker import AbstractWorker, AsyncWorker, ThreadAsyncWorker TranscriberConfigType = TypeVar("TranscriberConfigType", bound=TranscriberConfig) -class AbstractTranscriber(Generic[TranscriberConfigType], ABC): +class AbstractTranscriber(Generic[TranscriberConfigType], AbstractWorker[bytes]): + consumer: AbstractWorker[Transcription] + def __init__(self, transcriber_config: TranscriberConfigType): + AbstractWorker.__init__(self) self.transcriber_config = transcriber_config self.is_muted = False self.speed_manager: Optional[SpeedManager] = None - self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - self.output_queue: asyncio.Queue[Transcription] = asyncio.Queue() def attach_speed_manager(self, speed_manager: SpeedManager): self.speed_manager = speed_manager @@ -47,12 +48,15 @@ def create_silent_chunk(self, chunk_size, sample_width=2): async def _run_loop(self): pass - def send_audio(self, chunk): + def send_audio(self, chunk: bytes): if not self.is_muted: self.consume_nonblocking(chunk) else: self.consume_nonblocking(self.create_silent_chunk(len(chunk))) + def produce_nonblocking(self, item: Transcription): + self.consumer.consume_nonblocking(item) + @abstractmethod def terminate(self): pass @@ -61,7 +65,7 @@ def terminate(self): class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) - AsyncWorker.__init__(self, self.input_queue, self.output_queue) + AsyncWorker.__init__(self) def terminate(self): AsyncWorker.terminate(self) @@ -72,11 +76,31 @@ class BaseThreadAsyncTranscriber( # type: ignore ): def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) - ThreadAsyncWorker.__init__(self, self.input_queue, self.output_queue) + ThreadAsyncWorker.__init__(self) def _run_loop(self): raise NotImplementedError + async def run_thread_forwarding(self): + try: + await asyncio.gather( + self._forward_to_thread(), + self._forward_from_thread(), + ) + except asyncio.CancelledError: + return + + async def _forward_from_thread(self): + while True: + try: + transcription = await self.output_janus_queue.async_q.get() + self.consumer.consume_nonblocking(transcription) + except asyncio.CancelledError: + break + + def produce_nonblocking(self, item: Transcription): + self.output_janus_queue.sync_q.put_nowait(item) + def terminate(self): ThreadAsyncWorker.terminate(self) diff --git a/vocode/streaming/transcriber/deepgram_transcriber.py b/vocode/streaming/transcriber/deepgram_transcriber.py index eef281d58..4835d9646 100644 --- a/vocode/streaming/transcriber/deepgram_transcriber.py +++ b/vocode/streaming/transcriber/deepgram_transcriber.py @@ -190,7 +190,7 @@ def terminate(self): }, ) terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8") - self.input_queue.put_nowait(terminate_msg) + self.consume_nonblocking(terminate_msg) # todo (dow-107): typing self._ended = True super().terminate() @@ -404,7 +404,7 @@ async def sender( while not self._ended: try: - data = await asyncio.wait_for(self.input_queue.get(), 5) + data = await asyncio.wait_for(self._input_queue.get(), 5) except asyncio.exceptions.TimeoutError: break @@ -485,7 +485,7 @@ async def receiver(ws: WebSocketClientProtocol): is_final_ts=is_final_ts, output_ts=output_ts, ) - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=buffer, confidence=buffer_avg_confidence, @@ -513,7 +513,7 @@ async def receiver(ws: WebSocketClientProtocol): else: interim_message = buffer - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=interim_message, confidence=deepgram_response.top_choice.confidence, diff --git a/vocode/streaming/transcriber/gladia_transcriber.py b/vocode/streaming/transcriber/gladia_transcriber.py index 54366035a..1cd8f32b0 100644 --- a/vocode/streaming/transcriber/gladia_transcriber.py +++ b/vocode/streaming/transcriber/gladia_transcriber.py @@ -53,7 +53,7 @@ def send_audio(self, chunk): if ( len(self.buffer) / (2 * self.transcriber_config.sampling_rate) ) >= self.transcriber_config.buffer_size_seconds: - self.input_queue.put_nowait(self.buffer) + self.consume_nonblocking(self.buffer) self.buffer = bytearray() def terminate(self): @@ -75,7 +75,7 @@ async def process(self): async def sender(ws): while not self._ended: try: - data = await asyncio.wait_for(self.input_queue.get(), 5) + data = await asyncio.wait_for(self._input_queue.get(), 5) except asyncio.exceptions.TimeoutError: break @@ -104,7 +104,7 @@ async def receiver(ws): is_final = data["type"] == "final" if "transcription" in data and data["transcription"]: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=data["transcription"], confidence=data["confidence"], diff --git a/vocode/streaming/transcriber/google_transcriber.py b/vocode/streaming/transcriber/google_transcriber.py index 5f3da9a68..97c2c34b5 100644 --- a/vocode/streaming/transcriber/google_transcriber.py +++ b/vocode/streaming/transcriber/google_transcriber.py @@ -80,7 +80,7 @@ def _on_response(self, response): message = top_choice.transcript confidence = top_choice.confidence - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=message, confidence=confidence, is_final=result.is_final) ) diff --git a/vocode/streaming/transcriber/rev_ai_transcriber.py b/vocode/streaming/transcriber/rev_ai_transcriber.py index 856768695..0684f3dce 100644 --- a/vocode/streaming/transcriber/rev_ai_transcriber.py +++ b/vocode/streaming/transcriber/rev_ai_transcriber.py @@ -74,7 +74,7 @@ async def process(self): async def sender(ws: WebSocketClientProtocol): while not self.closed: try: - data = await asyncio.wait_for(self.input_queue.get(), 5) + data = await asyncio.wait_for(self._input_queue.get(), 5) except asyncio.exceptions.TimeoutError: break await ws.send(data) @@ -118,12 +118,12 @@ async def receiver(ws: WebSocketClientProtocol): confidence = 1.0 if is_done: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription(message=buffer, confidence=confidence, is_final=True) ) buffer = "" else: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=buffer, confidence=confidence, @@ -137,5 +137,5 @@ async def receiver(ws: WebSocketClientProtocol): def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}) - self.input_queue.put_nowait(terminate_msg) + self.consume_nonblocking(terminate_msg) self.closed = True diff --git a/vocode/streaming/transcriber/whisper_cpp_transcriber.py b/vocode/streaming/transcriber/whisper_cpp_transcriber.py index 6f1967de5..c12c6333c 100644 --- a/vocode/streaming/transcriber/whisper_cpp_transcriber.py +++ b/vocode/streaming/transcriber/whisper_cpp_transcriber.py @@ -72,7 +72,7 @@ def _run_loop(self): message_buffer += message is_final = any(message_buffer.endswith(ending) for ending in SENTENCE_ENDINGS) in_memory_wav, audio_buffer = self.create_new_buffer() - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription(message=message_buffer, confidence=confidence, is_final=is_final) ) if is_final: diff --git a/vocode/streaming/utils/create_task.py b/vocode/streaming/utils/create_task.py index 38778d0bf..0680709c5 100644 --- a/vocode/streaming/utils/create_task.py +++ b/vocode/streaming/utils/create_task.py @@ -8,7 +8,6 @@ def asyncio_create_task( *args, - reraise_cancelled: bool = False, **kwargs, ) -> asyncio.Task: task = asyncio.create_task(*args, **kwargs) diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index 16c52a655..d321ea193 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -2,7 +2,8 @@ import asyncio import threading -from typing import Any, Generic, Optional, TypeVar +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, TypeVar import janus from loguru import logger @@ -12,15 +13,44 @@ WorkerInputType = TypeVar("WorkerInputType") -class AsyncWorker(Generic[WorkerInputType]): +class AbstractWorker(Generic[WorkerInputType], ABC): + """ + A generic processor - knows only how to consume typed items. + In order for a worker to process items, clients must invoke start() and tear down with terminate() + """ + + @abstractmethod + def start(self): + raise NotImplementedError + + @abstractmethod + def consume_nonblocking(self, item: WorkerInputType): + raise NotImplementedError + + def terminate(self): + pass + + +class QueueConsumer(AbstractWorker[WorkerInputType]): + def __init__( + self, + input_queue: Optional[asyncio.Queue[WorkerInputType]] = None, + ) -> None: + self.input_queue: asyncio.Queue[WorkerInputType] = input_queue or asyncio.Queue() + + def consume_nonblocking(self, item: WorkerInputType): + self.input_queue.put_nowait(item) + + def start(self): + pass + + +class AsyncWorker(AbstractWorker[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue[WorkerInputType], - output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: self.worker_task: Optional[asyncio.Task] = None - self.input_queue = input_queue - self.output_queue = output_queue + self._input_queue: asyncio.Queue[WorkerInputType] = asyncio.Queue() def start(self) -> asyncio.Task: self.worker_task = asyncio_create_task( @@ -31,10 +61,7 @@ def start(self) -> asyncio.Task: return self.worker_task def consume_nonblocking(self, item: WorkerInputType): - self.input_queue.put_nowait(item) - - def produce_nonblocking(self, item): - self.output_queue.put_nowait(item) + self._input_queue.put_nowait(item) async def _run_loop(self): raise NotImplementedError @@ -49,10 +76,8 @@ def terminate(self): class ThreadAsyncWorker(AsyncWorker[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue[WorkerInputType], - output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: - super().__init__(input_queue, output_queue) + super().__init__() self.worker_thread: Optional[threading.Thread] = None self.input_janus_queue: janus.Queue[WorkerInputType] = janus.Queue() self.output_janus_queue: janus.Queue = janus.Queue() @@ -69,35 +94,24 @@ def start(self) -> asyncio.Task: async def run_thread_forwarding(self): try: - await asyncio.gather( - self._forward_to_thread(), - self._forward_from_thead(), - ) + await self._forward_to_thread() except asyncio.CancelledError: return async def _forward_to_thread(self): while True: - item = await self.input_queue.get() + item = await self._input_queue.get() self.input_janus_queue.async_q.put_nowait(item) - async def _forward_from_thead(self): - while True: - item = await self.output_janus_queue.async_q.get() - self.output_queue.put_nowait(item) - def _run_loop(self): raise NotImplementedError - def terminate(self): - return super().terminate() - class AsyncQueueWorker(AsyncWorker[WorkerInputType]): async def _run_loop(self): while True: try: - item = await self.input_queue.get() + item = await self._input_queue.get() await self.process(item) except asyncio.CancelledError: return @@ -180,44 +194,20 @@ def create_interruptible_agent_response_event( class InterruptibleWorker(AsyncWorker[InterruptibleEventType]): def __init__( self, - input_queue: asyncio.Queue[InterruptibleEventType], - output_queue: asyncio.Queue = asyncio.Queue(), interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), max_concurrency=2, ) -> None: - super().__init__(input_queue, output_queue) - self.input_queue = input_queue + super().__init__() self.max_concurrency = max_concurrency self.interruptible_event_factory = interruptible_event_factory self.current_task = None self.interruptible_event = None - def produce_interruptible_event_nonblocking(self, item: Any, is_interruptible: bool = True): - interruptible_event = self.interruptible_event_factory.create_interruptible_event( - item, is_interruptible=is_interruptible - ) - return super().produce_nonblocking(interruptible_event) - - def produce_interruptible_agent_response_event_nonblocking( - self, - item: Any, - is_interruptible: bool = True, - agent_response_tracker: Optional[asyncio.Event] = None, - ): - interruptible_utterance_event = ( - self.interruptible_event_factory.create_interruptible_agent_response_event( - item, - is_interruptible=is_interruptible, - agent_response_tracker=agent_response_tracker or asyncio.Event(), - ) - ) - return super().produce_nonblocking(interruptible_utterance_event) - async def _run_loop(self): # TODO Implement concurrency with max_nb_of_thread while True: try: - item = await self.input_queue.get() + item = await self._input_queue.get() except asyncio.CancelledError: return @@ -226,7 +216,6 @@ async def _run_loop(self): self.interruptible_event = item self.current_task = asyncio_create_task( self.process(item), - reraise_cancelled=True, ) try: