From 42e69bf1ecda14343c9e873286a87fc17b3b6d6a Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 14:53:38 -0700 Subject: [PATCH] [DOW-119] creates AudioPipeline abstraction (#625) * make terminate async * creates audio pipeline abstraction * fix streaming conversation api * make terminate() invocations in tests async * removes the vector_db.tear_down() call in streaming conversation --- playground/streaming/agent/chat.py | 4 +-- tests/streaming/action/test_dtmf.py | 2 +- tests/streaming/agent/test_base_agent.py | 2 +- ..._rate_limit_interruptions_output_device.py | 2 +- .../test_twilio_output_device.py | 2 +- .../streaming/test_streaming_conversation.py | 10 +++--- vocode/streaming/agent/chat_gpt_agent.py | 5 +++ .../agent/websocket_user_implemented_agent.py | 4 +-- .../output_device/blocking_speaker_output.py | 10 +++--- .../output_device/file_output_device.py | 4 +-- .../streaming/output_device/speaker_output.py | 2 +- vocode/streaming/streaming_conversation.py | 35 ++++++++----------- .../streaming/synthesizer/base_synthesizer.py | 2 +- .../streaming/synthesizer/miniaudio_worker.py | 4 +-- .../transcriber/assembly_ai_transcriber.py | 4 +-- .../transcriber/azure_transcriber.py | 4 +-- .../streaming/transcriber/base_transcriber.py | 12 +++---- .../transcriber/deepgram_transcriber.py | 4 +-- .../transcriber/gladia_transcriber.py | 4 +-- .../transcriber/google_transcriber.py | 4 +-- .../transcriber/rev_ai_transcriber.py | 3 +- .../transcriber/whisper_cpp_transcriber.py | 2 +- vocode/streaming/utils/audio_pipeline.py | 21 +++++++++++ vocode/streaming/utils/worker.py | 4 +-- 24 files changed, 84 insertions(+), 66 deletions(-) create mode 100644 vocode/streaming/utils/audio_pipeline.py diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index e93b9eeda8..dfcede2dec 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -197,7 +197,7 @@ async def sender(): await asyncio.gather(receiver(), sender()) if actions_worker is not None: - actions_worker.terminate() + await actions_worker.terminate() async def agent_main(): @@ -233,7 +233,7 @@ async def agent_main(): try: await run_agent(agent, interruption_probability=0, backchannel_probability=0) except KeyboardInterrupt: - agent.terminate() + await agent.terminate() if __name__ == "__main__": diff --git a/tests/streaming/action/test_dtmf.py b/tests/streaming/action/test_dtmf.py index 026779c0a3..e65f46cf5e 100644 --- a/tests/streaming/action/test_dtmf.py +++ b/tests/streaming/action/test_dtmf.py @@ -118,7 +118,7 @@ async def test_twilio_dtmf_press_digits( assert False, "Timed out waiting for DTMF tones to be sent" assert action_output.response.success - mock_twilio_output_device.terminate() + await mock_twilio_output_device.terminate() for digit, call in zip(digits, mock_twilio_output_device.ws.send_text.call_args_list): expected_dtmf = DTMFToneGenerator().generate( diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index 26519ad141..e3fe9e5f90 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -140,7 +140,7 @@ async def test_generate_responses(mocker: MockerFixture): agent.agent_responses_consumer = agent_consumer agent.start() agent_responses = await _consume_until_end_of_turn(agent_consumer) - agent.terminate() + await agent.terminate() messages = [response.message for response in agent_responses] diff --git a/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py b/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py index 44c6679ab1..66891a420d 100644 --- a/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py +++ b/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py @@ -65,4 +65,4 @@ def uninterruptible_on_play(): await uninterruptible_played_event.wait() assert uninterruptible_audio_chunk.state == ChunkState.PLAYED - output_device.terminate() + await output_device.terminate() diff --git a/tests/streaming/output_device/test_twilio_output_device.py b/tests/streaming/output_device/test_twilio_output_device.py index e8cd69f3d5..c2c26c8d87 100644 --- a/tests/streaming/output_device/test_twilio_output_device.py +++ b/tests/streaming/output_device/test_twilio_output_device.py @@ -57,7 +57,7 @@ def on_play(): assert mark_message["streamSid"] == twilio_output_device.stream_sid assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id) - twilio_output_device.terminate() + await twilio_output_device.terminate() @pytest.mark.asyncio diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 2266a7e47d..89149994ab 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -287,7 +287,7 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( backchannel.sender == Sender.HUMAN and backchannel.is_backchannel for backchannel in human_backchannels ) - streaming_conversation.transcriptions_worker.terminate() + await streaming_conversation.transcriptions_worker.terminate() @pytest.mark.asyncio @@ -355,7 +355,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( "I'm listening.", ] assert streaming_conversation.transcript.event_logs[-1].is_backchannel - streaming_conversation.transcriptions_worker.terminate() + await streaming_conversation.transcriptions_worker.terminate() @pytest.mark.asyncio @@ -405,7 +405,7 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( assert streaming_conversation.transcript.event_logs[-1].sender == Sender.BOT assert streaming_conversation.transcript.event_logs[-1].text == "Hi, I was wondering" - streaming_conversation.transcriptions_worker.terminate() + await streaming_conversation.transcriptions_worker.terminate() @pytest.mark.asyncio @@ -460,7 +460,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called - streaming_conversation.transcriptions_worker.terminate() + await streaming_conversation.transcriptions_worker.terminate() def _create_dummy_synthesis_result( @@ -576,7 +576,7 @@ async def chunk_generator(): stop_event.set() streaming_conversation.output_device.interrupt_event.set() message_sent, cut_off = await send_speech_to_output_task - streaming_conversation.output_device.terminate() + await streaming_conversation.output_device.terminate() assert message_sent != "Hi there" assert cut_off diff --git a/vocode/streaming/agent/chat_gpt_agent.py b/vocode/streaming/agent/chat_gpt_agent.py index 4d54eb54a2..0ca8c5c1b9 100644 --- a/vocode/streaming/agent/chat_gpt_agent.py +++ b/vocode/streaming/agent/chat_gpt_agent.py @@ -290,3 +290,8 @@ async def generate_response( message=message, is_interruptible=True, ) + + async def terminate(self): + if hasattr(self, "vector_db") and self.vector_db is not None: + await self.vector_db.tear_down() + return await super().terminate() diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index f8233588fb..0faa522544 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -138,10 +138,10 @@ async def receiver(ws: WebSocketClientProtocol) -> None: await asyncio.gather(sender(ws), receiver(ws)) - def terminate(self): + async def terminate(self): self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( AgentResponseStop() ) ) - super().terminate() + await super().terminate() diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index 4679fb0a72..c872b42801 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -38,9 +38,9 @@ def _run_loop(self): except queue.Empty: continue - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() self.stream.close() @@ -65,9 +65,9 @@ def start(self) -> asyncio.Task: self.playback_worker.start() return super().start() - def terminate(self): - self.playback_worker.terminate() - super().terminate() + async def terminate(self): + await self.playback_worker.terminate() + await super().terminate() @classmethod def from_default_device( diff --git a/vocode/streaming/output_device/file_output_device.py b/vocode/streaming/output_device/file_output_device.py index 607fd5b3e1..d2df2aa621 100644 --- a/vocode/streaming/output_device/file_output_device.py +++ b/vocode/streaming/output_device/file_output_device.py @@ -30,6 +30,6 @@ def __init__( async def play(self, chunk: bytes): await asyncio.to_thread(lambda: self.wav.writeframes(chunk)) - def terminate(self): + async def terminate(self): self.wav.close() - super().terminate() + await super().terminate() diff --git a/vocode/streaming/output_device/speaker_output.py b/vocode/streaming/output_device/speaker_output.py index c416c814ef..17acc9e266 100644 --- a/vocode/streaming/output_device/speaker_output.py +++ b/vocode/streaming/output_device/speaker_output.py @@ -53,7 +53,7 @@ def consume_nonblocking(self, chunk): block[:size] = chunk_arr[i : i + size] self.queue.put_nowait(block) - def terminate(self): + async def terminate(self): self.stream.close() @classmethod diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index f6fc02500c..bb13270806 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -24,6 +24,7 @@ from loguru import logger from sentry_sdk.tracing import Span +from vocode import conversation_id as ctx_conversation_id from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.agent.base_agent import ( AgentInput, @@ -61,6 +62,7 @@ enumerate_async_iter, get_chunk_size_per_second, ) +from vocode.streaming.utils.audio_pipeline import AudioPipeline, OutputDeviceType from vocode.streaming.utils.create_task import asyncio_create_task from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.speed_manager import SpeedManager @@ -107,10 +109,7 @@ LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3 -OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice) - - -class StreamingConversation(Generic[OutputDeviceType]): +class StreamingConversation(AudioPipeline[OutputDeviceType]): class QueueingInterruptibleEventFactory(InterruptibleEventFactory): def __init__(self, conversation: "StreamingConversation"): self.conversation = conversation @@ -593,6 +592,8 @@ def __init__( events_manager: Optional[EventsManager] = None, ): self.id = conversation_id or create_conversation_id() + ctx_conversation_id.set(self.id) + self.output_device = output_device self.transcriber = transcriber self.agent = agent @@ -800,8 +801,8 @@ def receive_message(self, message: str): ) self.transcriptions_worker.consume_nonblocking(transcription) - def receive_audio(self, chunk: bytes): - self.transcriber.send_audio(chunk) + def consume_nonblocking(self, item: bytes): + self.transcriber.send_audio(item) def warmup_synthesizer(self): self.synthesizer.ready_synthesizer(self._get_synthesizer_chunk_size()) @@ -1010,29 +1011,23 @@ async def terminate(self): logger.debug("Tearing down synthesizer") await self.synthesizer.tear_down() logger.debug("Terminating agent") - if isinstance(self.agent, ChatGPTAgent) and self.agent.agent_config.vector_db_config: - # Shutting down the vector db should be done in the agent's terminate method, - # but it is done here because `vector_db.tear_down()` is async and - # `agent.terminate()` is not async. - logger.debug("Terminating vector db") - await self.agent.vector_db.tear_down() - self.agent.terminate() + await self.agent.terminate() logger.debug("Terminating output device") - self.output_device.terminate() + await self.output_device.terminate() logger.debug("Terminating speech transcriber") - self.transcriber.terminate() + await self.transcriber.terminate() logger.debug("Terminating transcriptions worker") - self.transcriptions_worker.terminate() + await self.transcriptions_worker.terminate() logger.debug("Terminating final transcriptions worker") - self.agent_responses_worker.terminate() + await self.agent_responses_worker.terminate() logger.debug("Terminating synthesis results worker") - self.synthesis_results_worker.terminate() + await self.synthesis_results_worker.terminate() if self.filler_audio_worker is not None: logger.debug("Terminating filler audio worker") - self.filler_audio_worker.terminate() + await self.filler_audio_worker.terminate() if self.actions_worker is not None: logger.debug("Terminating actions worker") - self.actions_worker.terminate() + await self.actions_worker.terminate() logger.debug("Successfully terminated") def is_active(self): diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index 8ccd6aad25..f6e70aca0b 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -446,7 +446,7 @@ async def send_chunks(): except asyncio.CancelledError: pass finally: - miniaudio_worker.terminate() + await miniaudio_worker.terminate() def _resample_chunk( self, diff --git a/vocode/streaming/synthesizer/miniaudio_worker.py b/vocode/streaming/synthesizer/miniaudio_worker.py index fcba60460b..9cb67e4d74 100644 --- a/vocode/streaming/synthesizer/miniaudio_worker.py +++ b/vocode/streaming/synthesizer/miniaudio_worker.py @@ -96,6 +96,6 @@ def _run_loop(self): current_wav_output_buffer = current_wav_output_buffer[output_buffer_idx:] current_wav_buffer.extend(new_bytes) - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() diff --git a/vocode/streaming/transcriber/assembly_ai_transcriber.py b/vocode/streaming/transcriber/assembly_ai_transcriber.py index 331810e735..00953bd44d 100644 --- a/vocode/streaming/transcriber/assembly_ai_transcriber.py +++ b/vocode/streaming/transcriber/assembly_ai_transcriber.py @@ -78,9 +78,9 @@ def send_audio(self, chunk): self.consume_nonblocking(self.buffer) self.buffer = bytearray() - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() def get_assembly_ai_url(self): url_params = {"sample_rate": self.transcriber_config.sampling_rate} diff --git a/vocode/streaming/transcriber/azure_transcriber.py b/vocode/streaming/transcriber/azure_transcriber.py index 2cc666340f..00a7a7a7f9 100644 --- a/vocode/streaming/transcriber/azure_transcriber.py +++ b/vocode/streaming/transcriber/azure_transcriber.py @@ -141,7 +141,7 @@ def generator(self): yield b"".join(data) - def terminate(self): + async def terminate(self): self._ended = True self.speech.stop_continuous_recognition_async() - super().terminate() + await super().terminate() diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index beea852b15..edd1009d48 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -57,18 +57,14 @@ def send_audio(self, chunk: bytes): def produce_nonblocking(self, item: Transcription): self.consumer.consume_nonblocking(item) - @abstractmethod - def terminate(self): - pass - class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) AsyncWorker.__init__(self) - def terminate(self): - AsyncWorker.terminate(self) + async def terminate(self): + await AsyncWorker.terminate(self) class BaseThreadAsyncTranscriber( # type: ignore @@ -101,8 +97,8 @@ async def _forward_from_thread(self): def produce_nonblocking(self, item: Transcription): self.output_janus_queue.sync_q.put_nowait(item) - def terminate(self): - ThreadAsyncWorker.terminate(self) + async def terminate(self): + await ThreadAsyncWorker.terminate(self) BaseTranscriber = Union[ diff --git a/vocode/streaming/transcriber/deepgram_transcriber.py b/vocode/streaming/transcriber/deepgram_transcriber.py index 4835d9646e..7f64f4a976 100644 --- a/vocode/streaming/transcriber/deepgram_transcriber.py +++ b/vocode/streaming/transcriber/deepgram_transcriber.py @@ -177,7 +177,7 @@ async def _run_loop(self): logger.error("Deepgram connection died, not restarting") - def terminate(self): + async def terminate(self): self._track_latency_of_transcription_start() # Put this in logs until we sentry metrics show up # properly on dashboard @@ -192,7 +192,7 @@ def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8") self.consume_nonblocking(terminate_msg) # todo (dow-107): typing self._ended = True - super().terminate() + await super().terminate() def get_input_sample_width(self): encoding = self.transcriber_config.audio_encoding diff --git a/vocode/streaming/transcriber/gladia_transcriber.py b/vocode/streaming/transcriber/gladia_transcriber.py index 1cd8f32b0a..c01b82fc56 100644 --- a/vocode/streaming/transcriber/gladia_transcriber.py +++ b/vocode/streaming/transcriber/gladia_transcriber.py @@ -56,9 +56,9 @@ def send_audio(self, chunk): self.consume_nonblocking(self.buffer) self.buffer = bytearray() - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() async def process(self): async with websockets.connect(GLADIA_URL) as ws: diff --git a/vocode/streaming/transcriber/google_transcriber.py b/vocode/streaming/transcriber/google_transcriber.py index 97c2c34b56..0241ee7305 100644 --- a/vocode/streaming/transcriber/google_transcriber.py +++ b/vocode/streaming/transcriber/google_transcriber.py @@ -57,9 +57,9 @@ def _run_loop(self): responses = self.client.streaming_recognize(self.google_streaming_config, requests) self.process_responses_loop(responses) - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() def process_responses_loop(self, responses): for response in responses: diff --git a/vocode/streaming/transcriber/rev_ai_transcriber.py b/vocode/streaming/transcriber/rev_ai_transcriber.py index 0684f3dce6..bb5ae4d0c9 100644 --- a/vocode/streaming/transcriber/rev_ai_transcriber.py +++ b/vocode/streaming/transcriber/rev_ai_transcriber.py @@ -135,7 +135,8 @@ async def receiver(ws: WebSocketClientProtocol): await asyncio.gather(sender(ws), receiver(ws)) - def terminate(self): + async def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}) self.consume_nonblocking(terminate_msg) self.closed = True + await super().terminate() diff --git a/vocode/streaming/transcriber/whisper_cpp_transcriber.py b/vocode/streaming/transcriber/whisper_cpp_transcriber.py index c12c6333c3..f5c0205acc 100644 --- a/vocode/streaming/transcriber/whisper_cpp_transcriber.py +++ b/vocode/streaming/transcriber/whisper_cpp_transcriber.py @@ -78,5 +78,5 @@ def _run_loop(self): if is_final: message_buffer = "" - def terminate(self): + async def terminate(self): pass diff --git a/vocode/streaming/utils/audio_pipeline.py b/vocode/streaming/utils/audio_pipeline.py new file mode 100644 index 0000000000..60fcd2f449 --- /dev/null +++ b/vocode/streaming/utils/audio_pipeline.py @@ -0,0 +1,21 @@ +from abc import abstractmethod +from typing import Generic, TypeVar + +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.utils.events_manager import EventsManager +from vocode.streaming.utils.worker import AbstractWorker + +OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice) + + +class AudioPipeline(AbstractWorker[bytes], Generic[OutputDeviceType]): + output_device: OutputDeviceType + events_manager: EventsManager + id: str + + def receive_audio(self, chunk: bytes): + self.consume_nonblocking(chunk) + + @abstractmethod + def is_active(self): + raise NotImplementedError diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index d321ea193d..c98e0b653a 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -27,7 +27,7 @@ def start(self): def consume_nonblocking(self, item: WorkerInputType): raise NotImplementedError - def terminate(self): + async def terminate(self): pass @@ -66,7 +66,7 @@ def consume_nonblocking(self, item: WorkerInputType): async def _run_loop(self): raise NotImplementedError - def terminate(self): + async def terminate(self): if self.worker_task: return self.worker_task.cancel()