Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds twilio dtmf action #623

Merged
merged 9 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions tests/streaming/action/test_dtmf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import asyncio
import base64
import json

import pytest
from aioresponses import aioresponses
from pytest_mock import MockerFixture

from tests.fakedata.id import generate_uuid
from vocode.streaming.action.dtmf import (
Expand All @@ -12,9 +17,15 @@
TwilioPhoneConversationActionInput,
VonagePhoneConversationActionInput,
)
from vocode.streaming.models.audio import AudioEncoding
from vocode.streaming.models.telephony import VonageConfig
from vocode.streaming.output_device.twilio_output_device import TwilioOutputDevice
from vocode.streaming.utils import create_conversation_id
from vocode.streaming.utils.state_manager import VonagePhoneConversationStateManager
from vocode.streaming.utils.dtmf_utils import DTMFToneGenerator
from vocode.streaming.utils.state_manager import (
TwilioPhoneConversationStateManager,
VonagePhoneConversationStateManager,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -59,12 +70,79 @@ async def test_vonage_dtmf_press_digits(mocker, mock_env):
assert action_output.response.success is True


@pytest.fixture
def mock_twilio_output_device(mocker: MockerFixture):
output_device = TwilioOutputDevice()
output_device.ws = mocker.AsyncMock()
output_device.stream_sid = "stream_sid"
return output_device


@pytest.fixture
def mock_twilio_phone_conversation(
mocker: MockerFixture, mock_twilio_output_device: TwilioOutputDevice
):
twilio_phone_conversation_mock = mocker.MagicMock()
twilio_phone_conversation_mock.output_device = mock_twilio_output_device
return twilio_phone_conversation_mock


@pytest.mark.asyncio
async def test_twilio_dtmf_press_digits(mocker, mock_env):
async def test_twilio_dtmf_press_digits(
mocker, mock_env, mock_twilio_phone_conversation, mock_twilio_output_device: TwilioOutputDevice
):
action = TwilioDTMF(action_config=DTMFVocodeActionConfig())
digits = "1234"
twilio_sid = "twilio_sid"

action.attach_conversation_state_manager(
TwilioPhoneConversationStateManager(mock_twilio_phone_conversation)
)

action_output = await action.run(
action_input=TwilioPhoneConversationActionInput(
action_config=DTMFVocodeActionConfig(),
conversation_id=create_conversation_id(),
params=DTMFParameters(buttons=digits),
twilio_sid=twilio_sid,
)
)

mock_twilio_output_device.start()
max_wait_seconds = 1
waited_seconds = 0
while mock_twilio_output_device.ws.send_text.call_count < len(digits):
await asyncio.sleep(0.1)
waited_seconds += 0.1
if waited_seconds > max_wait_seconds:
assert False, "Timed out waiting for DTMF tones to be sent"

assert action_output.response.success
mock_twilio_output_device.terminate()

for digit, call in zip(digits, mock_twilio_output_device.ws.send_text.call_args_list):
expected_dtmf = DTMFToneGenerator().generate(
digit, sampling_rate=8000, audio_encoding=AudioEncoding.MULAW
)
media_message = json.loads(call[0][0])
assert media_message["streamSid"] == mock_twilio_output_device.stream_sid
assert media_message["media"] == {
"payload": base64.b64encode(expected_dtmf).decode("utf-8")
}


@pytest.mark.asyncio
async def test_twilio_dtmf_failure(
mocker, mock_env, mock_twilio_phone_conversation, mock_twilio_output_device: TwilioOutputDevice
):
action = TwilioDTMF(action_config=DTMFVocodeActionConfig())
digits = "****"
twilio_sid = "twilio_sid"

action.attach_conversation_state_manager(
TwilioPhoneConversationStateManager(mock_twilio_phone_conversation)
)

action_output = await action.run(
action_input=TwilioPhoneConversationActionInput(
action_config=DTMFVocodeActionConfig(),
Expand All @@ -74,4 +152,5 @@ async def test_twilio_dtmf_press_digits(mocker, mock_env):
)
)

assert action_output.response.success is False # Twilio does not support DTMF
assert not action_output.response.success
assert action_output.response.message == "Invalid DTMF buttons, can only accept 0-9"
16 changes: 16 additions & 0 deletions tests/streaming/output_device/test_twilio_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ChunkFinishedMarkMessage,
TwilioOutputDevice,
)
from vocode.streaming.utils.dtmf_utils import DTMFToneGenerator, KeypadEntry
from vocode.streaming.utils.singleton import SingletonMeta
from vocode.streaming.utils.worker import InterruptibleEvent


Expand Down Expand Up @@ -137,3 +139,17 @@ def on_interrupt():
assert audio_chunk.state == ChunkState.INTERRUPTED

twilio_output_device.ws.send_text.assert_not_called()


def test_dtmf_tone_generator_caches(
twilio_output_device: TwilioOutputDevice, mocker: MockerFixture
):
del SingletonMeta._instances[DTMFToneGenerator]
lin2ulaw_mock = mocker.patch(
"audioop.lin2ulaw",
return_value=b"ulaw_encoded",
)

twilio_output_device.send_dtmf_tones([KeypadEntry.ONE, KeypadEntry.ONE])

lin2ulaw_mock.assert_called_once()
29 changes: 25 additions & 4 deletions vocode/streaming/action/dtmf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type
from typing import List, Optional, Type

from loguru import logger
from pydantic.v1 import BaseModel, Field
Expand All @@ -9,6 +9,7 @@
)
from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig
from vocode.streaming.models.actions import ActionInput, ActionOutput
from vocode.streaming.utils.dtmf_utils import DTMFToneGenerator, KeypadEntry
from vocode.streaming.utils.state_manager import (
TwilioPhoneConversationStateManager,
VonagePhoneConversationStateManager,
Expand All @@ -21,6 +22,7 @@ class DTMFParameters(BaseModel):

class DTMFResponse(BaseModel):
success: bool
message: Optional[str] = None


class DTMFVocodeActionConfig(VocodeActionConfig, type="action_dtmf"): # type: ignore
Expand All @@ -31,7 +33,12 @@ def action_attempt_to_string(self, input: ActionInput) -> str:
def action_result_to_string(self, input: ActionInput, output: ActionOutput) -> str:
assert isinstance(input.params, DTMFParameters)
assert isinstance(output.response, DTMFResponse)
return f"Pressed numbers {list(input.params.buttons)} successfully"
if output.response.success:
return f"Pressed numbers {list(input.params.buttons)} successfully"
else:
return (
f"Failed to press numbers {list(input.params.buttons)}: {output.response.message}"
)


FUNCTION_DESCRIPTION = "Presses a string numbers using DTMF tones."
Expand Down Expand Up @@ -76,8 +83,22 @@ def __init__(self, action_config: DTMFVocodeActionConfig):
)

async def run(self, action_input: ActionInput[DTMFParameters]) -> ActionOutput[DTMFResponse]:
logger.error("DTMF not yet supported with Twilio")
buttons = action_input.params.buttons
keypad_entries: List[KeypadEntry]
try:
keypad_entries = [KeypadEntry(button) for button in buttons]
except ValueError:
logger.warning(f"Invalid DTMF buttons: {buttons}")
return ActionOutput(
action_type=action_input.action_config.type,
response=DTMFResponse(
success=False, message="Invalid DTMF buttons, can only accept 0-9"
),
)
self.conversation_state_manager._twilio_phone_conversation.output_device.send_dtmf_tones(
keypad_entries=keypad_entries
)
return ActionOutput(
action_type=action_input.action_config.type,
response=DTMFResponse(success=False),
response=DTMFResponse(success=True),
)
19 changes: 17 additions & 2 deletions vocode/streaming/output_device/twilio_output_device.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import asyncio
import audioop
import base64
import json
import uuid
from typing import Optional, Union
from typing import List, Optional, Union

from fastapi import WebSocket
from fastapi.websockets import WebSocketState
Expand All @@ -15,6 +15,7 @@
from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState
from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE
from vocode.streaming.utils.create_task import asyncio_create_task
from vocode.streaming.utils.dtmf_utils import DTMFToneGenerator, KeypadEntry
from vocode.streaming.utils.worker import InterruptibleEvent


Expand Down Expand Up @@ -55,6 +56,20 @@ def interrupt(self):
def enqueue_mark_message(self, mark_message: MarkMessage):
self._mark_message_queue.put_nowait(mark_message)

def send_dtmf_tones(self, keypad_entries: List[KeypadEntry]):
tone_generator = DTMFToneGenerator()
for keypad_entry in keypad_entries:
logger.info(f"Sending DTMF tone {keypad_entry.value}")
dtmf_tone = tone_generator.generate(
keypad_entry, sampling_rate=self.sampling_rate, audio_encoding=self.audio_encoding
)
dtmf_message = {
"event": "media",
"streamSid": self.stream_sid,
"media": {"payload": base64.b64encode(dtmf_tone).decode("utf-8")},
}
self._twilio_events_queue.put_nowait(json.dumps(dtmf_message))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these get interrupted? And if so is a mark necessary as well, so using _send_audio_chunk_and_mark()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these cannot get interrupted (and thus mirrors the vonage implementation). if we send an event to _twilio_events_queue we can expect that it gets sent to the ws unless the output device gets torn down (when the conversation ends)


async def _send_twilio_messages(self):
while True:
try:
Expand Down
65 changes: 65 additions & 0 deletions vocode/streaming/utils/dtmf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import audioop
from enum import Enum
from typing import Dict, Tuple

import numpy as np

from vocode.streaming.models.audio import AudioEncoding
from vocode.streaming.utils.singleton import Singleton

DEFAULT_DTMF_TONE_LENGTH_SECONDS = 0.3
MAX_INT = 32767


class KeypadEntry(str, Enum):
ONE = "1"
TWO = "2"
THREE = "3"
FOUR = "4"
FIVE = "5"
SIX = "6"
SEVEN = "7"
EIGHT = "8"
NINE = "9"
ZERO = "0"


DTMF_FREQUENCIES = {
KeypadEntry.ONE: (697, 1209),
KeypadEntry.TWO: (697, 1336),
KeypadEntry.THREE: (697, 1477),
KeypadEntry.FOUR: (770, 1209),
KeypadEntry.FIVE: (770, 1336),
KeypadEntry.SIX: (770, 1477),
KeypadEntry.SEVEN: (852, 1209),
KeypadEntry.EIGHT: (852, 1336),
KeypadEntry.NINE: (852, 1477),
KeypadEntry.ZERO: (941, 1336),
}


class DTMFToneGenerator(Singleton):

def __init__(self):
self.tone_cache: Dict[Tuple[KeypadEntry, int, AudioEncoding], bytes] = {}

def generate(
self,
keypad_entry: KeypadEntry,
sampling_rate: int,
audio_encoding: AudioEncoding,
duration_seconds: float = DEFAULT_DTMF_TONE_LENGTH_SECONDS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any use of duration_seconds. Is it really necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary to make this work on Twilio per se - i'd rather keep it in the case it ever becomes useful. theoretically the lower you go, the faster the action can run

) -> bytes:
if (keypad_entry, sampling_rate, audio_encoding) in self.tone_cache:
return self.tone_cache[(keypad_entry, sampling_rate, audio_encoding)]
f1, f2 = DTMF_FREQUENCIES[keypad_entry]
t = np.linspace(0, duration_seconds, int(sampling_rate * duration_seconds), endpoint=False)
tone = np.sin(2 * np.pi * f1 * t) + np.sin(2 * np.pi * f2 * t)
tone = tone / np.max(np.abs(tone)) # Normalize to [-1, 1]
pcm = (tone * MAX_INT).astype(np.int16).tobytes()
if audio_encoding == AudioEncoding.MULAW:
output = audioop.lin2ulaw(pcm, 2)
else:
output = pcm
self.tone_cache[(keypad_entry, sampling_rate, audio_encoding)] = output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the tone cache necessary? Like this cache is only useful if the agent decides to press multiple tones in a single turn and there are repeated numbers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a singleton, so it would get cached for the course of a program (e.g. for the time for a FastAPI server to remain up)

return output
Loading