|
| 1 | +import json |
1 | 2 | from typing import Any, cast
|
2 | 3 | from unittest.mock import AsyncMock, Mock, patch
|
3 | 4 |
|
@@ -96,6 +97,88 @@ def mock_create_task_func(coro):
|
96 | 97 | assert model._websocket_task is not None
|
97 | 98 | assert model.model == "gpt-4o-realtime-preview"
|
98 | 99 |
|
| 100 | + @pytest.mark.asyncio |
| 101 | + async def test_session_update_includes_noise_reduction(self, model, mock_websocket): |
| 102 | + """Session.update should pass through input_audio_noise_reduction config.""" |
| 103 | + config = { |
| 104 | + "api_key": "test-api-key-123", |
| 105 | + "initial_model_settings": { |
| 106 | + "model_name": "gpt-4o-realtime-preview", |
| 107 | + "input_audio_noise_reduction": {"type": "near_field"}, |
| 108 | + }, |
| 109 | + } |
| 110 | + |
| 111 | + sent_messages: list[dict[str, Any]] = [] |
| 112 | + |
| 113 | + async def async_websocket(*args, **kwargs): |
| 114 | + async def send(payload: str): |
| 115 | + sent_messages.append(json.loads(payload)) |
| 116 | + return None |
| 117 | + |
| 118 | + mock_websocket.send.side_effect = send |
| 119 | + return mock_websocket |
| 120 | + |
| 121 | + with patch("websockets.connect", side_effect=async_websocket): |
| 122 | + with patch("asyncio.create_task") as mock_create_task: |
| 123 | + mock_task = AsyncMock() |
| 124 | + |
| 125 | + def mock_create_task_func(coro): |
| 126 | + coro.close() |
| 127 | + return mock_task |
| 128 | + |
| 129 | + mock_create_task.side_effect = mock_create_task_func |
| 130 | + await model.connect(config) |
| 131 | + |
| 132 | + # Find the session.update events |
| 133 | + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] |
| 134 | + assert len(session_updates) >= 1 |
| 135 | + # Verify the last session.update contains the noise_reduction field |
| 136 | + session = session_updates[-1]["session"] |
| 137 | + assert session.get("audio", {}).get("input", {}).get("noise_reduction") == { |
| 138 | + "type": "near_field" |
| 139 | + } |
| 140 | + |
| 141 | + @pytest.mark.asyncio |
| 142 | + async def test_session_update_omits_noise_reduction_when_not_provided( |
| 143 | + self, model, mock_websocket |
| 144 | + ): |
| 145 | + """Session.update should omit input_audio_noise_reduction when not provided.""" |
| 146 | + config = { |
| 147 | + "api_key": "test-api-key-123", |
| 148 | + "initial_model_settings": { |
| 149 | + "model_name": "gpt-4o-realtime-preview", |
| 150 | + }, |
| 151 | + } |
| 152 | + |
| 153 | + sent_messages: list[dict[str, Any]] = [] |
| 154 | + |
| 155 | + async def async_websocket(*args, **kwargs): |
| 156 | + async def send(payload: str): |
| 157 | + sent_messages.append(json.loads(payload)) |
| 158 | + return None |
| 159 | + |
| 160 | + mock_websocket.send.side_effect = send |
| 161 | + return mock_websocket |
| 162 | + |
| 163 | + with patch("websockets.connect", side_effect=async_websocket): |
| 164 | + with patch("asyncio.create_task") as mock_create_task: |
| 165 | + mock_task = AsyncMock() |
| 166 | + |
| 167 | + def mock_create_task_func(coro): |
| 168 | + coro.close() |
| 169 | + return mock_task |
| 170 | + |
| 171 | + mock_create_task.side_effect = mock_create_task_func |
| 172 | + await model.connect(config) |
| 173 | + |
| 174 | + # Find the session.update events |
| 175 | + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] |
| 176 | + assert len(session_updates) >= 1 |
| 177 | + # Verify the last session.update omits the noise_reduction field |
| 178 | + session = session_updates[-1]["session"] |
| 179 | + assert "audio" in session and "input" in session["audio"] |
| 180 | + assert "noise_reduction" not in session["audio"]["input"] |
| 181 | + |
99 | 182 | @pytest.mark.asyncio
|
100 | 183 | async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket):
|
101 | 184 | """If custom headers are provided, use them verbatim without adding defaults."""
|
|
0 commit comments