diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index ad98c9c7a..ece72e755 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -83,6 +83,8 @@ class RealtimeSessionModelSettings(TypedDict): tool_choice: NotRequired[ToolChoice] tools: NotRequired[list[Tool]] + tracing: NotRequired[RealtimeModelTracingConfig | None] + class RealtimeGuardrailsSettings(TypedDict): """Settings for output guardrails in realtime sessions.""" @@ -95,6 +97,19 @@ class RealtimeGuardrailsSettings(TypedDict): """ +class RealtimeModelTracingConfig(TypedDict): + """Configuration for tracing in realtime model sessions.""" + + workflow_name: NotRequired[str] + """The workflow name to use for tracing.""" + + group_id: NotRequired[str] + """A group identifier to use for tracing, to link multiple traces together.""" + + metadata: NotRequired[dict[str, Any]] + """Additional metadata to include with the trace.""" + + class RealtimeRunConfig(TypedDict): model_settings: NotRequired[RealtimeSessionModelSettings] @@ -104,6 +119,7 @@ class RealtimeRunConfig(TypedDict): guardrails_settings: NotRequired[RealtimeGuardrailsSettings] """Settings for guardrail execution.""" - # TODO (rm) Add tracing support - # tracing: NotRequired[RealtimeTracingConfig | None] + tracing_disabled: NotRequired[bool] + """Whether tracing is disabled for this run.""" + # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py index 2b41960e7..abb3a1eac 100644 --- a/src/agents/realtime/model.py +++ b/src/agents/realtime/model.py @@ -38,6 +38,7 @@ class RealtimeModelConfig(TypedDict): """ initial_model_settings: NotRequired[RealtimeSessionModelSettings] + """The initial model settings to use when connecting.""" class RealtimeModel(abc.ABC): diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index de8f57ac7..797753242 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -6,7 +6,7 @@ import json import os from datetime import datetime -from typing import Any, Callable +from typing import Any, Callable, Literal import websockets from openai.types.beta.realtime.conversation_item import ConversationItem @@ -23,6 +23,7 @@ from ..logger import logger from .config import ( RealtimeClientMessage, + RealtimeModelTracingConfig, RealtimeSessionModelSettings, RealtimeUserInput, ) @@ -73,6 +74,7 @@ def __init__(self) -> None: self._audio_length_ms: float = 0.0 self._ongoing_response: bool = False self._current_audio_content_index: int | None = None + self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None async def connect(self, options: RealtimeModelConfig) -> None: """Establish a connection to the model and keep it alive.""" @@ -84,6 +86,11 @@ async def connect(self, options: RealtimeModelConfig) -> None: self.model = model_settings.get("model_name", self.model) api_key = await get_api_key(options.get("api_key")) + if "tracing" in model_settings: + self._tracing_config = model_settings["tracing"] + else: + self._tracing_config = "auto" + if not api_key: raise UserError("API key is required but was not provided.") @@ -96,6 +103,15 @@ async def connect(self, options: RealtimeModelConfig) -> None: self._websocket = await websockets.connect(url, additional_headers=headers) self._websocket_task = asyncio.create_task(self._listen_for_messages()) + async def _send_tracing_config( + self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None + ) -> None: + """Update tracing configuration via session.update event.""" + if tracing_config is not None: + await self.send_event( + {"type": "session.update", "other_data": {"session": {"tracing": tracing_config}}} + ) + def add_listener(self, listener: RealtimeModelListener) -> None: """Add a listener to the model.""" self._listeners.append(listener) @@ -343,8 +359,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): self._ongoing_response = False await self._emit_event(RealtimeModelTurnEndedEvent()) elif parsed.type == "session.created": - # TODO (rm) tracing stuff here - pass + await self._send_tracing_config(self._tracing_config) elif parsed.type == "error": await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) elif parsed.type == "conversation.item.deleted": diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py index 369267797..a7047a6f5 100644 --- a/src/agents/realtime/runner.py +++ b/src/agents/realtime/runner.py @@ -69,6 +69,7 @@ async def run( """ model_settings = await self._get_model_settings( agent=self._starting_agent, + disable_tracing=self._config.get("tracing_disabled", False) if self._config else False, initial_settings=model_config.get("initial_model_settings") if model_config else None, overrides=self._config.get("model_settings") if self._config else None, ) @@ -90,6 +91,7 @@ async def run( async def _get_model_settings( self, agent: RealtimeAgent, + disable_tracing: bool, context: TContext | None = None, initial_settings: RealtimeSessionModelSettings | None = None, overrides: RealtimeSessionModelSettings | None = None, @@ -110,4 +112,7 @@ async def _get_model_settings( if overrides: model_settings.update(overrides) + if disable_tracing: + model_settings["tracing"] = None + return model_settings diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py new file mode 100644 index 000000000..456ae125f --- /dev/null +++ b/tests/realtime/test_tracing.py @@ -0,0 +1,257 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class TestRealtimeTracingIntegration: + """Test tracing configuration and session.update integration.""" + + @pytest.fixture + def model(self): + """Create a fresh model instance for each test.""" + return OpenAIRealtimeWebSocketModel() + + @pytest.fixture + def mock_websocket(self): + """Create a mock websocket connection.""" + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.close = AsyncMock() + return mock_ws + + @pytest.mark.asyncio + async def test_tracing_config_storage_and_defaults(self, model, mock_websocket): + """Test that tracing config is stored correctly and defaults to 'auto'.""" + # Test with explicit tracing config + config_with_tracing = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.return_value = mock_task + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config_with_tracing) + + # Should store the tracing config + assert model._tracing_config == { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + + # Test without tracing config - should default to "auto" + model2 = OpenAIRealtimeWebSocketModel() + config_no_tracing = { + "api_key": "test-key", + "initial_model_settings": {}, + } + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model2.connect(config_no_tracing) # type: ignore[arg-type] + assert model2._tracing_config == "auto" + + @pytest.mark.asyncio + async def test_send_tracing_config_on_session_created(self, model, mock_websocket): + """Test that tracing config is sent when session.created event is received.""" + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "test_workflow", "group_id": "group_123"} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + # Simulate session.created event + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with tracing config + mock_send_event.assert_called_once_with( + { + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + } + } + }, + } + ) + + @pytest.mark.asyncio + async def test_send_tracing_config_auto_mode(self, model, mock_websocket): + """Test that 'auto' tracing config is sent correctly.""" + config = { + "api_key": "test-key", + "initial_model_settings": {}, # No tracing config - defaults to "auto" + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with "auto" + mock_send_event.assert_called_once_with( + {"type": "session.update", "other_data": {"session": {"tracing": "auto"}}} + ) + + @pytest.mark.asyncio + async def test_tracing_config_none_skips_session_update(self, model, mock_websocket): + """Test that None tracing config skips sending session.update.""" + # Manually set tracing config to None (this would happen if explicitly set) + model._tracing_config = None + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should not send any session.update + mock_send_event.assert_not_called() + + @pytest.mark.asyncio + async def test_tracing_config_with_metadata_serialization(self, model, mock_websocket): + """Test that complex metadata in tracing config is handled correctly.""" + complex_metadata = { + "user_id": "user_123", + "session_type": "demo", + "features": ["audio", "tools"], + "config": {"timeout": 30, "retries": 3}, + } + + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "complex_workflow", "metadata": complex_metadata} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with complete tracing config including metadata + expected_call = { + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "complex_workflow", + "metadata": complex_metadata, + } + } + }, + } + mock_send_event.assert_called_once_with(expected_call) + + @pytest.mark.asyncio + async def test_tracing_disabled_prevents_tracing(self, mock_websocket): + """Test that tracing_disabled=True prevents tracing configuration.""" + from agents.realtime.agent import RealtimeAgent + from agents.realtime.runner import RealtimeRunner + + # Create a test agent and runner with tracing disabled + agent = RealtimeAgent(name="test_agent", instructions="test") + + runner = RealtimeRunner( + starting_agent=agent, + config={"tracing_disabled": True} + ) + + # Test the _get_model_settings method directly since that's where the logic is + model_settings = await runner._get_model_settings( + agent=agent, + disable_tracing=True, # This should come from config["tracing_disabled"] + initial_settings=None, + overrides=None + ) + + # When tracing is disabled, model settings should have tracing=None + assert model_settings["tracing"] is None + + # Also test that the runner passes disable_tracing=True correctly + with patch.object(runner, '_get_model_settings') as mock_get_settings: + mock_get_settings.return_value = {"tracing": None} + + with patch('agents.realtime.session.RealtimeSession') as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value = mock_session + + await runner.run() + + # Verify that _get_model_settings was called with disable_tracing=True + mock_get_settings.assert_called_once_with( + agent=agent, + disable_tracing=True, + initial_settings=None, + overrides=None + )