Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Jun 28, 2024
1 parent 700b17a commit 9e28363
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 52 deletions.
48 changes: 15 additions & 33 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from vocode.streaming.models.transcriber import Transcription
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
from vocode.streaming.utils.worker import InterruptibleEvent
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
QueueConsumer,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -51,11 +55,16 @@ def _create_agent(
return agent


async def _consume_until_end_of_turn(agent: BaseAgent, timeout: float = 0.1) -> List[AgentResponse]:
async def _consume_until_end_of_turn(
agent_consumer: QueueConsumer[InterruptibleAgentResponseEvent[AgentResponse]],
timeout: float = 0.1,
) -> List[AgentResponse]:
agent_responses = []
try:
while True:
agent_response = await asyncio.wait_for(agent.output_queue.get(), timeout=timeout)
agent_response = await asyncio.wait_for(
agent_consumer.input_queue.get(), timeout=timeout
)
agent_responses.append(agent_response.payload)
if isinstance(agent_response.payload, AgentResponseMessage) and isinstance(
agent_response.payload.message, EndOfTurn
Expand Down Expand Up @@ -127,37 +136,10 @@ async def test_generate_responses(mocker: MockerFixture):
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent_consumer = QueueConsumer()
agent.agent_responses_consumer = agent_consumer
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent.terminate()

messages = [response.message for response in agent_responses]

assert messages == [BaseMessage(text="Hi, how are you doing today?"), EndOfTurn()]


@pytest.mark.asyncio
async def test_generate_response(mocker: MockerFixture):
agent_config = ChatGPTAgentConfig(
prompt_preamble="Have a pleasant conversation about life",
generate_responses=True,
)
agent = _create_agent(mocker, agent_config)
_mock_generate_response(
mocker,
agent,
[
GeneratedResponse(
message=BaseMessage(text="Hi, how are you doing today?"), is_interruptible=True
)
],
)
_send_transcription(
agent,
Transcription(message="Hello?", confidence=1.0, is_final=True),
)
agent.start()
agent_responses = await _consume_until_end_of_turn(agent)
agent_responses = await _consume_until_end_of_turn(agent_consumer)
agent.terminate()

messages = [response.message for response in agent_responses]
Expand Down
47 changes: 28 additions & 19 deletions tests/streaming/test_streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vocode.streaming.models.transcriber import Transcription
from vocode.streaming.models.transcript import ActionStart, Message, Transcript
from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult
from vocode.streaming.utils.worker import AsyncWorker
from vocode.streaming.utils.worker import AsyncWorker, QueueConsumer


class ShouldIgnoreUtteranceTestCase(BaseModel):
Expand All @@ -27,9 +27,9 @@ class ShouldIgnoreUtteranceTestCase(BaseModel):
expected: bool


async def _consume_worker_output(worker: AsyncWorker, timeout: float = 0.1):
async def _get_from_consumer_queue_if_exists(queue_consumer: QueueConsumer, timeout: float = 0.1):
try:
return await asyncio.wait_for(worker.output_queue.get(), timeout=timeout)
return await asyncio.wait_for(queue_consumer.input_queue.get(), timeout=timeout)
except asyncio.TimeoutError:
return None

Expand Down Expand Up @@ -174,8 +174,6 @@ def test_should_ignore_utterance(

conversation = mocker.MagicMock()
transcriptions_worker = StreamingConversation.TranscriptionsWorker(
input_queue=mocker.MagicMock(),
output_queue=mocker.MagicMock(),
conversation=conversation,
interruptible_event_factory=mocker.MagicMock(),
)
Expand Down Expand Up @@ -253,7 +251,9 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message(
is_final=True,
),
)
assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
transcriptions_worker_consumer = QueueConsumer()
streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer
assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert not streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcript.add_bot_message(
Expand All @@ -269,8 +269,8 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message(
),
)

transcription_agent_input = await _consume_worker_output(
streaming_conversation.transcriptions_worker
transcription_agent_input = await _get_from_consumer_queue_if_exists(
transcriptions_worker_consumer
)
assert transcription_agent_input.payload.transcription.message == "hi, who is this?"
assert streaming_conversation.broadcast_interrupt.called
Expand Down Expand Up @@ -310,7 +310,10 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance(
is_final=False,
),
)
assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
transcriptions_worker_consumer = QueueConsumer()
streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer

assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response

streaming_conversation.transcript.event_logs[-1].text = (
Expand All @@ -325,7 +328,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance(
),
)

assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response

streaming_conversation.transcriptions_worker.consume_nonblocking(
Expand All @@ -336,8 +339,8 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance(
),
)

transcription_agent_input = await _consume_worker_output(
streaming_conversation.transcriptions_worker
transcription_agent_input = await _get_from_consumer_queue_if_exists(
transcriptions_worker_consumer
)
assert (
transcription_agent_input.payload.transcription.message == "I have not yet gotten a chance."
Expand Down Expand Up @@ -377,7 +380,10 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts(
),
)

assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
transcriptions_worker_consumer = QueueConsumer()
streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer

assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcriptions_worker.consume_nonblocking(
Expand All @@ -388,8 +394,8 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts(
),
)

transcription_agent_input = await _consume_worker_output(
streaming_conversation.transcriptions_worker
transcription_agent_input = await _get_from_consumer_queue_if_exists(
transcriptions_worker_consumer
)
assert (
transcription_agent_input.payload.transcription.message
Expand Down Expand Up @@ -421,7 +427,10 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun
is_final=False,
),
)
assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
transcriptions_worker_consumer = QueueConsumer()
streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer

assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcriptions_worker.consume_nonblocking(
Expand All @@ -431,8 +440,8 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun
is_final=True,
),
)
transcription_agent_input = await _consume_worker_output(
streaming_conversation.transcriptions_worker
transcription_agent_input = await _get_from_consumer_queue_if_exists(
transcriptions_worker_consumer
)
assert transcription_agent_input.payload.transcription.message == "Sorry, what?"
assert streaming_conversation.broadcast_interrupt.called
Expand All @@ -449,7 +458,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun
),
)

assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None
assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None
assert streaming_conversation.broadcast_interrupt.called

streaming_conversation.transcriptions_worker.terminate()
Expand Down

0 comments on commit 9e28363

Please sign in to comment.