Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOW-113] deprecate output queue and manually attach workers to each other #593

Merged
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"rewrap.wrappingColumn": 100,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"isort.check": true
"isort.check": true,
"python.testing.pytestArgs": ["tests"]
}
28 changes: 17 additions & 11 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import AbstractConversationStateManager
from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent
from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent, QueueConsumer

load_dotenv()

from vocode.streaming.agent import ChatGPTAgent
from vocode.streaming.agent.base_agent import (
AgentResponse,
AgentResponseMessage,
AgentResponseType,
BaseAgent,
Expand Down Expand Up @@ -96,6 +97,11 @@ async def run_agent(
):
ended = False
conversation_id = create_conversation_id()
agent_response_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = (
asyncio.Queue()
)
agent_consumer = QueueConsumer(input_queue=agent_response_queue)
agent.agent_responses_consumer = agent_consumer

async def receiver():
nonlocal ended
Expand All @@ -106,7 +112,7 @@ async def receiver():

while not ended:
try:
event = await agent.get_output_queue().get()
event = await agent_response_queue.get()
response = event.payload
if response.type == AgentResponseType.FILLER_AUDIO:
print("Would have sent filler audio")
Expand Down Expand Up @@ -152,6 +158,13 @@ async def receiver():
break

async def sender():
if agent.agent_config.initial_message is not None:
agent.agent_responses_consumer.consume_nonblocking(
InterruptibleAgentResponseEvent(
payload=AgentResponseMessage(message=agent.agent_config.initial_message),
agent_response_tracker=asyncio.Event(),
)
)
while not ended:
try:
message = await asyncio.get_event_loop().run_in_executor(
Expand All @@ -175,10 +188,10 @@ async def sender():
actions_worker = None
if isinstance(agent, ChatGPTAgent):
actions_worker = ActionsWorker(
input_queue=agent.actions_queue,
output_queue=agent.get_input_queue(),
action_factory=agent.action_factory,
)
actions_worker.consumer = agent
agent.actions_consumer = actions_worker
actions_worker.attach_conversation_state_manager(agent.conversation_state_manager)
actions_worker.start()

Expand Down Expand Up @@ -215,13 +228,6 @@ async def agent_main():
)
agent.attach_conversation_state_manager(DummyConversationManager())
agent.attach_transcript(transcript)
if agent.agent_config.initial_message is not None:
agent.output_queue.put_nowait(
InterruptibleAgentResponseEvent(
payload=AgentResponseMessage(message=agent.agent_config.initial_message),
agent_response_tracker=asyncio.Event(),
)
)
agent.start()

try:
Expand Down
18 changes: 12 additions & 6 deletions playground/streaming/transcriber/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
DeepgramEndpointingConfig,
DeepgramTranscriber,
)
from vocode.streaming.utils.worker import AsyncWorker


class TranscriptionPrinter(AsyncWorker[Transcription]):
async def _run_loop(self):
while True:
transcription: Transcription = await self.input_queue.get()
print(transcription)


if __name__ == "__main__":
import asyncio
Expand All @@ -13,11 +22,6 @@

load_dotenv()

async def print_output(transcriber: BaseTranscriber):
while True:
transcription: Transcription = await transcriber.output_queue.get()
print(transcription)

async def listen():
microphone_input = MicrophoneInput.from_default_device()

Expand All @@ -28,7 +32,9 @@ async def listen():
)
)
transcriber.start()
asyncio.create_task(print_output(transcriber))
transcription_printer = TranscriptionPrinter()
transcriber.consumer = transcription_printer
transcription_printer.start()
print("Start speaking...press Ctrl+C to end. ")
while True:
chunk = await microphone_input.get_audio()
Expand Down
3 changes: 2 additions & 1 deletion tests/fakedata/conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from typing import Optional

from pytest_mock import MockerFixture
Expand Down Expand Up @@ -52,6 +51,7 @@ def __init__(
self.wait_for_interrupt = wait_for_interrupt
self.chunks_before_interrupt = chunks_before_interrupt
self.interrupt_event = asyncio.Event()
self.dummy_playback_queue = asyncio.Queue()

async def process(self, item):
self.interruptible_event = item
Expand All @@ -61,6 +61,7 @@ async def process(self, item):
audio_chunk.on_interrupt()
audio_chunk.state = ChunkState.INTERRUPTED
else:
self.dummy_playback_queue.put_nowait(audio_chunk)
audio_chunk.on_play()
audio_chunk.state = ChunkState.PLAYED
self.interruptible_event.is_interruptible = False
Expand Down
50 changes: 50 additions & 0 deletions tests/fixtures/synthesizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import wave
from io import BytesIO

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.models.synthesizer import SynthesizerConfig
from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer, SynthesisResult


def create_fake_audio(message: str, synthesizer_config: SynthesizerConfig):
file = BytesIO()
with wave.open(file, "wb") as wave_file:
wave_file.setnchannels(1)
wave_file.setsampwidth(2)
wave_file.setframerate(synthesizer_config.sampling_rate)
wave_file.writeframes(message.encode())
file.seek(0)
return file


class TestSynthesizerConfig(SynthesizerConfig, type="synthesizer_test"):
__test__ = False


class TestSynthesizer(BaseSynthesizer[TestSynthesizerConfig]):
"""Accepts text and creates a SynthesisResult containing audio data which is the same as the text as bytes."""

__test__ = False

def __init__(self, synthesizer_config: SynthesizerConfig):
super().__init__(synthesizer_config)

async def create_speech_uncached(
self,
message: BaseMessage,
chunk_size: int,
is_first_text_chunk: bool = False,
is_sole_text_chunk: bool = False,
) -> SynthesisResult:
return self.create_synthesis_result_from_wav(
synthesizer_config=self.synthesizer_config,
message=message,
chunk_size=chunk_size,
file=create_fake_audio(
message=message.text, synthesizer_config=self.synthesizer_config
),
)

@classmethod
def get_voice_identifier(cls, synthesizer_config: TestSynthesizerConfig) -> str:
return "test_voice"
24 changes: 24 additions & 0 deletions tests/fixtures/transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio

from vocode.streaming.models.transcriber import TranscriberConfig
from vocode.streaming.transcriber.base_transcriber import BaseAsyncTranscriber, Transcription


class TestTranscriberConfig(TranscriberConfig, type="transcriber_test"):
__test__ = False


class TestAsyncTranscriber(BaseAsyncTranscriber[TestTranscriberConfig]):
"""Accepts fake audio chunks and sends out transcriptions which are the same as the audio chunks."""

__test__ = False

async def _run_loop(self):
while True:
try:
audio_chunk = await self.input_queue.get()
self.produce_nonblocking(
Transcription(message=audio_chunk.decode("utf-8"), confidence=1, is_final=True)
)
except asyncio.CancelledError:
return
48 changes: 15 additions & 33 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from vocode.streaming.models.transcriber import Transcription
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
from vocode.streaming.utils.worker import InterruptibleEvent
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
QueueConsumer,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -51,11 +55,16 @@ def _create_agent(
return agent


async def _consume_until_end_of_turn(agent: BaseAgent, timeout: float = 0.1) -> List[AgentResponse]:
async def _consume_until_end_of_turn(
agent_consumer: QueueConsumer[InterruptibleAgentResponseEvent[AgentResponse]],
timeout: float = 0.1,
) -> List[AgentResponse]:
agent_responses = []
try:
while True:
agent_response = await asyncio.wait_for(agent.output_queue.get(), timeout=timeout)
agent_response = await asyncio.wait_for(
agent_consumer.input_queue.get(), timeout=timeout
)
agent_responses.append(agent_response.payload)
if isinstance(agent_response.payload, AgentResponseMessage) and isinstance(
agent_response.payload.message, EndOfTurn
Expand Down Expand Up @@ -127,37 +136,10 @@ async def test_generate_responses(mocker: MockerFixture):
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent_consumer = QueueConsumer()
agent.agent_responses_consumer = agent_consumer
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent.terminate()

messages = [response.message for response in agent_responses]

assert messages == [BaseMessage(text="Hi, how are you doing today?"), EndOfTurn()]


@pytest.mark.asyncio
async def test_generate_response(mocker: MockerFixture):
agent_config = ChatGPTAgentConfig(
prompt_preamble="Have a pleasant conversation about life",
generate_responses=True,
)
agent = _create_agent(mocker, agent_config)
_mock_generate_response(
mocker,
agent,
[
GeneratedResponse(
message=BaseMessage(text="Hi, how are you doing today?"), is_interruptible=True
)
],
)
_send_transcription(
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent_responses = await _consume_until_end_of_turn(agent_consumer)
agent.terminate()

messages = [response.message for response in agent_responses]
Expand Down
Loading
Loading