Skip to content

Commit

Permalink
[DOW-105] refactor interruptions into the output device (#586)
Browse files Browse the repository at this point in the history
* [DOW-105] refactor interruptions into the output device (#562)

* initial refactor works

* remove notion of UtteranceAudioChunk and put all of the state in the callback

* move per_chunk_allowance_seconds into output device

* onboard onto vonage

* rename to abstract output device and onboard other output devices

* initial work to onboard twilio output device

* twilio conversation works

* some cleanup with better comments

* unset poetry.lock

* move abstract play method into ratelimitoutputdevice + dispatch to thread in fileoutputdevice

* rename back to AsyncWorker

* comments

* work through a bit of mypy

* asyncio.gather is g2g:

* create interrupt lock

* remove todo

* remove last todo

* remove log for interrupts

* fmt

* fix mypy

* fix mypy

* isort

* creates first test and adds scaffolding

* adds two other send_speech_to_output tests

* make send_speech_to_output more efficient

* adds tests for rate limit interruptions output device

* makes some variables private and also makes the chunk id coming back from the mark match the incoming audio chunk

* adds twilio output device tests

* make typing better for output devices

* fix mypy

* resolve PR comments

* resolve PR comments
  • Loading branch information
ajar98 authored Jul 3, 2024
1 parent 50318ca commit 60d2187
Show file tree
Hide file tree
Showing 33 changed files with 848 additions and 461 deletions.
10 changes: 7 additions & 3 deletions playground/streaming/synthesizer/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import AzureSynthesizerConfig
from vocode.streaming.output_device.base_output_device import BaseOutputDevice
from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice
from vocode.streaming.output_device.audio_chunk import AudioChunk
from vocode.streaming.output_device.blocking_speaker_output import BlockingSpeakerOutput
from vocode.streaming.synthesizer.azure_synthesizer import AzureSynthesizer
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer
from vocode.streaming.utils import get_chunk_size_per_second
from vocode.streaming.utils.worker import InterruptibleEvent

if __name__ == "__main__":
import asyncio
Expand All @@ -19,7 +21,7 @@

async def speak(
synthesizer: BaseSynthesizer,
output_device: BaseOutputDevice,
output_device: AbstractOutputDevice,
message: BaseMessage,
):
message_sent = message.text
Expand All @@ -38,7 +40,9 @@ async def speak(
try:
start_time = time.time()
speech_length_seconds = seconds_per_chunk * (len(chunk_result.chunk) / chunk_size)
output_device.consume_nonblocking(chunk_result.chunk)
output_device.consume_nonblocking(
InterruptibleEvent(payload=AudioChunk(data=chunk_result.chunk))
)
end_time = time.time()
await asyncio.sleep(
max(
Expand Down
54 changes: 51 additions & 3 deletions tests/fakedata/conversation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Optional

from pytest_mock import MockerFixture
Expand All @@ -8,7 +10,8 @@
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig
from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig
from vocode.streaming.output_device.base_output_device import BaseOutputDevice
from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice
from vocode.streaming.output_device.audio_chunk import ChunkState
from vocode.streaming.streaming_conversation import StreamingConversation
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer
from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE
Expand Down Expand Up @@ -36,8 +39,53 @@
)


class DummyOutputDevice(BaseOutputDevice):
def consume_nonblocking(self, chunk: bytes):
class DummyOutputDevice(AbstractOutputDevice):

def __init__(
self,
sampling_rate: int,
audio_encoding: AudioEncoding,
wait_for_interrupt: bool = False,
chunks_before_interrupt: int = 1,
):
super().__init__(sampling_rate, audio_encoding)
self.wait_for_interrupt = wait_for_interrupt
self.chunks_before_interrupt = chunks_before_interrupt
self.interrupt_event = asyncio.Event()

async def process(self, item):
self.interruptible_event = item
audio_chunk = item.payload

if item.is_interrupted():
audio_chunk.on_interrupt()
audio_chunk.state = ChunkState.INTERRUPTED
else:
audio_chunk.on_play()
audio_chunk.state = ChunkState.PLAYED
self.interruptible_event.is_interruptible = False

async def _run_loop(self):
chunk_counter = 0
while True:
try:
item = await self.input_queue.get()
except asyncio.CancelledError:
return
if self.wait_for_interrupt and chunk_counter == self.chunks_before_interrupt:
await self.interrupt_event.wait()
await self.process(item)
chunk_counter += 1

def flush(self):
while True:
try:
item = self.input_queue.get_nowait()
except asyncio.QueueEmpty:
break
self.process(item)

def interrupt(self):
pass


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio

import pytest

from vocode.streaming.models.audio import AudioEncoding
from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState
from vocode.streaming.output_device.rate_limit_interruptions_output_device import (
RateLimitInterruptionsOutputDevice,
)
from vocode.streaming.utils.worker import InterruptibleEvent


class DummyRateLimitInterruptionsOutputDevice(RateLimitInterruptionsOutputDevice):
async def play(self, chunk: bytes):
pass


@pytest.mark.asyncio
async def test_calls_callbacks():
output_device = DummyRateLimitInterruptionsOutputDevice(
sampling_rate=16000, audio_encoding=AudioEncoding.LINEAR16
)

played_event = asyncio.Event()
interrupted_event = asyncio.Event()
uninterruptible_played_event = asyncio.Event()

def on_play():
played_event.set()

def on_interrupt():
interrupted_event.set()

def uninterruptible_on_play():
uninterruptible_played_event.set()

played_audio_chunk = AudioChunk(data=b"")
played_audio_chunk.on_play = on_play

interrupted_audio_chunk = AudioChunk(data=b"")
interrupted_audio_chunk.on_interrupt = on_interrupt

uninterruptible_audio_chunk = AudioChunk(data=b"")
uninterruptible_audio_chunk.on_play = uninterruptible_on_play

interruptible_event = InterruptibleEvent(payload=interrupted_audio_chunk, is_interruptible=True)
interruptible_event.interruption_event.set()

uninterruptible_event = InterruptibleEvent(
payload=uninterruptible_audio_chunk, is_interruptible=False
)
uninterruptible_event.interruption_event.set()

output_device.consume_nonblocking(InterruptibleEvent(payload=played_audio_chunk))
output_device.consume_nonblocking(interruptible_event)
output_device.consume_nonblocking(uninterruptible_event)
output_device.start()

await played_event.wait()
assert played_audio_chunk.state == ChunkState.PLAYED

await interrupted_event.wait()
assert interrupted_audio_chunk.state == ChunkState.INTERRUPTED

await uninterruptible_played_event.wait()
assert uninterruptible_audio_chunk.state == ChunkState.PLAYED

output_device.terminate()
139 changes: 139 additions & 0 deletions tests/streaming/output_device/test_twilio_output_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import asyncio
import base64
import json

import pytest
from pytest_mock import MockerFixture

from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState
from vocode.streaming.output_device.twilio_output_device import (
ChunkFinishedMarkMessage,
TwilioOutputDevice,
)
from vocode.streaming.utils.worker import InterruptibleEvent


@pytest.fixture
def mock_ws(mocker: MockerFixture):
return mocker.AsyncMock()


@pytest.fixture
def mock_stream_sid():
return "stream_sid"


@pytest.fixture
def twilio_output_device(mock_ws, mock_stream_sid):
return TwilioOutputDevice(ws=mock_ws, stream_sid=mock_stream_sid)


@pytest.mark.asyncio
async def test_calls_play_callbacks(twilio_output_device: TwilioOutputDevice):
played_event = asyncio.Event()

def on_play():
played_event.set()

audio_chunk = AudioChunk(data=b"")
audio_chunk.on_play = on_play

twilio_output_device.consume_nonblocking(InterruptibleEvent(payload=audio_chunk))
twilio_output_device.start()
twilio_output_device.enqueue_mark_message(
ChunkFinishedMarkMessage(chunk_id=str(audio_chunk.chunk_id))
)

await played_event.wait()
assert audio_chunk.state == ChunkState.PLAYED

media_message = json.loads(twilio_output_device.ws.send_text.call_args_list[0][0][0])
assert media_message["streamSid"] == twilio_output_device.stream_sid
assert media_message["media"] == {"payload": base64.b64encode(audio_chunk.data).decode("utf-8")}

mark_message = json.loads(twilio_output_device.ws.send_text.call_args_list[1][0][0])
assert mark_message["streamSid"] == twilio_output_device.stream_sid
assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id)

twilio_output_device.terminate()


@pytest.mark.asyncio
async def test_calls_interrupt_callbacks(twilio_output_device: TwilioOutputDevice):
interrupted_event = asyncio.Event()

def on_interrupt():
interrupted_event.set()

audio_chunk = AudioChunk(data=b"")
audio_chunk.on_interrupt = on_interrupt

interruptible_event = InterruptibleEvent(payload=audio_chunk, is_interruptible=True)

twilio_output_device.consume_nonblocking(interruptible_event)
# we start the twilio events task and the mark messages task manually to test this particular case

# step 1: media is sent into the websocket
send_twilio_messages_task = asyncio.create_task(twilio_output_device._send_twilio_messages())

while not twilio_output_device._twilio_events_queue.empty():
await asyncio.sleep(0.1)

# step 2: we get an interrupt
interruptible_event.interrupt()
twilio_output_device.interrupt()

# note: this means that the time between the events being interrupted and the clear message being sent, chunks
# will be marked interrupted - this is OK since the clear message is sent almost instantaneously once queued
# this is required because it stops queueing new chunks to be sent to the WS immediately

while not twilio_output_device._twilio_events_queue.empty():
await asyncio.sleep(0.1)

# step 3: we get a mark message for the interrupted audio chunk after the clear message
twilio_output_device.enqueue_mark_message(
ChunkFinishedMarkMessage(chunk_id=str(audio_chunk.chunk_id))
)
process_mark_messages_task = asyncio.create_task(twilio_output_device._process_mark_messages())

await interrupted_event.wait()
assert audio_chunk.state == ChunkState.INTERRUPTED

media_message = json.loads(twilio_output_device.ws.send_text.call_args_list[0][0][0])
assert media_message["streamSid"] == twilio_output_device.stream_sid
assert media_message["media"] == {"payload": base64.b64encode(audio_chunk.data).decode("utf-8")}

mark_message = json.loads(twilio_output_device.ws.send_text.call_args_list[1][0][0])
assert mark_message["streamSid"] == twilio_output_device.stream_sid
assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id)

clear_message = json.loads(twilio_output_device.ws.send_text.call_args_list[2][0][0])
assert clear_message["streamSid"] == twilio_output_device.stream_sid
assert clear_message["event"] == "clear"

send_twilio_messages_task.cancel()
process_mark_messages_task.cancel()


@pytest.mark.asyncio
async def test_interrupted_audio_chunks_are_not_sent_but_are_marked_interrupted(
twilio_output_device: TwilioOutputDevice,
):
interrupted_event = asyncio.Event()

def on_interrupt():
interrupted_event.set()

audio_chunk = AudioChunk(data=b"")
audio_chunk.on_interrupt = on_interrupt

interruptible_event = InterruptibleEvent(payload=audio_chunk, is_interruptible=True)
interruptible_event.interrupt()

twilio_output_device.consume_nonblocking(interruptible_event)
twilio_output_device.start()

await interrupted_event.wait()
assert audio_chunk.state == ChunkState.INTERRUPTED

twilio_output_device.ws.send_text.assert_not_called()
Loading

0 comments on commit 60d2187

Please sign in to comment.