Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vocodehq-public #632

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def sender():

await asyncio.gather(receiver(), sender())
if actions_worker is not None:
actions_worker.terminate()
await actions_worker.terminate()


async def agent_main():
Expand Down Expand Up @@ -233,7 +233,7 @@ async def agent_main():
try:
await run_agent(agent, interruption_probability=0, backchannel_probability=0)
except KeyboardInterrupt:
agent.terminate()
await agent.terminate()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/action/test_dtmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_twilio_dtmf_press_digits(
assert False, "Timed out waiting for DTMF tones to be sent"

assert action_output.response.success
mock_twilio_output_device.terminate()
await mock_twilio_output_device.terminate()

for digit, call in zip(digits, mock_twilio_output_device.ws.send_text.call_args_list):
expected_dtmf = DTMFToneGenerator().generate(
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def test_generate_responses(mocker: MockerFixture):
agent.agent_responses_consumer = agent_consumer
agent.start()
agent_responses = await _consume_until_end_of_turn(agent_consumer)
agent.terminate()
await agent.terminate()

messages = [response.message for response in agent_responses]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def uninterruptible_on_play():
await uninterruptible_played_event.wait()
assert uninterruptible_audio_chunk.state == ChunkState.PLAYED

output_device.terminate()
await output_device.terminate()
2 changes: 1 addition & 1 deletion tests/streaming/output_device/test_twilio_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def on_play():
assert mark_message["streamSid"] == twilio_output_device.stream_sid
assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id)

twilio_output_device.terminate()
await twilio_output_device.terminate()


@pytest.mark.asyncio
Expand Down
10 changes: 5 additions & 5 deletions tests/streaming/test_streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message(
backchannel.sender == Sender.HUMAN and backchannel.is_backchannel
for backchannel in human_backchannels
)
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -355,7 +355,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance(
"I'm listening.",
]
assert streaming_conversation.transcript.event_logs[-1].is_backchannel
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -405,7 +405,7 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts(

assert streaming_conversation.transcript.event_logs[-1].sender == Sender.BOT
assert streaming_conversation.transcript.event_logs[-1].text == "Hi, I was wondering"
streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


@pytest.mark.asyncio
Expand Down Expand Up @@ -460,7 +460,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun
assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcriptions_worker.terminate()
await streaming_conversation.transcriptions_worker.terminate()


def _create_dummy_synthesis_result(
Expand Down Expand Up @@ -576,7 +576,7 @@ async def chunk_generator():
stop_event.set()
streaming_conversation.output_device.interrupt_event.set()
message_sent, cut_off = await send_speech_to_output_task
streaming_conversation.output_device.terminate()
await streaming_conversation.output_device.terminate()

assert message_sent != "Hi there"
assert cut_off
Expand Down
5 changes: 5 additions & 0 deletions vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,8 @@ async def generate_response(
message=message,
is_interruptible=True,
)

async def terminate(self):
if hasattr(self, "vector_db") and self.vector_db is not None:
await self.vector_db.tear_down()
return await super().terminate()
4 changes: 2 additions & 2 deletions vocode/streaming/agent/websocket_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ async def receiver(ws: WebSocketClientProtocol) -> None:

await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
async def terminate(self):
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseStop()
)
)
super().terminate()
await super().terminate()
10 changes: 5 additions & 5 deletions vocode/streaming/output_device/blocking_speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _run_loop(self):
except queue.Empty:
continue

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
self.stream.close()


Expand All @@ -65,9 +65,9 @@ def start(self) -> asyncio.Task:
self.playback_worker.start()
return super().start()

def terminate(self):
self.playback_worker.terminate()
super().terminate()
async def terminate(self):
await self.playback_worker.terminate()
await super().terminate()

@classmethod
def from_default_device(
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/output_device/file_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def __init__(
async def play(self, chunk: bytes):
await asyncio.to_thread(lambda: self.wav.writeframes(chunk))

def terminate(self):
async def terminate(self):
self.wav.close()
super().terminate()
await super().terminate()
2 changes: 1 addition & 1 deletion vocode/streaming/output_device/speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def consume_nonblocking(self, chunk):
block[:size] = chunk_arr[i : i + size]
self.queue.put_nowait(block)

def terminate(self):
async def terminate(self):
self.stream.close()

@classmethod
Expand Down
35 changes: 15 additions & 20 deletions vocode/streaming/streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from loguru import logger
from sentry_sdk.tracing import Span

from vocode import conversation_id as ctx_conversation_id
from vocode.streaming.action.worker import ActionsWorker
from vocode.streaming.agent.base_agent import (
AgentInput,
Expand Down Expand Up @@ -61,6 +62,7 @@
enumerate_async_iter,
get_chunk_size_per_second,
)
from vocode.streaming.utils.audio_pipeline import AudioPipeline, OutputDeviceType
from vocode.streaming.utils.create_task import asyncio_create_task
from vocode.streaming.utils.events_manager import EventsManager
from vocode.streaming.utils.speed_manager import SpeedManager
Expand Down Expand Up @@ -107,10 +109,7 @@
LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3


OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice)


class StreamingConversation(Generic[OutputDeviceType]):
class StreamingConversation(AudioPipeline[OutputDeviceType]):
class QueueingInterruptibleEventFactory(InterruptibleEventFactory):
def __init__(self, conversation: "StreamingConversation"):
self.conversation = conversation
Expand Down Expand Up @@ -593,6 +592,8 @@ def __init__(
events_manager: Optional[EventsManager] = None,
):
self.id = conversation_id or create_conversation_id()
ctx_conversation_id.set(self.id)

self.output_device = output_device
self.transcriber = transcriber
self.agent = agent
Expand Down Expand Up @@ -800,8 +801,8 @@ def receive_message(self, message: str):
)
self.transcriptions_worker.consume_nonblocking(transcription)

def receive_audio(self, chunk: bytes):
self.transcriber.send_audio(chunk)
def consume_nonblocking(self, item: bytes):
self.transcriber.send_audio(item)

def warmup_synthesizer(self):
self.synthesizer.ready_synthesizer(self._get_synthesizer_chunk_size())
Expand Down Expand Up @@ -1010,29 +1011,23 @@ async def terminate(self):
logger.debug("Tearing down synthesizer")
await self.synthesizer.tear_down()
logger.debug("Terminating agent")
if isinstance(self.agent, ChatGPTAgent) and self.agent.agent_config.vector_db_config:
# Shutting down the vector db should be done in the agent's terminate method,
# but it is done here because `vector_db.tear_down()` is async and
# `agent.terminate()` is not async.
logger.debug("Terminating vector db")
await self.agent.vector_db.tear_down()
self.agent.terminate()
await self.agent.terminate()
logger.debug("Terminating output device")
self.output_device.terminate()
await self.output_device.terminate()
logger.debug("Terminating speech transcriber")
self.transcriber.terminate()
await self.transcriber.terminate()
logger.debug("Terminating transcriptions worker")
self.transcriptions_worker.terminate()
await self.transcriptions_worker.terminate()
logger.debug("Terminating final transcriptions worker")
self.agent_responses_worker.terminate()
await self.agent_responses_worker.terminate()
logger.debug("Terminating synthesis results worker")
self.synthesis_results_worker.terminate()
await self.synthesis_results_worker.terminate()
if self.filler_audio_worker is not None:
logger.debug("Terminating filler audio worker")
self.filler_audio_worker.terminate()
await self.filler_audio_worker.terminate()
if self.actions_worker is not None:
logger.debug("Terminating actions worker")
self.actions_worker.terminate()
await self.actions_worker.terminate()
logger.debug("Successfully terminated")

def is_active(self):
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/synthesizer/base_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ async def send_chunks():
except asyncio.CancelledError:
pass
finally:
miniaudio_worker.terminate()
await miniaudio_worker.terminate()

def _resample_chunk(
self,
Expand Down
8 changes: 4 additions & 4 deletions vocode/streaming/synthesizer/google_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

import google.auth
from google.cloud import texttospeech as tts # type: ignore
from google.cloud import texttospeech_v1beta1 as tts # type: ignore

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import GoogleSynthesizerConfig
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
# Select the type of audio file you want returned
self.audio_config = tts.AudioConfig(
audio_encoding=tts.AudioEncoding.LINEAR16,
sample_rate_hertz=24000,
sample_rate_hertz=synthesizer_config.sampling_rate,
speaking_rate=synthesizer_config.speaking_rate,
pitch=synthesizer_config.pitch,
effects_profile_id=["telephony-class-application"],
Expand All @@ -56,7 +56,7 @@ def synthesize(self, message: str) -> Any:
)

# TODO: make this nonblocking, see speech.TextToSpeechAsyncClient
async def create_speech(
async def create_speech_uncached(
self,
message: BaseMessage,
chunk_size: int,
Expand All @@ -75,7 +75,7 @@ async def create_speech(
in_memory_wav.setnchannels(1)
in_memory_wav.setsampwidth(2)
in_memory_wav.setframerate(output_sample_rate)
in_memory_wav.writeframes(response.audio_content)
in_memory_wav.writeframes(response.audio_content[44:])
output_bytes_io.seek(0)

result = self.create_synthesis_result_from_wav(
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/synthesizer/miniaudio_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def _run_loop(self):
current_wav_output_buffer = current_wav_output_buffer[output_buffer_idx:]
current_wav_buffer.extend(new_bytes)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/assembly_ai_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def get_assembly_ai_url(self):
url_params = {"sample_rate": self.transcriber_config.sampling_rate}
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/azure_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def generator(self):

yield b"".join(data)

def terminate(self):
async def terminate(self):
self._ended = True
self.speech.stop_continuous_recognition_async()
super().terminate()
await super().terminate()
12 changes: 4 additions & 8 deletions vocode/streaming/transcriber/base_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,14 @@ def send_audio(self, chunk: bytes):
def produce_nonblocking(self, item: Transcription):
self.consumer.consume_nonblocking(item)

@abstractmethod
def terminate(self):
pass


class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore
def __init__(self, transcriber_config: TranscriberConfigType):
AbstractTranscriber.__init__(self, transcriber_config)
AsyncWorker.__init__(self)

def terminate(self):
AsyncWorker.terminate(self)
async def terminate(self):
await AsyncWorker.terminate(self)


class BaseThreadAsyncTranscriber( # type: ignore
Expand Down Expand Up @@ -101,8 +97,8 @@ async def _forward_from_thread(self):
def produce_nonblocking(self, item: Transcription):
self.output_janus_queue.sync_q.put_nowait(item)

def terminate(self):
ThreadAsyncWorker.terminate(self)
async def terminate(self):
await ThreadAsyncWorker.terminate(self)


BaseTranscriber = Union[
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/deepgram_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def _run_loop(self):

logger.error("Deepgram connection died, not restarting")

def terminate(self):
async def terminate(self):
self._track_latency_of_transcription_start()
# Put this in logs until we sentry metrics show up
# properly on dashboard
Expand All @@ -192,7 +192,7 @@ def terminate(self):
terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8")
self.consume_nonblocking(terminate_msg) # todo (dow-107): typing
self._ended = True
super().terminate()
await super().terminate()

def get_input_sample_width(self):
encoding = self.transcriber_config.audio_encoding
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/gladia_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def send_audio(self, chunk):
self.consume_nonblocking(self.buffer)
self.buffer = bytearray()

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

async def process(self):
async with websockets.connect(GLADIA_URL) as ws:
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/transcriber/google_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _run_loop(self):
responses = self.client.streaming_recognize(self.google_streaming_config, requests)
self.process_responses_loop(responses)

def terminate(self):
async def terminate(self):
self._ended = True
super().terminate()
await super().terminate()

def process_responses_loop(self, responses):
for response in responses:
Expand Down
Loading
Loading