Skip to content

Commit

Permalink
make terminate async
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Jul 10, 2024
1 parent 48cba3e commit 2b1f4c4
Show file tree
Hide file tree
Showing 17 changed files with 42 additions and 45 deletions.
4 changes: 2 additions & 2 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def sender():

await asyncio.gather(receiver(), sender())
if actions_worker is not None:
actions_worker.terminate()
await actions_worker.terminate()


async def agent_main():
Expand Down Expand Up @@ -233,7 +233,7 @@ async def agent_main():
try:
await run_agent(agent, interruption_probability=0, backchannel_probability=0)
except KeyboardInterrupt:
agent.terminate()
await agent.terminate()


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/agent/websocket_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ async def receiver(ws: WebSocketClientProtocol) -> None:

await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
async def terminate(self):
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseStop()
)
)
super().terminate()
await super().terminate()
10 changes: 5 additions & 5 deletions vocode/streaming/output_device/blocking_speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _run_loop(self):
except queue.Empty:
continue

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
self.stream.close()


Expand All @@ -65,9 +65,9 @@ def start(self) -> asyncio.Task:
self.playback_worker.start()
return super().start()

def terminate(self):
self.playback_worker.terminate()
super().terminate()
async def terminate(self):
await self.playback_worker.terminate()
await super().terminate()

@classmethod
def from_default_device(
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/output_device/file_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def __init__(
async def play(self, chunk: bytes):
await asyncio.to_thread(lambda: self.wav.writeframes(chunk))

def terminate(self):
async def terminate(self):
self.wav.close()
super().terminate()
await super().terminate()
2 changes: 1 addition & 1 deletion vocode/streaming/output_device/speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def consume_nonblocking(self, chunk):
block[:size] = chunk_arr[i : i + size]
self.queue.put_nowait(block)

def terminate(self):
async def terminate(self):
self.stream.close()

@classmethod
Expand Down
16 changes: 8 additions & 8 deletions vocode/streaming/streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,23 +1016,23 @@ async def terminate(self):
# `agent.terminate()` is not async.
logger.debug("Terminating vector db")
await self.agent.vector_db.tear_down()
self.agent.terminate()
await self.agent.terminate()
logger.debug("Terminating output device")
self.output_device.terminate()
await self.output_device.terminate()
logger.debug("Terminating speech transcriber")
self.transcriber.terminate()
await self.transcriber.terminate()
logger.debug("Terminating transcriptions worker")
self.transcriptions_worker.terminate()
await self.transcriptions_worker.terminate()
logger.debug("Terminating final transcriptions worker")
self.agent_responses_worker.terminate()
await self.agent_responses_worker.terminate()
logger.debug("Terminating synthesis results worker")
self.synthesis_results_worker.terminate()
await self.synthesis_results_worker.terminate()
if self.filler_audio_worker is not None:
logger.debug("Terminating filler audio worker")
self.filler_audio_worker.terminate()
await self.filler_audio_worker.terminate()
if self.actions_worker is not None:
logger.debug("Terminating actions worker")
self.actions_worker.terminate()
await self.actions_worker.terminate()
logger.debug("Successfully terminated")

def is_active(self):
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/synthesizer/base_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ async def send_chunks():
except asyncio.CancelledError:
pass
finally:
miniaudio_worker.terminate()
await miniaudio_worker.terminate()

def _resample_chunk(
self,
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/synthesizer/miniaudio_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def _run_loop(self):
current_wav_output_buffer = current_wav_output_buffer[output_buffer_idx:]
current_wav_buffer.extend(new_bytes)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/assembly_ai_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def get_assembly_ai_url(self):
url_params = {"sample_rate": self.transcriber_config.sampling_rate}
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/azure_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def generator(self):

yield b"".join(data)

def terminate(self):
async def terminate(self):
self._ended = True
self.speech.stop_continuous_recognition_async()
super().terminate()
await super().terminate()
12 changes: 4 additions & 8 deletions vocode/streaming/transcriber/base_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,14 @@ def send_audio(self, chunk: bytes):
def produce_nonblocking(self, item: Transcription):
self.consumer.consume_nonblocking(item)

@abstractmethod
def terminate(self):
pass


class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore
def __init__(self, transcriber_config: TranscriberConfigType):
AbstractTranscriber.__init__(self, transcriber_config)
AsyncWorker.__init__(self)

def terminate(self):
AsyncWorker.terminate(self)
async def terminate(self):
await AsyncWorker.terminate(self)


class BaseThreadAsyncTranscriber( # type: ignore
Expand Down Expand Up @@ -101,8 +97,8 @@ async def _forward_from_thread(self):
def produce_nonblocking(self, item: Transcription):
self.output_janus_queue.sync_q.put_nowait(item)

def terminate(self):
ThreadAsyncWorker.terminate(self)
async def terminate(self):
await ThreadAsyncWorker.terminate(self)


BaseTranscriber = Union[
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/deepgram_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def _run_loop(self):

logger.error("Deepgram connection died, not restarting")

def terminate(self):
async def terminate(self):
self._track_latency_of_transcription_start()
# Put this in logs until we sentry metrics show up
# properly on dashboard
Expand All @@ -192,7 +192,7 @@ def terminate(self):
terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8")
self.consume_nonblocking(terminate_msg) # todo (dow-107): typing
self._ended = True
super().terminate()
await super().terminate()

def get_input_sample_width(self):
encoding = self.transcriber_config.audio_encoding
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/gladia_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

async def process(self):
async with websockets.connect(GLADIA_URL) as ws:
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/google_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _run_loop(self):
responses = self.client.streaming_recognize(self.google_streaming_config, requests)
self.process_responses_loop(responses)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def process_responses_loop(self, responses):
for response in responses:
Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/transcriber/rev_ai_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ async def receiver(ws: WebSocketClientProtocol):

await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
async def terminate(self):
terminate_msg = json.dumps({"type": "CloseStream"})
self.consume_nonblocking(terminate_msg)
self.closed = True
await super().terminate()
2 changes: 1 addition & 1 deletion vocode/streaming/transcriber/whisper_cpp_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ def _run_loop(self):
if is_final:
message_buffer = ""

def terminate(self):
async def terminate(self):
pass
4 changes: 2 additions & 2 deletions vocode/streaming/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def start(self):
def consume_nonblocking(self, item: WorkerInputType):
raise NotImplementedError

def terminate(self):
async def terminate(self):
pass


Expand Down Expand Up @@ -66,7 +66,7 @@ def consume_nonblocking(self, item: WorkerInputType):
async def _run_loop(self):
raise NotImplementedError

def terminate(self):
async def terminate(self):
if self.worker_task:
return self.worker_task.cancel()

Expand Down

0 comments on commit 2b1f4c4

Please sign in to comment.