Skip to content

Commit

Permalink
Update vocodehq-public (#631)
Browse files Browse the repository at this point in the history
* [Bug #628] correct coding errors in the google synthesiser (#629)

* [Bug-628] correct coding errors in the google synthesiser

* create_speech --> create_speech_uncached

---------

Co-authored-by: Ajay Raj <[email protected]>

* [DOW-119] creates AudioPipeline abstraction (#625)

* make terminate async

* creates audio pipeline abstraction

* fix streaming conversation api

* make terminate() invocations in tests async

* removes the vector_db.tear_down() call in streaming conversation

---------

Co-authored-by: jstahlbaum-fibernetics <[email protected]>
  • Loading branch information
ajar98 and jstahlbaum-fibernetics committed Jul 12, 2024
1 parent 48cba3e commit 2eb0115
Show file tree
Hide file tree
Showing 25 changed files with 88 additions and 70 deletions.
4 changes: 2 additions & 2 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def sender():

await asyncio.gather(receiver(), sender())
if actions_worker is not None:
actions_worker.terminate()
await actions_worker.terminate()


async def agent_main():
Expand Down Expand Up @@ -233,7 +233,7 @@ async def agent_main():
try:
await run_agent(agent, interruption_probability=0, backchannel_probability=0)
except KeyboardInterrupt:
agent.terminate()
await agent.terminate()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/action/test_dtmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_twilio_dtmf_press_digits(
assert False, "Timed out waiting for DTMF tones to be sent"

assert action_output.response.success
mock_twilio_output_device.terminate()
await mock_twilio_output_device.terminate()

for digit, call in zip(digits, mock_twilio_output_device.ws.send_text.call_args_list):
expected_dtmf = DTMFToneGenerator().generate(
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def test_generate_responses(mocker: MockerFixture):
agent.agent_responses_consumer = agent_consumer
agent.start()
agent_responses = await _consume_until_end_of_turn(agent_consumer)
agent.terminate()
await agent.terminate()

messages = [response.message for response in agent_responses]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def uninterruptible_on_play():
await uninterruptible_played_event.wait()
assert uninterruptible_audio_chunk.state == ChunkState.PLAYED

output_device.terminate()
await output_device.terminate()
2 changes: 1 addition & 1 deletion tests/streaming/output_device/test_twilio_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def on_play():
assert mark_message["streamSid"] == twilio_output_device.stream_sid
assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id)

twilio_output_device.terminate()
await twilio_output_device.terminate()


@pytest.mark.asyncio
Expand Down
10 changes: 5 additions & 5 deletions tests/streaming/test_streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message(
backchannel.sender == Sender.HUMAN and backchannel.is_backchannel
for backchannel in human_backchannels
)
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -355,7 +355,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance(
"I'm listening.",
]
assert streaming_conversation.transcript.event_logs[-1].is_backchannel
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -405,7 +405,7 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts(

assert streaming_conversation.transcript.event_logs[-1].sender == Sender.BOT
assert streaming_conversation.transcript.event_logs[-1].text == "Hi, I was wondering"
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -460,7 +460,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun
assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


def _create_dummy_synthesis_result(
Expand Down Expand Up @@ -576,7 +576,7 @@ async def chunk_generator():
stop_event.set()
streaming_conversation.output_device.interrupt_event.set()
message_sent, cut_off = await send_speech_to_output_task
streaming_conversation.output_device.terminate()
await streaming_conversation.output_device.terminate()

assert message_sent != "Hi there"
assert cut_off
Expand Down
5 changes: 5 additions & 0 deletions vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,8 @@ async def generate_response(
message=message,
is_interruptible=True,
)

async def terminate(self):
if hasattr(self, "vector_db") and self.vector_db is not None:
await self.vector_db.tear_down()
return await super().terminate()
4 changes: 2 additions & 2 deletions vocode/streaming/agent/websocket_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ async def receiver(ws: WebSocketClientProtocol) -> None:

await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
async def terminate(self):
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseStop()
)
)
super().terminate()
await super().terminate()
10 changes: 5 additions & 5 deletions vocode/streaming/output_device/blocking_speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _run_loop(self):
except queue.Empty:
continue

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
self.stream.close()


Expand All @@ -65,9 +65,9 @@ def start(self) -> asyncio.Task:
self.playback_worker.start()
return super().start()

def terminate(self):
self.playback_worker.terminate()
super().terminate()
async def terminate(self):
await self.playback_worker.terminate()
await super().terminate()

@classmethod
def from_default_device(
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/output_device/file_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def __init__(
async def play(self, chunk: bytes):
await asyncio.to_thread(lambda: self.wav.writeframes(chunk))

def terminate(self):
async def terminate(self):
self.wav.close()
super().terminate()
await super().terminate()
2 changes: 1 addition & 1 deletion vocode/streaming/output_device/speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def consume_nonblocking(self, chunk):
block[:size] = chunk_arr[i : i + size]
self.queue.put_nowait(block)

def terminate(self):
async def terminate(self):
self.stream.close()

@classmethod
Expand Down
35 changes: 15 additions & 20 deletions vocode/streaming/streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from loguru import logger
from sentry_sdk.tracing import Span

from vocode import conversation_id as ctx_conversation_id
from vocode.streaming.action.worker import ActionsWorker
from vocode.streaming.agent.base_agent import (
AgentInput,
Expand Down Expand Up @@ -61,6 +62,7 @@
enumerate_async_iter,
get_chunk_size_per_second,
)
from vocode.streaming.utils.audio_pipeline import AudioPipeline, OutputDeviceType
from vocode.streaming.utils.create_task import asyncio_create_task
from vocode.streaming.utils.events_manager import EventsManager
from vocode.streaming.utils.speed_manager import SpeedManager
Expand Down Expand Up @@ -107,10 +109,7 @@
LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3


OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice)


class StreamingConversation(Generic[OutputDeviceType]):
class StreamingConversation(AudioPipeline[OutputDeviceType]):
class QueueingInterruptibleEventFactory(InterruptibleEventFactory):
def __init__(self, conversation: "StreamingConversation"):
self.conversation = conversation
Expand Down Expand Up @@ -593,6 +592,8 @@ def __init__(
events_manager: Optional[EventsManager] = None,
):
self.id = conversation_id or create_conversation_id()
ctx_conversation_id.set(self.id)

self.output_device = output_device
self.transcriber = transcriber
self.agent = agent
Expand Down Expand Up @@ -800,8 +801,8 @@ def receive_message(self, message: str):
)
self.transcriptions_worker.consume_nonblocking(transcription)

def receive_audio(self, chunk: bytes):
self.transcriber.send_audio(chunk)
def consume_nonblocking(self, item: bytes):
self.transcriber.send_audio(item)

def warmup_synthesizer(self):
self.synthesizer.ready_synthesizer(self._get_synthesizer_chunk_size())
Expand Down Expand Up @@ -1010,29 +1011,23 @@ async def terminate(self):
logger.debug("Tearing down synthesizer")
await self.synthesizer.tear_down()
logger.debug("Terminating agent")
if isinstance(self.agent, ChatGPTAgent) and self.agent.agent_config.vector_db_config:
# Shutting down the vector db should be done in the agent's terminate method,
# but it is done here because `vector_db.tear_down()` is async and
# `agent.terminate()` is not async.
logger.debug("Terminating vector db")
await self.agent.vector_db.tear_down()
self.agent.terminate()
await self.agent.terminate()
logger.debug("Terminating output device")
self.output_device.terminate()
await self.output_device.terminate()
logger.debug("Terminating speech transcriber")
self.transcriber.terminate()
await self.transcriber.terminate()
logger.debug("Terminating transcriptions worker")
self.transcriptions_worker.terminate()
await self.transcriptions_worker.terminate()
logger.debug("Terminating final transcriptions worker")
self.agent_responses_worker.terminate()
await self.agent_responses_worker.terminate()
logger.debug("Terminating synthesis results worker")
self.synthesis_results_worker.terminate()
await self.synthesis_results_worker.terminate()
if self.filler_audio_worker is not None:
logger.debug("Terminating filler audio worker")
self.filler_audio_worker.terminate()
await self.filler_audio_worker.terminate()
if self.actions_worker is not None:
logger.debug("Terminating actions worker")
self.actions_worker.terminate()
await self.actions_worker.terminate()
logger.debug("Successfully terminated")

def is_active(self):
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/synthesizer/base_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ async def send_chunks():
except asyncio.CancelledError:
pass
finally:
miniaudio_worker.terminate()
await miniaudio_worker.terminate()

def _resample_chunk(
self,
Expand Down
8 changes: 4 additions & 4 deletions vocode/streaming/synthesizer/google_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

import google.auth
from google.cloud import texttospeech as tts # type: ignore
from google.cloud import texttospeech_v1beta1 as tts # type: ignore

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import GoogleSynthesizerConfig
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
# Select the type of audio file you want returned
self.audio_config = tts.AudioConfig(
audio_encoding=tts.AudioEncoding.LINEAR16,
sample_rate_hertz=24000,
sample_rate_hertz=synthesizer_config.sampling_rate,
speaking_rate=synthesizer_config.speaking_rate,
pitch=synthesizer_config.pitch,
effects_profile_id=["telephony-class-application"],
Expand All @@ -56,7 +56,7 @@ def synthesize(self, message: str) -> Any:
)

# TODO: make this nonblocking, see speech.TextToSpeechAsyncClient
async def create_speech(
async def create_speech_uncached(
self,
message: BaseMessage,
chunk_size: int,
Expand All @@ -75,7 +75,7 @@ async def create_speech(
in_memory_wav.setnchannels(1)
in_memory_wav.setsampwidth(2)
in_memory_wav.setframerate(output_sample_rate)
in_memory_wav.writeframes(response.audio_content)
in_memory_wav.writeframes(response.audio_content[44:])
output_bytes_io.seek(0)

result = self.create_synthesis_result_from_wav(
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/synthesizer/miniaudio_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def _run_loop(self):
current_wav_output_buffer = current_wav_output_buffer[output_buffer_idx:]
current_wav_buffer.extend(new_bytes)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/assembly_ai_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def get_assembly_ai_url(self):
url_params = {"sample_rate": self.transcriber_config.sampling_rate}
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/azure_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def generator(self):

yield b"".join(data)

def terminate(self):
async def terminate(self):
self._ended = True
self.speech.stop_continuous_recognition_async()
super().terminate()
await super().terminate()
12 changes: 4 additions & 8 deletions vocode/streaming/transcriber/base_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,14 @@ def send_audio(self, chunk: bytes):
def produce_nonblocking(self, item: Transcription):
self.consumer.consume_nonblocking(item)

@abstractmethod
def terminate(self):
pass


class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore
def __init__(self, transcriber_config: TranscriberConfigType):
AbstractTranscriber.__init__(self, transcriber_config)
AsyncWorker.__init__(self)

def terminate(self):
AsyncWorker.terminate(self)
async def terminate(self):
await AsyncWorker.terminate(self)


class BaseThreadAsyncTranscriber( # type: ignore
Expand Down Expand Up @@ -101,8 +97,8 @@ async def _forward_from_thread(self):
def produce_nonblocking(self, item: Transcription):
self.output_janus_queue.sync_q.put_nowait(item)

def terminate(self):
ThreadAsyncWorker.terminate(self)
async def terminate(self):
await ThreadAsyncWorker.terminate(self)


BaseTranscriber = Union[
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/deepgram_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def _run_loop(self):

logger.error("Deepgram connection died, not restarting")

def terminate(self):
async def terminate(self):
self._track_latency_of_transcription_start()
# Put this in logs until we sentry metrics show up
# properly on dashboard
Expand All @@ -192,7 +192,7 @@ def terminate(self):
terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8")
self.consume_nonblocking(terminate_msg) # todo (dow-107): typing
self._ended = True
super().terminate()
await super().terminate()

def get_input_sample_width(self):
encoding = self.transcriber_config.audio_encoding
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/gladia_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

async def process(self):
async with websockets.connect(GLADIA_URL) as ws:
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/google_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _run_loop(self):
responses = self.client.streaming_recognize(self.google_streaming_config, requests)
self.process_responses_loop(responses)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def process_responses_loop(self, responses):
for response in responses:
Expand Down
Loading

0 comments on commit 2eb0115

Please sign in to comment.