Skip to content

Commit 5e7942a

Browse files
committed
wip
1 parent 5def693 commit 5e7942a

File tree

14 files changed

+1766
-64
lines changed

14 files changed

+1766
-64
lines changed

examples/realtime/demo.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import asyncio
2+
import base64
3+
import os
4+
import sys
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
9+
# Add the current directory to path so we can import ui
10+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
11+
12+
from agents import function_tool
13+
from agents.realtime import RealtimeAgent, RealtimeSession, RealtimeSessionEvent
14+
15+
if TYPE_CHECKING:
16+
from .ui import AppUI
17+
else:
18+
# At runtime, try both import styles
19+
try:
20+
# Try relative import first (when used as a package)
21+
from .ui import AppUI
22+
except ImportError:
23+
# Fall back to direct import (when run as a script)
24+
from ui import AppUI
25+
26+
27+
@function_tool
28+
def get_weather(city: str) -> str:
29+
"""Get the weather in a city."""
30+
return f"The weather in {city} is sunny."
31+
32+
33+
agent = RealtimeAgent(
34+
name="Assistant",
35+
instructions="You always greet the user with 'Top of the morning to you'.",
36+
tools=[get_weather],
37+
)
38+
39+
40+
class Example:
41+
def __init__(self) -> None:
42+
self.session = RealtimeSession(agent)
43+
self.ui = AppUI()
44+
self.ui.connected = asyncio.Event()
45+
self.ui.last_audio_item_id = None
46+
# Set the audio callback
47+
self.ui.set_audio_callback(self.on_audio_recorded)
48+
49+
async def run(self) -> None:
50+
self.session.add_listener(self.on_event)
51+
await self.session.connect()
52+
self.ui.set_is_connected(True)
53+
await self.ui.run_async()
54+
55+
async def on_audio_recorded(self, audio_bytes: bytes) -> None:
56+
"""Called when audio is recorded by the UI."""
57+
try:
58+
# Send the audio to the session
59+
await self.session.send_audio(audio_bytes)
60+
except Exception as e:
61+
self.ui.log_message(f"Error sending audio: {e}")
62+
63+
async def on_event(self, event: RealtimeSessionEvent) -> None:
64+
# Display event in the UI
65+
try:
66+
if event.type == "raw_transport_event" and event.data.type == "other":
67+
# self.ui.log_message(f"{event.data}, {type(event.data.data)}")
68+
if event.data.data["type"] == "response.audio.delta":
69+
self.ui.log_message("audio deltas")
70+
delta_b64_string = event.data.data["delta"]
71+
delta_bytes = base64.b64decode(delta_b64_string)
72+
audio_data = np.frombuffer(delta_bytes, dtype=np.int16)
73+
self.ui.play_audio(audio_data)
74+
75+
# Handle audio from model
76+
if event.type == "audio":
77+
try:
78+
# Convert bytes to numpy array for audio player
79+
audio_data = np.frombuffer(event.audio.data, dtype=np.int16)
80+
self.ui.play_audio(audio_data)
81+
except Exception as e:
82+
self.ui.log_message(f"Audio play error: {e}")
83+
except Exception:
84+
# This can happen if the UI has already exited
85+
pass
86+
87+
88+
if __name__ == "__main__":
89+
example = Example()
90+
asyncio.run(example.run())

examples/realtime/ui.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Coroutine
5+
from typing import Any, Callable
6+
7+
import numpy as np
8+
import numpy.typing as npt
9+
import sounddevice as sd
10+
from textual import events
11+
from textual.app import App, ComposeResult
12+
from textual.containers import Container
13+
from textual.reactive import reactive
14+
from textual.widgets import RichLog, Static
15+
from typing_extensions import override
16+
17+
CHUNK_LENGTH_S = 0.05 # 50ms
18+
SAMPLE_RATE = 24000
19+
FORMAT = np.int16
20+
CHANNELS = 1
21+
22+
23+
class Header(Static):
24+
"""A header widget."""
25+
26+
@override
27+
def render(self) -> str:
28+
return "Realtime Demo"
29+
30+
31+
class AudioStatusIndicator(Static):
32+
"""A widget that shows the current audio recording status."""
33+
34+
is_recording = reactive(False)
35+
36+
@override
37+
def render(self) -> str:
38+
status = (
39+
"🔴 Conversation started."
40+
if self.is_recording
41+
else "⚪ Press SPACE to start the conversation (q to quit)"
42+
)
43+
return status
44+
45+
46+
class AppUI(App[None]):
47+
CSS = """
48+
Screen {
49+
background: #1a1b26; /* Dark blue-grey background */
50+
}
51+
52+
Container {
53+
border: double rgb(91, 164, 91);
54+
}
55+
56+
#input-container {
57+
height: 5; /* Explicit height for input container */
58+
margin: 1 1;
59+
padding: 1 2;
60+
}
61+
62+
#bottom-pane {
63+
width: 100%;
64+
height: 82%; /* Reduced to make room for session display */
65+
border: round rgb(205, 133, 63);
66+
content-align: center middle;
67+
}
68+
69+
#status-indicator {
70+
height: 3;
71+
content-align: center middle;
72+
background: #2a2b36;
73+
border: solid rgb(91, 164, 91);
74+
margin: 1 1;
75+
}
76+
77+
#session-display {
78+
height: 3;
79+
content-align: center middle;
80+
background: #2a2b36;
81+
border: solid rgb(91, 164, 91);
82+
margin: 1 1;
83+
}
84+
85+
Static {
86+
color: white;
87+
}
88+
"""
89+
90+
should_send_audio: asyncio.Event
91+
connected: asyncio.Event
92+
last_audio_item_id: str | None
93+
audio_callback: Callable[[bytes], Coroutine[Any, Any, None]] | None
94+
95+
def __init__(self) -> None:
96+
super().__init__()
97+
self.audio_player = sd.OutputStream(
98+
samplerate=SAMPLE_RATE,
99+
channels=CHANNELS,
100+
dtype=FORMAT,
101+
)
102+
self.should_send_audio = asyncio.Event()
103+
self.connected = asyncio.Event()
104+
self.audio_callback = None
105+
106+
@override
107+
def compose(self) -> ComposeResult:
108+
"""Create child widgets for the app."""
109+
with Container():
110+
yield Header(id="session-display")
111+
yield AudioStatusIndicator(id="status-indicator")
112+
yield RichLog(id="bottom-pane", wrap=True, highlight=True, markup=True)
113+
114+
def set_is_connected(self, is_connected: bool) -> None:
115+
self.connected.set() if is_connected else self.connected.clear()
116+
117+
def set_audio_callback(self, callback: Callable[[bytes], Coroutine[Any, Any, None]]) -> None:
118+
"""Set a callback function to be called when audio is recorded."""
119+
self.audio_callback = callback
120+
121+
# High-level methods for UI operations
122+
def set_header_text(self, text: str) -> None:
123+
"""Update the header text."""
124+
header = self.query_one("#session-display", Header)
125+
header.update(text)
126+
127+
def set_recording_status(self, is_recording: bool) -> None:
128+
"""Set the recording status indicator."""
129+
status_indicator = self.query_one(AudioStatusIndicator)
130+
status_indicator.is_recording = is_recording
131+
132+
def log_message(self, message: str) -> None:
133+
"""Add a message to the log pane."""
134+
try:
135+
bottom_pane = self.query_one("#bottom-pane", RichLog)
136+
bottom_pane.write(message)
137+
except Exception:
138+
# Handle the case where the widget might not be available
139+
pass
140+
141+
def play_audio(self, audio_data: npt.NDArray[np.int16]) -> None:
142+
"""Play audio data through the audio player."""
143+
try:
144+
self.audio_player.write(audio_data)
145+
except Exception as e:
146+
self.log_message(f"Audio play error: {e}")
147+
148+
async def on_mount(self) -> None:
149+
"""Set up audio player and start the audio capture worker."""
150+
self.audio_player.start()
151+
self.run_worker(self.capture_audio())
152+
153+
async def capture_audio(self) -> None:
154+
"""Capture audio from the microphone and send to the session."""
155+
# Wait for connection to be established
156+
await self.connected.wait()
157+
158+
self.log_message("Connected to agent. Press space to start the conversation")
159+
160+
# Set up audio input stream
161+
stream = sd.InputStream(
162+
channels=CHANNELS,
163+
samplerate=SAMPLE_RATE,
164+
dtype=FORMAT,
165+
)
166+
167+
try:
168+
# Wait for user to press spacebar to start
169+
await self.should_send_audio.wait()
170+
171+
stream.start()
172+
self.set_recording_status(True)
173+
self.log_message("Recording started - speak to the agent")
174+
175+
# Buffer size in samples
176+
read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S)
177+
178+
while True:
179+
# Check if there's enough data to read
180+
if stream.read_available < read_size:
181+
await asyncio.sleep(0.01) # Small sleep to avoid CPU hogging
182+
continue
183+
184+
# Read audio data
185+
data, _ = stream.read(read_size)
186+
187+
# Convert numpy array to bytes
188+
audio_bytes = data.tobytes()
189+
190+
# Call audio callback if set
191+
if self.audio_callback:
192+
try:
193+
await self.audio_callback(audio_bytes)
194+
except Exception as e:
195+
self.log_message(f"Audio callback error: {e}")
196+
197+
# Yield control back to event loop
198+
await asyncio.sleep(0)
199+
200+
except Exception as e:
201+
self.log_message(f"Audio capture error: {e}")
202+
finally:
203+
if stream.active:
204+
stream.stop()
205+
stream.close()
206+
207+
async def on_key(self, event: events.Key) -> None:
208+
"""Handle key press events."""
209+
# add the keypress to the log
210+
self.log_message(f"Key pressed: {event.key}")
211+
212+
if event.key == "q":
213+
self.audio_player.stop()
214+
self.audio_player.close()
215+
self.exit()
216+
return
217+
218+
if event.key == "space": # Spacebar
219+
if not self.should_send_audio.is_set():
220+
self.should_send_audio.set()
221+
self.set_recording_status(True)

src/agents/agent.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,33 @@ class AgentBase:
9494
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
9595
"""Configuration for MCP servers."""
9696

97+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98+
"""Fetches the available tools from the MCP servers."""
99+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100+
return await MCPUtil.get_all_function_tools(
101+
self.mcp_servers, convert_schemas_to_strict, run_context, self
102+
)
103+
104+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105+
"""All agent tools, including MCP tools and function tools."""
106+
mcp_tools = await self.get_mcp_tools(run_context)
107+
108+
async def _check_tool_enabled(tool: Tool) -> bool:
109+
if not isinstance(tool, FunctionTool):
110+
return True
111+
112+
attr = tool.is_enabled
113+
if isinstance(attr, bool):
114+
return attr
115+
res = attr(run_context, self)
116+
if inspect.isawaitable(res):
117+
return bool(await res)
118+
return bool(res)
119+
120+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122+
return [*mcp_tools, *enabled]
123+
97124

98125
@dataclass
99126
class Agent(AgentBase, Generic[TContext]):
@@ -262,30 +289,3 @@ async def get_prompt(
262289
) -> ResponsePromptParam | None:
263290
"""Get the prompt for the agent."""
264291
return await PromptUtil.to_model_input(self.prompt, run_context, self)
265-
266-
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
267-
"""Fetches the available tools from the MCP servers."""
268-
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
269-
return await MCPUtil.get_all_function_tools(
270-
self.mcp_servers, convert_schemas_to_strict, run_context, self
271-
)
272-
273-
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
274-
"""All agent tools, including MCP tools and function tools."""
275-
mcp_tools = await self.get_mcp_tools(run_context)
276-
277-
async def _check_tool_enabled(tool: Tool) -> bool:
278-
if not isinstance(tool, FunctionTool):
279-
return True
280-
281-
attr = tool.is_enabled
282-
if isinstance(attr, bool):
283-
return attr
284-
res = attr(run_context, self)
285-
if inspect.isawaitable(res):
286-
return bool(await res)
287-
return bool(res)
288-
289-
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
290-
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
291-
return [*mcp_tools, *enabled]

0 commit comments

Comments
 (0)