Skip to content

Commit

Permalink
adds back streamingconversation test
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Jul 3, 2024
1 parent f9c3124 commit c3011b6
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"rewrap.wrappingColumn": 100,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"isort.check": true
"isort.check": true,
"python.testing.pytestArgs": ["tests"]
}
3 changes: 2 additions & 1 deletion tests/fakedata/conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from typing import Optional

from pytest_mock import MockerFixture
Expand Down Expand Up @@ -52,6 +51,7 @@ def __init__(
self.wait_for_interrupt = wait_for_interrupt
self.chunks_before_interrupt = chunks_before_interrupt
self.interrupt_event = asyncio.Event()
self.dummy_playback_queue = asyncio.Queue()

async def process(self, item):
self.interruptible_event = item
Expand All @@ -61,6 +61,7 @@ async def process(self, item):
audio_chunk.on_interrupt()
audio_chunk.state = ChunkState.INTERRUPTED
else:
self.dummy_playback_queue.put_nowait(audio_chunk)
audio_chunk.on_play()
audio_chunk.state = ChunkState.PLAYED
self.interruptible_event.is_interruptible = False
Expand Down
50 changes: 50 additions & 0 deletions tests/fixtures/synthesizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import wave
from io import BytesIO

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import SynthesizerConfig
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer, SynthesisResult


def create_fake_audio(message: str, synthesizer_config: SynthesizerConfig):
file = BytesIO()
with wave.open(file, "wb") as wave_file:
wave_file.setnchannels(1)
wave_file.setsampwidth(2)
wave_file.setframerate(synthesizer_config.sampling_rate)
wave_file.writeframes(message.encode())
file.seek(0)
return file


class TestSynthesizerConfig(SynthesizerConfig, type="synthesizer_test"):
__test__ = False


class TestSynthesizer(BaseSynthesizer[TestSynthesizerConfig]):
"""Accepts text and creates a SynthesisResult containing audio data which is the same as the text as bytes."""

__test__ = False

def __init__(self, synthesizer_config: SynthesizerConfig):
super().__init__(synthesizer_config)

async def create_speech_uncached(
self,
message: BaseMessage,
chunk_size: int,
is_first_text_chunk: bool = False,
is_sole_text_chunk: bool = False,
) -> SynthesisResult:
return self.create_synthesis_result_from_wav(
synthesizer_config=self.synthesizer_config,
message=message,
chunk_size=chunk_size,
file=create_fake_audio(
message=message.text, synthesizer_config=self.synthesizer_config
),
)

@classmethod
def get_voice_identifier(cls, synthesizer_config: TestSynthesizerConfig) -> str:
return "test_voice"
24 changes: 24 additions & 0 deletions tests/fixtures/transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio

from vocode.streaming.models.transcriber import TranscriberConfig
from vocode.streaming.transcriber.base_transcriber import BaseAsyncTranscriber, Transcription


class TestTranscriberConfig(TranscriberConfig, type="transcriber_test"):
__test__ = False


class TestAsyncTranscriber(BaseAsyncTranscriber[TestTranscriberConfig]):
"""Accepts fake audio chunks and sends out transcriptions which are the same as the audio chunks."""

__test__ = False

async def _run_loop(self):
while True:
try:
audio_chunk = await self.input_queue.get()
self.produce_nonblocking(
Transcription(message=audio_chunk.decode("utf-8"), confidence=1, is_final=True)
)
except asyncio.CancelledError:
return
38 changes: 37 additions & 1 deletion tests/streaming/test_streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@

from tests.fakedata.conversation import (
DEFAULT_CHAT_GPT_AGENT_CONFIG,
DummyOutputDevice,
create_fake_agent,
create_fake_streaming_conversation,
)
from tests.fixtures.synthesizer import TestSynthesizer, TestSynthesizerConfig
from tests.fixtures.transcriber import TestAsyncTranscriber, TestTranscriberConfig
from vocode.streaming.agent.echo_agent import EchoAgent
from vocode.streaming.models.actions import ActionInput
from vocode.streaming.models.agent import InterruptSensitivity
from vocode.streaming.models.agent import EchoAgentConfig, InterruptSensitivity
from vocode.streaming.models.audio import AudioEncoding
from vocode.streaming.models.events import Sender
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.transcriber import Transcription
from vocode.streaming.models.transcript import ActionStart, Message, Transcript
from vocode.streaming.streaming_conversation import StreamingConversation
from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult
from vocode.streaming.telephony.constants import DEFAULT_SAMPLING_RATE
from vocode.streaming.utils.worker import AsyncWorker, QueueConsumer


Expand Down Expand Up @@ -583,3 +591,31 @@ async def chunk_generator():
assert cut_off
assert transcript_message.text != "Hi there"
assert not transcript_message.is_final


@pytest.mark.asyncio
async def test_streaming_conversation_pipeline(
mocker: MockerFixture,
):
output_device = DummyOutputDevice(sampling_rate=48000, audio_encoding=AudioEncoding.LINEAR16)
streaming_conversation = StreamingConversation(
output_device=output_device,
transcriber=TestAsyncTranscriber(
TestTranscriberConfig(
sampling_rate=48000,
audio_encoding=AudioEncoding.LINEAR16,
chunk_size=480,
)
),
agent=EchoAgent(
EchoAgentConfig(initial_message=BaseMessage(text="Hi there")),
),
synthesizer=TestSynthesizer(TestSynthesizerConfig.from_output_device(output_device)),
)
await streaming_conversation.start()
await streaming_conversation.initial_message_tracker.wait()
streaming_conversation.receive_audio(b"test")
initial_message_audio_chunk = await output_device.dummy_playback_queue.get()
assert initial_message_audio_chunk.data == b"Hi there"
first_response_audio_chunk = await output_device.dummy_playback_queue.get()
assert first_response_audio_chunk.data == b"test"

0 comments on commit c3011b6

Please sign in to comment.