Skip to content

Commit

Permalink
Merge branch 'main' into ajar98/update-vocodehq-public-20240705202045
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Jul 6, 2024
2 parents 87308ba + 6b41941 commit 29a9f3f
Show file tree
Hide file tree
Showing 31 changed files with 482 additions and 330 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"rewrap.wrappingColumn": 100,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"isort.check": true
"isort.check": true,
"python.testing.pytestArgs": ["tests"]
}
28 changes: 17 additions & 11 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import AbstractConversationStateManager
from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent
from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent, QueueConsumer

load_dotenv()

from vocode.streaming.agent import ChatGPTAgent
from vocode.streaming.agent.base_agent import (
AgentResponse,
AgentResponseMessage,
AgentResponseType,
BaseAgent,
Expand Down Expand Up @@ -96,6 +97,11 @@ async def run_agent(
):
ended = False
conversation_id = create_conversation_id()
agent_response_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = (
asyncio.Queue()
)
agent_consumer = QueueConsumer(input_queue=agent_response_queue)
agent.agent_responses_consumer = agent_consumer

async def receiver():
nonlocal ended
Expand All @@ -106,7 +112,7 @@ async def receiver():

while not ended:
try:
event = await agent.get_output_queue().get()
event = await agent_response_queue.get()
response = event.payload
if response.type == AgentResponseType.FILLER_AUDIO:
print("Would have sent filler audio")
Expand Down Expand Up @@ -152,6 +158,13 @@ async def receiver():
break

async def sender():
if agent.agent_config.initial_message is not None:
agent.agent_responses_consumer.consume_nonblocking(
InterruptibleAgentResponseEvent(
payload=AgentResponseMessage(message=agent.agent_config.initial_message),
agent_response_tracker=asyncio.Event(),
)
)
while not ended:
try:
message = await asyncio.get_event_loop().run_in_executor(
Expand All @@ -175,10 +188,10 @@ async def sender():
actions_worker = None
if isinstance(agent, ChatGPTAgent):
actions_worker = ActionsWorker(
input_queue=agent.actions_queue,
output_queue=agent.get_input_queue(),
action_factory=agent.action_factory,
)
actions_worker.consumer = agent
agent.actions_consumer = actions_worker
actions_worker.attach_conversation_state_manager(agent.conversation_state_manager)
actions_worker.start()

Expand Down Expand Up @@ -215,13 +228,6 @@ async def agent_main():
)
agent.attach_conversation_state_manager(DummyConversationManager())
agent.attach_transcript(transcript)
if agent.agent_config.initial_message is not None:
agent.output_queue.put_nowait(
InterruptibleAgentResponseEvent(
payload=AgentResponseMessage(message=agent.agent_config.initial_message),
agent_response_tracker=asyncio.Event(),
)
)
agent.start()

try:
Expand Down
26 changes: 17 additions & 9 deletions playground/streaming/transcriber/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from vocode.streaming.input_device.file_input_device import FileInputDevice
from vocode.streaming.input_device.microphone_input import MicrophoneInput
from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, Transcription
from vocode.streaming.transcriber.base_transcriber import BaseTranscriber
from vocode.streaming.transcriber.deepgram_transcriber import (
DeepgramEndpointingConfig,
DeepgramTranscriber,
)
from vocode.streaming.utils.worker import AsyncWorker


class TranscriptionPrinter(AsyncWorker[Transcription]):
async def _run_loop(self):
while True:
transcription: Transcription = await self._input_queue.get()
print(transcription)


if __name__ == "__main__":
import asyncio
Expand All @@ -13,25 +23,23 @@

load_dotenv()

async def print_output(transcriber: BaseTranscriber):
while True:
transcription: Transcription = await transcriber.output_queue.get()
print(transcription)

async def listen():
microphone_input = MicrophoneInput.from_default_device()
input_device = MicrophoneInput.from_default_device()
# input_device = FileInputDevice(file_path="spacewalk.wav")

# replace with the transcriber you want to test
transcriber = DeepgramTranscriber(
DeepgramTranscriberConfig.from_input_device(
microphone_input, endpointing_config=DeepgramEndpointingConfig()
input_device, endpointing_config=DeepgramEndpointingConfig()
)
)
transcriber.start()
asyncio.create_task(print_output(transcriber))
transcription_printer = TranscriptionPrinter()
transcriber.consumer = transcription_printer
transcription_printer.start()
print("Start speaking...press Ctrl+C to end. ")
while True:
chunk = await microphone_input.get_audio()
chunk = await input_device.get_audio()
transcriber.send_audio(chunk)

asyncio.run(listen())
7 changes: 4 additions & 3 deletions tests/fakedata/conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from typing import Optional

from pytest_mock import MockerFixture
Expand Down Expand Up @@ -52,6 +51,7 @@ def __init__(
self.wait_for_interrupt = wait_for_interrupt
self.chunks_before_interrupt = chunks_before_interrupt
self.interrupt_event = asyncio.Event()
self.dummy_playback_queue = asyncio.Queue()

async def process(self, item):
self.interruptible_event = item
Expand All @@ -61,6 +61,7 @@ async def process(self, item):
audio_chunk.on_interrupt()
audio_chunk.state = ChunkState.INTERRUPTED
else:
self.dummy_playback_queue.put_nowait(audio_chunk)
audio_chunk.on_play()
audio_chunk.state = ChunkState.PLAYED
self.interruptible_event.is_interruptible = False
Expand All @@ -69,7 +70,7 @@ async def _run_loop(self):
chunk_counter = 0
while True:
try:
item = await self.input_queue.get()
item = await self._input_queue.get()
except asyncio.CancelledError:
return
if self.wait_for_interrupt and chunk_counter == self.chunks_before_interrupt:
Expand All @@ -80,7 +81,7 @@ async def _run_loop(self):
def flush(self):
while True:
try:
item = self.input_queue.get_nowait()
item = self._input_queue.get_nowait()
except asyncio.QueueEmpty:
break
self.process(item)
Expand Down
50 changes: 50 additions & 0 deletions tests/fixtures/synthesizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import wave
from io import BytesIO

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import SynthesizerConfig
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer, SynthesisResult


def create_fake_audio(message: str, synthesizer_config: SynthesizerConfig):
file = BytesIO()
with wave.open(file, "wb") as wave_file:
wave_file.setnchannels(1)
wave_file.setsampwidth(2)
wave_file.setframerate(synthesizer_config.sampling_rate)
wave_file.writeframes(message.encode())
file.seek(0)
return file


class TestSynthesizerConfig(SynthesizerConfig, type="synthesizer_test"):
__test__ = False


class TestSynthesizer(BaseSynthesizer[TestSynthesizerConfig]):
"""Accepts text and creates a SynthesisResult containing audio data which is the same as the text as bytes."""

__test__ = False

def __init__(self, synthesizer_config: SynthesizerConfig):
super().__init__(synthesizer_config)

async def create_speech_uncached(
self,
message: BaseMessage,
chunk_size: int,
is_first_text_chunk: bool = False,
is_sole_text_chunk: bool = False,
) -> SynthesisResult:
return self.create_synthesis_result_from_wav(
synthesizer_config=self.synthesizer_config,
message=message,
chunk_size=chunk_size,
file=create_fake_audio(
message=message.text, synthesizer_config=self.synthesizer_config
),
)

@classmethod
def get_voice_identifier(cls, synthesizer_config: TestSynthesizerConfig) -> str:
return "test_voice"
24 changes: 24 additions & 0 deletions tests/fixtures/transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio

from vocode.streaming.models.transcriber import TranscriberConfig
from vocode.streaming.transcriber.base_transcriber import BaseAsyncTranscriber, Transcription


class TestTranscriberConfig(TranscriberConfig, type="transcriber_test"):
__test__ = False


class TestAsyncTranscriber(BaseAsyncTranscriber[TestTranscriberConfig]):
"""Accepts fake audio chunks and sends out transcriptions which are the same as the audio chunks."""

__test__ = False

async def _run_loop(self):
while True:
try:
audio_chunk = await self._input_queue.get()
self.produce_nonblocking(
Transcription(message=audio_chunk.decode("utf-8"), confidence=1, is_final=True)
)
except asyncio.CancelledError:
return
48 changes: 15 additions & 33 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from vocode.streaming.models.transcriber import Transcription
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
from vocode.streaming.utils.worker import InterruptibleEvent
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
QueueConsumer,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -51,11 +55,16 @@ def _create_agent(
return agent


async def _consume_until_end_of_turn(agent: BaseAgent, timeout: float = 0.1) -> List[AgentResponse]:
async def _consume_until_end_of_turn(
agent_consumer: QueueConsumer[InterruptibleAgentResponseEvent[AgentResponse]],
timeout: float = 0.1,
) -> List[AgentResponse]:
agent_responses = []
try:
while True:
agent_response = await asyncio.wait_for(agent.output_queue.get(), timeout=timeout)
agent_response = await asyncio.wait_for(
agent_consumer.input_queue.get(), timeout=timeout
)
agent_responses.append(agent_response.payload)
if isinstance(agent_response.payload, AgentResponseMessage) and isinstance(
agent_response.payload.message, EndOfTurn
Expand Down Expand Up @@ -127,37 +136,10 @@ async def test_generate_responses(mocker: MockerFixture):
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent_consumer = QueueConsumer()
agent.agent_responses_consumer = agent_consumer
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent.terminate()

messages = [response.message for response in agent_responses]

assert messages == [BaseMessage(text="Hi, how are you doing today?"), EndOfTurn()]


@pytest.mark.asyncio
async def test_generate_response(mocker: MockerFixture):
agent_config = ChatGPTAgentConfig(
prompt_preamble="Have a pleasant conversation about life",
generate_responses=True,
)
agent = _create_agent(mocker, agent_config)
_mock_generate_response(
mocker,
agent,
[
GeneratedResponse(
message=BaseMessage(text="Hi, how are you doing today?"), is_interruptible=True
)
],
)
_send_transcription(
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent_responses = await _consume_until_end_of_turn(agent_consumer)
agent.terminate()

messages = [response.message for response in agent_responses]
Expand Down
Loading

0 comments on commit 29a9f3f

Please sign in to comment.