From 2b1f4c4cdaf3ee3b4195c0e3c060e544542cdfd0 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 28 Jun 2024 16:34:16 -0700 Subject: [PATCH] make terminate async --- playground/streaming/agent/chat.py | 4 ++-- .../agent/websocket_user_implemented_agent.py | 4 ++-- .../output_device/blocking_speaker_output.py | 10 +++++----- .../output_device/file_output_device.py | 4 ++-- vocode/streaming/output_device/speaker_output.py | 2 +- vocode/streaming/streaming_conversation.py | 16 ++++++++-------- vocode/streaming/synthesizer/base_synthesizer.py | 2 +- vocode/streaming/synthesizer/miniaudio_worker.py | 4 ++-- .../transcriber/assembly_ai_transcriber.py | 4 ++-- .../streaming/transcriber/azure_transcriber.py | 4 ++-- vocode/streaming/transcriber/base_transcriber.py | 12 ++++-------- .../transcriber/deepgram_transcriber.py | 4 ++-- .../streaming/transcriber/gladia_transcriber.py | 4 ++-- .../streaming/transcriber/google_transcriber.py | 4 ++-- .../streaming/transcriber/rev_ai_transcriber.py | 3 ++- .../transcriber/whisper_cpp_transcriber.py | 2 +- vocode/streaming/utils/worker.py | 4 ++-- 17 files changed, 42 insertions(+), 45 deletions(-) diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index e93b9eeda..dfcede2de 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -197,7 +197,7 @@ async def sender(): await asyncio.gather(receiver(), sender()) if actions_worker is not None: - actions_worker.terminate() + await actions_worker.terminate() async def agent_main(): @@ -233,7 +233,7 @@ async def agent_main(): try: await run_agent(agent, interruption_probability=0, backchannel_probability=0) except KeyboardInterrupt: - agent.terminate() + await agent.terminate() if __name__ == "__main__": diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index f8233588f..0faa52254 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -138,10 +138,10 @@ async def receiver(ws: WebSocketClientProtocol) -> None: await asyncio.gather(sender(ws), receiver(ws)) - def terminate(self): + async def terminate(self): self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( AgentResponseStop() ) ) - super().terminate() + await super().terminate() diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index 4679fb0a7..c872b4280 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -38,9 +38,9 @@ def _run_loop(self): except queue.Empty: continue - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() self.stream.close() @@ -65,9 +65,9 @@ def start(self) -> asyncio.Task: self.playback_worker.start() return super().start() - def terminate(self): - self.playback_worker.terminate() - super().terminate() + async def terminate(self): + await self.playback_worker.terminate() + await super().terminate() @classmethod def from_default_device( diff --git a/vocode/streaming/output_device/file_output_device.py b/vocode/streaming/output_device/file_output_device.py index 607fd5b3e..d2df2aa62 100644 --- a/vocode/streaming/output_device/file_output_device.py +++ b/vocode/streaming/output_device/file_output_device.py @@ -30,6 +30,6 @@ def __init__( async def play(self, chunk: bytes): await asyncio.to_thread(lambda: self.wav.writeframes(chunk)) - def terminate(self): + async def terminate(self): self.wav.close() - super().terminate() + await super().terminate() diff --git a/vocode/streaming/output_device/speaker_output.py b/vocode/streaming/output_device/speaker_output.py index c416c814e..17acc9e26 100644 --- a/vocode/streaming/output_device/speaker_output.py +++ b/vocode/streaming/output_device/speaker_output.py @@ -53,7 +53,7 @@ def consume_nonblocking(self, chunk): block[:size] = chunk_arr[i : i + size] self.queue.put_nowait(block) - def terminate(self): + async def terminate(self): self.stream.close() @classmethod diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index f6fc02500..aed9b372c 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -1016,23 +1016,23 @@ async def terminate(self): # `agent.terminate()` is not async. logger.debug("Terminating vector db") await self.agent.vector_db.tear_down() - self.agent.terminate() + await self.agent.terminate() logger.debug("Terminating output device") - self.output_device.terminate() + await self.output_device.terminate() logger.debug("Terminating speech transcriber") - self.transcriber.terminate() + await self.transcriber.terminate() logger.debug("Terminating transcriptions worker") - self.transcriptions_worker.terminate() + await self.transcriptions_worker.terminate() logger.debug("Terminating final transcriptions worker") - self.agent_responses_worker.terminate() + await self.agent_responses_worker.terminate() logger.debug("Terminating synthesis results worker") - self.synthesis_results_worker.terminate() + await self.synthesis_results_worker.terminate() if self.filler_audio_worker is not None: logger.debug("Terminating filler audio worker") - self.filler_audio_worker.terminate() + await self.filler_audio_worker.terminate() if self.actions_worker is not None: logger.debug("Terminating actions worker") - self.actions_worker.terminate() + await self.actions_worker.terminate() logger.debug("Successfully terminated") def is_active(self): diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index 8ccd6aad2..f6e70aca0 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -446,7 +446,7 @@ async def send_chunks(): except asyncio.CancelledError: pass finally: - miniaudio_worker.terminate() + await miniaudio_worker.terminate() def _resample_chunk( self, diff --git a/vocode/streaming/synthesizer/miniaudio_worker.py b/vocode/streaming/synthesizer/miniaudio_worker.py index fcba60460..9cb67e4d7 100644 --- a/vocode/streaming/synthesizer/miniaudio_worker.py +++ b/vocode/streaming/synthesizer/miniaudio_worker.py @@ -96,6 +96,6 @@ def _run_loop(self): current_wav_output_buffer = current_wav_output_buffer[output_buffer_idx:] current_wav_buffer.extend(new_bytes) - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() diff --git a/vocode/streaming/transcriber/assembly_ai_transcriber.py b/vocode/streaming/transcriber/assembly_ai_transcriber.py index 331810e73..00953bd44 100644 --- a/vocode/streaming/transcriber/assembly_ai_transcriber.py +++ b/vocode/streaming/transcriber/assembly_ai_transcriber.py @@ -78,9 +78,9 @@ def send_audio(self, chunk): self.consume_nonblocking(self.buffer) self.buffer = bytearray() - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() def get_assembly_ai_url(self): url_params = {"sample_rate": self.transcriber_config.sampling_rate} diff --git a/vocode/streaming/transcriber/azure_transcriber.py b/vocode/streaming/transcriber/azure_transcriber.py index 2cc666340..00a7a7a7f 100644 --- a/vocode/streaming/transcriber/azure_transcriber.py +++ b/vocode/streaming/transcriber/azure_transcriber.py @@ -141,7 +141,7 @@ def generator(self): yield b"".join(data) - def terminate(self): + async def terminate(self): self._ended = True self.speech.stop_continuous_recognition_async() - super().terminate() + await super().terminate() diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index beea852b1..edd1009d4 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -57,18 +57,14 @@ def send_audio(self, chunk: bytes): def produce_nonblocking(self, item: Transcription): self.consumer.consume_nonblocking(item) - @abstractmethod - def terminate(self): - pass - class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) AsyncWorker.__init__(self) - def terminate(self): - AsyncWorker.terminate(self) + async def terminate(self): + await AsyncWorker.terminate(self) class BaseThreadAsyncTranscriber( # type: ignore @@ -101,8 +97,8 @@ async def _forward_from_thread(self): def produce_nonblocking(self, item: Transcription): self.output_janus_queue.sync_q.put_nowait(item) - def terminate(self): - ThreadAsyncWorker.terminate(self) + async def terminate(self): + await ThreadAsyncWorker.terminate(self) BaseTranscriber = Union[ diff --git a/vocode/streaming/transcriber/deepgram_transcriber.py b/vocode/streaming/transcriber/deepgram_transcriber.py index 4835d9646..7f64f4a97 100644 --- a/vocode/streaming/transcriber/deepgram_transcriber.py +++ b/vocode/streaming/transcriber/deepgram_transcriber.py @@ -177,7 +177,7 @@ async def _run_loop(self): logger.error("Deepgram connection died, not restarting") - def terminate(self): + async def terminate(self): self._track_latency_of_transcription_start() # Put this in logs until we sentry metrics show up # properly on dashboard @@ -192,7 +192,7 @@ def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8") self.consume_nonblocking(terminate_msg) # todo (dow-107): typing self._ended = True - super().terminate() + await super().terminate() def get_input_sample_width(self): encoding = self.transcriber_config.audio_encoding diff --git a/vocode/streaming/transcriber/gladia_transcriber.py b/vocode/streaming/transcriber/gladia_transcriber.py index 1cd8f32b0..c01b82fc5 100644 --- a/vocode/streaming/transcriber/gladia_transcriber.py +++ b/vocode/streaming/transcriber/gladia_transcriber.py @@ -56,9 +56,9 @@ def send_audio(self, chunk): self.consume_nonblocking(self.buffer) self.buffer = bytearray() - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() async def process(self): async with websockets.connect(GLADIA_URL) as ws: diff --git a/vocode/streaming/transcriber/google_transcriber.py b/vocode/streaming/transcriber/google_transcriber.py index 97c2c34b5..0241ee730 100644 --- a/vocode/streaming/transcriber/google_transcriber.py +++ b/vocode/streaming/transcriber/google_transcriber.py @@ -57,9 +57,9 @@ def _run_loop(self): responses = self.client.streaming_recognize(self.google_streaming_config, requests) self.process_responses_loop(responses) - def terminate(self): + async def terminate(self): self._ended = True - super().terminate() + await super().terminate() def process_responses_loop(self, responses): for response in responses: diff --git a/vocode/streaming/transcriber/rev_ai_transcriber.py b/vocode/streaming/transcriber/rev_ai_transcriber.py index 0684f3dce..bb5ae4d0c 100644 --- a/vocode/streaming/transcriber/rev_ai_transcriber.py +++ b/vocode/streaming/transcriber/rev_ai_transcriber.py @@ -135,7 +135,8 @@ async def receiver(ws: WebSocketClientProtocol): await asyncio.gather(sender(ws), receiver(ws)) - def terminate(self): + async def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}) self.consume_nonblocking(terminate_msg) self.closed = True + await super().terminate() diff --git a/vocode/streaming/transcriber/whisper_cpp_transcriber.py b/vocode/streaming/transcriber/whisper_cpp_transcriber.py index c12c6333c..f5c0205ac 100644 --- a/vocode/streaming/transcriber/whisper_cpp_transcriber.py +++ b/vocode/streaming/transcriber/whisper_cpp_transcriber.py @@ -78,5 +78,5 @@ def _run_loop(self): if is_final: message_buffer = "" - def terminate(self): + async def terminate(self): pass diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index d321ea193..c98e0b653 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -27,7 +27,7 @@ def start(self): def consume_nonblocking(self, item: WorkerInputType): raise NotImplementedError - def terminate(self): + async def terminate(self): pass @@ -66,7 +66,7 @@ def consume_nonblocking(self, item: WorkerInputType): async def _run_loop(self): raise NotImplementedError - def terminate(self): + async def terminate(self): if self.worker_task: return self.worker_task.cancel()