Skip to content

Commit

Permalink
[DOW-107] refactor synthesizer as a worker (#571)
Browse files Browse the repository at this point in the history
* rename create_speech_uncached

* deprecate agentstop - should use terminate_conversation instead

* deprecate filleraudio from agent - if reimplemented it should go around the inner agent

* [unstable] move agentresponsesworker logic into synthesizer

* hook everything up

* deprecate AgentResponseMessage and just use AgentResponse

* few other respond refs

* add comment for tear_down vs terminate

* fix ref to create_speech_uncached

* fix playground
  • Loading branch information
ajar98 committed Jun 27, 2024
1 parent ea2b754 commit 3f2f9bc
Show file tree
Hide file tree
Showing 28 changed files with 289 additions and 348 deletions.
6 changes: 3 additions & 3 deletions apps/telephony_app/speller_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def respond(
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[Optional[str], bool]:
) -> Optional[str]:
"""Generates a response from the SpellerAgent.
The response is generated by joining each character in the human input with a space.
Expand All @@ -43,9 +43,9 @@ async def respond(
is_interrupt (bool): A flag indicating whether the agent was interrupted.
Returns:
Tuple[Optional[str], bool]: The generated response and a flag indicating whether to stop.
Optional[str]: The generated response
"""
return "".join(c + " " for c in human_input), False
return "".join(c + " " for c in human_input)


class SpellerAgentFactory(AbstractAgentFactory):
Expand Down
10 changes: 8 additions & 2 deletions docs/open-source/create-your-own-agent.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ You can subclass a [`RespondAgent`](https://github.com/vocodedev/vocode-python/b
To do so, you must create an agent type, create an agent config, and then create your agent subclass. In the examples below, we will create an agent that responds with the same message no matter what is said to it, called `BrokenRecordAgent`.

### Agent type

Each agent has a unique agent type string that is checked in various parts of Vocode, most notably in the factories that create agents. So, you must create a new type for your custom agent. See the `AgentType` enum in `vocode/streaming/models/agent.py` for examples.
For our `BrokenRecordAgent`, we will use "agent_broken_record" as our type.

### Agent config

Your agent must have a corresponding agent config that is a subclass of `AgentConfig` and is ([JSON-serializable](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json)). Serialization is automatically handled by [Pydantic](https://docs.pydantic.dev/latest/).

The agent config should only contain the information you need to deterministically create the same agent each time. This means with the same parameters in your config, the corresponding agent should have the same behavior each time you create it.

For our `BrokenRecordAgent`, we create a config like:

```python
from vocode.streaming.models.agent import AgentConfig

Expand All @@ -24,21 +27,24 @@ class BrokenRecordAgentConfig(AgentConfig, type="agent_broken_record"):
```

### Custom Agent

Now, you can create your custom agent subclass of `RespondAgent`. In your class header, pass in `RespondAgent` with a your agent type as a type hint. This should look like `RespondAgent[Your_Agent_Type]`.

Each agent should override the `generate_response()` async method to support streaming and `respond()` method to support turn-based conversations.
Each agent should override the `generate_response()` async method to support streaming and `respond()` method to support turn-based conversations.

> If you want to only support turn-based conversations, you do not have to overwrite `generate_response()` but you MUST set `generate_response=False` in your agent config (see `ChatVertexAIAgentConfig` in `vocode/streaming/models/agent.py` for an example). Otherwise, you must ALWAYS implement the `generate_response()` async method.
The `generate_response()` method returns an `AsyncGenerator` of tuples containing each message/sentence and a boolean for whether that message can be interrupted by the human speaking. You can automatically create this generator by yielding instead of returning (see example below).

We will now define our `BrokenRecordAgent`. Since we simply return the same message each time, we can return and yield that message in `respond()` and `generate_response()`, respectively:

```python
class BrokenRecordAgent(RespondAgent[BrokenRecordAgentConfig]):

# is_interrupt is True when the human has just interrupted the bot's last response
def respond(
self, human_input, is_interrupt: bool = False
) -> tuple[Optional[str], bool]:
) -> Optional[str]
return self.agent_config.message

async def generate_response(
Expand Down
8 changes: 4 additions & 4 deletions docs/open-source/telephony.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ class SpellerAgent(RespondAgent[SpellerAgentConfig]):
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[Optional[str], bool]:
return "".join(c + " " for c in human_input), False
) -> Optional[str]:
return "".join(c + " " for c in human_input)
class SpellerAgentFactory(AbstractAgentFactory):
Expand Down Expand Up @@ -182,10 +182,10 @@ class SpellerAgent(BaseAgent):
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[Optional[str], bool]:
) -> Optional[str]:
call_config = self.config_manager.get_config(conversation_id)
if call_config is not None:
from_phone = call_config.twilio_from
to_phone = call_config.twilio_to
return "".join(c + " " for c in human_input), False
return "".join(c + " " for c in human_input)
```
75 changes: 33 additions & 42 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from vocode.streaming.agent import ChatGPTAgent
from vocode.streaming.agent.base_agent import (
AgentResponse,
AgentResponseMessage,
AgentResponseType,
AgentResponse,
BaseAgent,
TranscriptionAgentInput,
)
Expand Down Expand Up @@ -113,55 +112,47 @@ async def receiver():
while not ended:
try:
event = await agent_response_queue.get()
response = event.payload
if response.type == AgentResponseType.FILLER_AUDIO:
print("Would have sent filler audio")
elif response.type == AgentResponseType.STOP:
print("Agent returned stop")
ended = True
break
elif response.type == AgentResponseType.MESSAGE:
agent_response = typing.cast(AgentResponseMessage, response)

if isinstance(agent_response.message, EndOfTurn):
ignore_until_end_of_turn = False
if random.random() < backchannel_probability:
backchannel = random.choice(BACKCHANNELS)
print("Human: " + f"[{backchannel}]")
agent.transcript.add_human_message(
backchannel,
conversation_id,
is_backchannel=True,
)
elif isinstance(agent_response.message, BaseMessage):
if ignore_until_end_of_turn:
continue

message_sent: str
is_final: bool
# TODO: consider allowing the user to interrupt the agent manually by responding fast
if random.random() < interruption_probability:
stop_idx = random.randint(0, len(agent_response.message.text))
message_sent = agent_response.message.text[:stop_idx]
ignore_until_end_of_turn = True
is_final = False
else:
message_sent = agent_response.message.text
is_final = True

agent.transcript.add_bot_message(
message_sent, conversation_id, is_final=is_final
agent_response = event.payload

if isinstance(agent_response.message, EndOfTurn):
ignore_until_end_of_turn = False
if random.random() < backchannel_probability:
backchannel = random.choice(BACKCHANNELS)
print("Human: " + f"[{backchannel}]")
agent.transcript.add_human_message(
backchannel,
conversation_id,
is_backchannel=True,
)
elif isinstance(agent_response.message, BaseMessage):
if ignore_until_end_of_turn:
continue

message_sent: str
is_final: bool
# TODO: consider allowing the user to interrupt the agent manually by responding fast
if random.random() < interruption_probability:
stop_idx = random.randint(0, len(agent_response.message.text))
message_sent = agent_response.message.text[:stop_idx]
ignore_until_end_of_turn = True
is_final = False
else:
message_sent = agent_response.message.text
is_final = True

agent.transcript.add_bot_message(
message_sent, conversation_id, is_final=is_final
)

print("AI: " + message_sent + ("-" if not is_final else ""))
print("AI: " + message_sent + ("-" if not is_final else ""))
except asyncio.CancelledError:
break

async def sender():
if agent.agent_config.initial_message is not None:
agent.agent_responses_consumer.consume_nonblocking(
InterruptibleAgentResponseEvent(
payload=AgentResponseMessage(message=agent.agent_config.initial_message),
payload=AgentResponse(message=agent.agent_config.initial_message),
agent_response_tracker=asyncio.Event(),
)
)
Expand Down
2 changes: 1 addition & 1 deletion playground/streaming/synthesizer/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def speak(
synthesizer.get_synthesizer_config().sampling_rate,
)
# ClientSession needs to be created within the async task
synthesis_result = await synthesizer.create_speech_uncached(
synthesis_result = await synthesizer.create_speech(
message=message,
chunk_size=int(chunk_size),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/streaming/agent/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vocode.streaming.action.abstract_factory import AbstractActionFactory
from vocode.streaming.agent.base_agent import (
AgentResponse,
AgentResponseMessage,
AgentResponse,
BaseAgent,
GeneratedResponse,
TranscriptionAgentInput,
Expand Down Expand Up @@ -66,7 +66,7 @@ async def _consume_until_end_of_turn(
agent_consumer.input_queue.get(), timeout=timeout
)
agent_responses.append(agent_response.payload)
if isinstance(agent_response.payload, AgentResponseMessage) and isinstance(
if isinstance(agent_response.payload, AgentResponse) and isinstance(
agent_response.payload.message, EndOfTurn
):
break
Expand Down
70 changes: 14 additions & 56 deletions vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,7 @@ class ActionResultAgentInput(AgentInput, type=AgentInputType.ACTION_RESULT.value
is_quiet: bool = False


class AgentResponseType(str, Enum):
BASE = "agent_response_base"
MESSAGE = "agent_response_message"
STOP = "agent_response_stop"
FILLER_AUDIO = "agent_response_filler_audio"


class AgentResponse(TypedModel, type=AgentResponseType.BASE.value): # type: ignore
pass


class AgentResponseMessage(AgentResponse, type=AgentResponseType.MESSAGE.value): # type: ignore
class AgentResponse(BaseModel):
message: Union[BaseMessage, EndOfTurn]
is_interruptible: bool = True
# Whether the message is the first message in the response; has metrics implications
Expand All @@ -108,17 +97,6 @@ class AgentResponseMessage(AgentResponse, type=AgentResponseType.MESSAGE.value):
is_sole_text_chunk: bool = False


class AgentResponseStop(AgentResponse, type=AgentResponseType.STOP.value): # type: ignore
pass


class AgentResponseFillerAudio(
AgentResponse,
type=AgentResponseType.FILLER_AUDIO.value, # type: ignore
):
pass


class GeneratedResponse(BaseModel):
message: Union[BaseMessage, FunctionCall, EndOfTurn]
is_interruptible: bool
Expand Down Expand Up @@ -248,7 +226,7 @@ async def handle_generate_response(
self,
transcription: Transcription,
agent_input: AgentInput,
) -> bool:
):
conversation_id = agent_input.conversation_id
responses = self._maybe_prepend_interrupt_responses(
transcription=transcription,
Expand Down Expand Up @@ -294,7 +272,7 @@ async def handle_generate_response(
agent_response_tracker = agent_input.agent_response_tracker or asyncio.Event()
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(
AgentResponse(
message=generated_response.message,
is_first=is_first_response_of_turn,
),
Expand Down Expand Up @@ -327,7 +305,7 @@ async def handle_generate_response(
)
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(
AgentResponse(
message=EndOfTurn(),
is_first=is_first_response_of_turn,
),
Expand All @@ -353,14 +331,12 @@ async def handle_generate_response(
)
self.enqueue_action_input(action, action_input, agent_input.conversation_id)

# TODO: implement should_stop for generate_responses
if function_call and self.agent_config.actions is not None:
await self.call_function(function_call, agent_input)
return False

async def handle_respond(self, transcription: Transcription, conversation_id: str) -> bool:
try:
response, should_stop = await self.respond(
response = await self.respond(
transcription.message,
is_interrupt=transcription.is_interrupt,
conversation_id=conversation_id,
Expand All @@ -372,17 +348,16 @@ async def handle_respond(self, transcription: Transcription, conversation_id: st
if response:
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(message=BaseMessage(text=response)),
AgentResponse(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
)
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(message=EndOfTurn()),
AgentResponse(message=EndOfTurn()),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
)
return should_stop
else:
logger.debug("No response generated")
return False
Expand Down Expand Up @@ -410,7 +385,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
if agent_input.action_output.canned_response is not None:
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(
AgentResponse(
message=agent_input.action_output.canned_response,
is_sole_text_chunk=True,
),
Expand All @@ -419,7 +394,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
)
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(message=EndOfTurn()),
AgentResponse(message=EndOfTurn()),
)
)
return
Expand All @@ -435,34 +410,17 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
logger.debug("Agent is muted, skipping processing")
return

if self.agent_config.send_filler_audio:
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseFillerAudio(),
)
)

logger.debug("Responding to transcription")
should_stop = False
if self.agent_config.generate_responses:
# TODO (EA): this is quite ugly but necessary to have the agent act properly after an action completes
if not isinstance(agent_input, ActionResultAgentInput):
sentry_create_span(
sentry_callable=sentry_sdk.start_span,
op=CustomSentrySpans.LANGUAGE_MODEL_TIME_TO_FIRST_TOKEN,
)
should_stop = await self.handle_generate_response(transcription, agent_input)
await self.handle_generate_response(transcription, agent_input)
else:
should_stop = await self.handle_respond(transcription, agent_input.conversation_id)

if should_stop:
logger.debug("Agent requested to stop")
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseStop(),
)
)
return
await self.handle_respond(transcription, agent_input.conversation_id)
except asyncio.CancelledError:
pass

Expand Down Expand Up @@ -490,7 +448,7 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp
user_message_tracker = asyncio.Event()
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(
AgentResponse(
message=BaseMessage(text=user_message),
is_sole_text_chunk=True,
),
Expand All @@ -499,7 +457,7 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp
)
self.agent_responses_consumer.consume_nonblocking(
self.interruptible_event_factory.create_interruptible_agent_response_event(
AgentResponseMessage(message=EndOfTurn()),
AgentResponse(message=EndOfTurn()),
agent_response_tracker=user_message_tracker,
)
)
Expand Down Expand Up @@ -567,7 +525,7 @@ async def respond(
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[Optional[str], bool]:
) -> Optional[str]:
raise NotImplementedError

def generate_response(
Expand Down
Loading

0 comments on commit 3f2f9bc

Please sign in to comment.