Skip to content

Commit aeaf83f

Browse files
notV3NOMseratch
andauthored
feat(realtime): #1560 add input audio noise reduction (#1749)
Co-authored-by: Kazuhiro Sera <[email protected]>
1 parent 605611c commit aeaf83f

File tree

5 files changed

+107
-1
lines changed

5 files changed

+107
-1
lines changed

docs/ref/realtime/config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
## Audio Configuration
1212

1313
::: agents.realtime.config.RealtimeInputAudioTranscriptionConfig
14+
::: agents.realtime.config.RealtimeInputAudioNoiseReductionConfig
1415
::: agents.realtime.config.RealtimeTurnDetectionConfig
1516

1617
## Guardrails Settings

src/agents/realtime/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
RealtimeAudioFormat,
44
RealtimeClientMessage,
55
RealtimeGuardrailsSettings,
6+
RealtimeInputAudioNoiseReductionConfig,
67
RealtimeInputAudioTranscriptionConfig,
78
RealtimeModelName,
89
RealtimeModelTracingConfig,
@@ -101,6 +102,7 @@
101102
"RealtimeAudioFormat",
102103
"RealtimeClientMessage",
103104
"RealtimeGuardrailsSettings",
105+
"RealtimeInputAudioNoiseReductionConfig",
104106
"RealtimeInputAudioTranscriptionConfig",
105107
"RealtimeModelName",
106108
"RealtimeModelTracingConfig",

src/agents/realtime/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class RealtimeInputAudioTranscriptionConfig(TypedDict):
6161
"""An optional prompt to guide transcription."""
6262

6363

64+
class RealtimeInputAudioNoiseReductionConfig(TypedDict):
65+
"""Noise reduction configuration for input audio."""
66+
67+
type: NotRequired[Literal["near_field", "far_field"]]
68+
"""Noise reduction mode to apply to input audio."""
69+
70+
6471
class RealtimeTurnDetectionConfig(TypedDict):
6572
"""Turn detection config. Allows extra vendor keys if needed."""
6673

@@ -119,6 +126,9 @@ class RealtimeSessionModelSettings(TypedDict):
119126
input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig]
120127
"""Configuration for transcribing input audio."""
121128

129+
input_audio_noise_reduction: NotRequired[RealtimeInputAudioNoiseReductionConfig | None]
130+
"""Noise reduction configuration for input audio."""
131+
122132
turn_detection: NotRequired[RealtimeTurnDetectionConfig]
123133
"""Configuration for detecting conversation turns."""
124134

src/agents/realtime/openai_realtime.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,14 +825,24 @@ def _get_session_config(
825825
"output_audio_format",
826826
DEFAULT_MODEL_SETTINGS.get("output_audio_format"),
827827
)
828+
input_audio_noise_reduction = model_settings.get(
829+
"input_audio_noise_reduction",
830+
DEFAULT_MODEL_SETTINGS.get("input_audio_noise_reduction"),
831+
)
828832

829833
input_audio_config = None
830834
if any(
831835
value is not None
832-
for value in [input_audio_format, input_audio_transcription, turn_detection]
836+
for value in [
837+
input_audio_format,
838+
input_audio_noise_reduction,
839+
input_audio_transcription,
840+
turn_detection,
841+
]
833842
):
834843
input_audio_config = OpenAIRealtimeAudioInput(
835844
format=to_realtime_audio_format(input_audio_format),
845+
noise_reduction=cast(Any, input_audio_noise_reduction),
836846
transcription=cast(Any, input_audio_transcription),
837847
turn_detection=cast(Any, turn_detection),
838848
)

tests/realtime/test_openai_realtime.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, cast
23
from unittest.mock import AsyncMock, Mock, patch
34

@@ -96,6 +97,88 @@ def mock_create_task_func(coro):
9697
assert model._websocket_task is not None
9798
assert model.model == "gpt-4o-realtime-preview"
9899

100+
@pytest.mark.asyncio
101+
async def test_session_update_includes_noise_reduction(self, model, mock_websocket):
102+
"""Session.update should pass through input_audio_noise_reduction config."""
103+
config = {
104+
"api_key": "test-api-key-123",
105+
"initial_model_settings": {
106+
"model_name": "gpt-4o-realtime-preview",
107+
"input_audio_noise_reduction": {"type": "near_field"},
108+
},
109+
}
110+
111+
sent_messages: list[dict[str, Any]] = []
112+
113+
async def async_websocket(*args, **kwargs):
114+
async def send(payload: str):
115+
sent_messages.append(json.loads(payload))
116+
return None
117+
118+
mock_websocket.send.side_effect = send
119+
return mock_websocket
120+
121+
with patch("websockets.connect", side_effect=async_websocket):
122+
with patch("asyncio.create_task") as mock_create_task:
123+
mock_task = AsyncMock()
124+
125+
def mock_create_task_func(coro):
126+
coro.close()
127+
return mock_task
128+
129+
mock_create_task.side_effect = mock_create_task_func
130+
await model.connect(config)
131+
132+
# Find the session.update events
133+
session_updates = [m for m in sent_messages if m.get("type") == "session.update"]
134+
assert len(session_updates) >= 1
135+
# Verify the last session.update contains the noise_reduction field
136+
session = session_updates[-1]["session"]
137+
assert session.get("audio", {}).get("input", {}).get("noise_reduction") == {
138+
"type": "near_field"
139+
}
140+
141+
@pytest.mark.asyncio
142+
async def test_session_update_omits_noise_reduction_when_not_provided(
143+
self, model, mock_websocket
144+
):
145+
"""Session.update should omit input_audio_noise_reduction when not provided."""
146+
config = {
147+
"api_key": "test-api-key-123",
148+
"initial_model_settings": {
149+
"model_name": "gpt-4o-realtime-preview",
150+
},
151+
}
152+
153+
sent_messages: list[dict[str, Any]] = []
154+
155+
async def async_websocket(*args, **kwargs):
156+
async def send(payload: str):
157+
sent_messages.append(json.loads(payload))
158+
return None
159+
160+
mock_websocket.send.side_effect = send
161+
return mock_websocket
162+
163+
with patch("websockets.connect", side_effect=async_websocket):
164+
with patch("asyncio.create_task") as mock_create_task:
165+
mock_task = AsyncMock()
166+
167+
def mock_create_task_func(coro):
168+
coro.close()
169+
return mock_task
170+
171+
mock_create_task.side_effect = mock_create_task_func
172+
await model.connect(config)
173+
174+
# Find the session.update events
175+
session_updates = [m for m in sent_messages if m.get("type") == "session.update"]
176+
assert len(session_updates) >= 1
177+
# Verify the last session.update omits the noise_reduction field
178+
session = session_updates[-1]["session"]
179+
assert "audio" in session and "input" in session["audio"]
180+
assert "noise_reduction" not in session["audio"]["input"]
181+
99182
@pytest.mark.asyncio
100183
async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket):
101184
"""If custom headers are provided, use them verbatim without adding defaults."""

0 commit comments

Comments
 (0)