diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 83fa32abc..6133af344 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -53,6 +53,9 @@ class StreamingState: refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) + # Fields for real-time function call streaming + function_call_streaming: dict[int, bool] = field(default_factory=dict) + function_call_output_idx: dict[int, int] = field(default_factory=dict) class SequenceNumber: @@ -255,9 +258,7 @@ async def handle_stream( # Accumulate the refusal string in the output part state.refusal_content_index_and_output[1].refusal += delta.refusal - # Handle tool calls - # Because we don't know the name of the function until the end of the stream, we'll - # save everything and yield events at the end + # Handle tool calls with real-time streaming support if delta.tool_calls: for tc_delta in delta.tool_calls: if tc_delta.index not in state.function_calls: @@ -268,15 +269,76 @@ async def handle_stream( type="function_call", call_id="", ) + state.function_call_streaming[tc_delta.index] = False + tc_function = tc_delta.function + # Accumulate arguments as they come in state.function_calls[tc_delta.index].arguments += ( tc_function.arguments if tc_function else "" ) or "" - state.function_calls[tc_delta.index].name += ( - tc_function.name if tc_function else "" - ) or "" - state.function_calls[tc_delta.index].call_id = tc_delta.id or "" + + # Set function name directly (it's correct from the first function call chunk) + if tc_function and tc_function.name: + state.function_calls[tc_delta.index].name = tc_function.name + + if tc_delta.id: + state.function_calls[tc_delta.index].call_id = tc_delta.id + + function_call = state.function_calls[tc_delta.index] + + # Start streaming as soon as we have function name and call_id + if (not state.function_call_streaming[tc_delta.index] and + function_call.name and + function_call.call_id): + + # Calculate the output index for this function call + function_call_starting_index = 0 + if state.reasoning_content_index_and_output: + function_call_starting_index += 1 + if state.text_content_index_and_output: + function_call_starting_index += 1 + if state.refusal_content_index_and_output: + function_call_starting_index += 1 + + # Add offset for already started function calls + function_call_starting_index += sum( + 1 for streaming in state.function_call_streaming.values() if streaming + ) + + # Mark this function call as streaming and store its output index + state.function_call_streaming[tc_delta.index] = True + state.function_call_output_idx[ + tc_delta.index + ] = function_call_starting_index + + # Send initial function call added event + yield ResponseOutputItemAddedEvent( + item=ResponseFunctionToolCall( + id=FAKE_RESPONSES_ID, + call_id=function_call.call_id, + arguments="", # Start with empty arguments + name=function_call.name, + type="function_call", + ), + output_index=function_call_starting_index, + type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), + ) + + # Stream arguments if we've started streaming this function call + if (state.function_call_streaming.get(tc_delta.index, False) and + tc_function and + tc_function.arguments): + + output_index = state.function_call_output_idx[tc_delta.index] + yield ResponseFunctionCallArgumentsDeltaEvent( + delta=tc_function.arguments, + item_id=FAKE_RESPONSES_ID, + output_index=output_index, + type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), + ) if state.reasoning_content_index_and_output: yield ResponseReasoningSummaryPartDoneEvent( @@ -327,42 +389,71 @@ async def handle_stream( sequence_number=sequence_number.get_and_increment(), ) - # Actually send events for the function calls - for function_call in state.function_calls.values(): - # First, a ResponseOutputItemAdded for the function call - yield ResponseOutputItemAddedEvent( - item=ResponseFunctionToolCall( - id=FAKE_RESPONSES_ID, - call_id=function_call.call_id, - arguments=function_call.arguments, - name=function_call.name, - type="function_call", - ), - output_index=function_call_starting_index, - type="response.output_item.added", - sequence_number=sequence_number.get_and_increment(), - ) - # Then, yield the args - yield ResponseFunctionCallArgumentsDeltaEvent( - delta=function_call.arguments, - item_id=FAKE_RESPONSES_ID, - output_index=function_call_starting_index, - type="response.function_call_arguments.delta", - sequence_number=sequence_number.get_and_increment(), - ) - # Finally, the ResponseOutputItemDone - yield ResponseOutputItemDoneEvent( - item=ResponseFunctionToolCall( - id=FAKE_RESPONSES_ID, - call_id=function_call.call_id, - arguments=function_call.arguments, - name=function_call.name, - type="function_call", - ), - output_index=function_call_starting_index, - type="response.output_item.done", - sequence_number=sequence_number.get_and_increment(), - ) + # Send completion events for function calls + for index, function_call in state.function_calls.items(): + if state.function_call_streaming.get(index, False): + # Function call was streamed, just send the completion event + output_index = state.function_call_output_idx[index] + yield ResponseOutputItemDoneEvent( + item=ResponseFunctionToolCall( + id=FAKE_RESPONSES_ID, + call_id=function_call.call_id, + arguments=function_call.arguments, + name=function_call.name, + type="function_call", + ), + output_index=output_index, + type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), + ) + else: + # Function call was not streamed (fallback to old behavior) + # This handles edge cases where function name never arrived + fallback_starting_index = 0 + if state.reasoning_content_index_and_output: + fallback_starting_index += 1 + if state.text_content_index_and_output: + fallback_starting_index += 1 + if state.refusal_content_index_and_output: + fallback_starting_index += 1 + + # Add offset for already started function calls + fallback_starting_index += sum( + 1 for streaming in state.function_call_streaming.values() if streaming + ) + + # Send all events at once (backward compatibility) + yield ResponseOutputItemAddedEvent( + item=ResponseFunctionToolCall( + id=FAKE_RESPONSES_ID, + call_id=function_call.call_id, + arguments=function_call.arguments, + name=function_call.name, + type="function_call", + ), + output_index=fallback_starting_index, + type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), + ) + yield ResponseFunctionCallArgumentsDeltaEvent( + delta=function_call.arguments, + item_id=FAKE_RESPONSES_ID, + output_index=fallback_starting_index, + type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), + ) + yield ResponseOutputItemDoneEvent( + item=ResponseFunctionToolCall( + id=FAKE_RESPONSES_ID, + call_id=function_call.call_id, + arguments=function_call.arguments, + name=function_call.name, + type="function_call", + ), + output_index=fallback_starting_index, + type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), + ) # Finally, send the Response completed event outputs: list[ResponseOutputItem] = [] diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index cd342e444..bd38f8759 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -214,17 +214,18 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None: the model is streaming a function/tool call instead of plain text. The function call will be split across two chunks. """ - # Simulate a single tool call whose ID stays constant and function name/args built over chunks. + # Simulate a single tool call with complete function name in first chunk + # and arguments split across chunks (reflecting real API behavior) tool_call_delta1 = ChoiceDeltaToolCall( index=0, id="tool-id", - function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"), + function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"), type="function", ) tool_call_delta2 = ChoiceDeltaToolCall( index=0, id="tool-id", - function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"), + function=ChoiceDeltaToolCallFunction(name=None, arguments="arg2"), type="function", ) chunk1 = ChatCompletionChunk( @@ -284,18 +285,131 @@ async def patched_fetch_response(self, *args, **kwargs): # The added item should be a ResponseFunctionToolCall. added_fn = output_events[1].item assert isinstance(added_fn, ResponseFunctionToolCall) - assert added_fn.name == "my_func" # Name should be concatenation of both chunks. - assert added_fn.arguments == "arg1arg2" + assert added_fn.name == "my_func" # Name should be complete from first chunk + assert added_fn.arguments == "" # Arguments start empty assert output_events[2].type == "response.function_call_arguments.delta" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" - assert added_fn.name == "my_func" # Name should be concatenation of both chunks. - assert added_fn.arguments == "arg1arg2" - assert output_events[2].type == "response.function_call_arguments.delta" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" + assert output_events[2].delta == "arg1" # First argument chunk + assert output_events[3].type == "response.function_call_arguments.delta" + assert output_events[3].delta == "arg2" # Second argument chunk + assert output_events[4].type == "response.output_item.done" + assert output_events[5].type == "response.completed" + # Final function call should have complete arguments + final_fn = output_events[4].item + assert isinstance(final_fn, ResponseFunctionToolCall) + assert final_fn.name == "my_func" + assert final_fn.arguments == "arg1arg2" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None: + """ + Validate that LiteLLM `stream_response` also emits function call arguments in real-time + as they are received, ensuring consistent behavior across model providers. + """ + # Simulate realistic chunks: name first, then arguments incrementally + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="litellm-call-456", + function=ChoiceDeltaToolCallFunction(name="generate_code", arguments=""), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='{"language": "'), + type="function", + ) + tool_call_delta3 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='python", "task": "'), + type="function", + ) + tool_call_delta4 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='hello world"}'), + type="function", + ) + + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + ) + chunk3 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))], + ) + chunk4 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2, chunk3, chunk4): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ): + output_events.append(event) + + # Extract events by type + function_args_delta_events = [ + e for e in output_events if e.type == "response.function_call_arguments.delta" + ] + output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"] + + # Verify we got real-time streaming (3 argument delta events) + assert len(function_args_delta_events) == 3 + assert len(output_item_added_events) == 1 + + # Verify the deltas were streamed correctly + expected_deltas = ['{"language": "', 'python", "task": "', 'hello world"}'] + for i, delta_event in enumerate(function_args_delta_events): + assert delta_event.delta == expected_deltas[i] + + # Verify function call metadata + added_event = output_item_added_events[0] + assert isinstance(added_event.item, ResponseFunctionToolCall) + assert added_event.item.name == "generate_code" + assert added_event.item.call_id == "litellm-call-456" diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index 49e7bc2f4..cbb3c5dae 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -214,17 +214,18 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None: the model is streaming a function/tool call instead of plain text. The function call will be split across two chunks. """ - # Simulate a single tool call whose ID stays constant and function name/args built over chunks. + # Simulate a single tool call with complete function name in first chunk + # and arguments split across chunks (reflecting real OpenAI API behavior) tool_call_delta1 = ChoiceDeltaToolCall( index=0, id="tool-id", - function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"), + function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"), type="function", ) tool_call_delta2 = ChoiceDeltaToolCall( index=0, id="tool-id", - function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"), + function=ChoiceDeltaToolCallFunction(name=None, arguments="arg2"), type="function", ) chunk1 = ChatCompletionChunk( @@ -284,18 +285,154 @@ async def patched_fetch_response(self, *args, **kwargs): # The added item should be a ResponseFunctionToolCall. added_fn = output_events[1].item assert isinstance(added_fn, ResponseFunctionToolCall) - assert added_fn.name == "my_func" # Name should be concatenation of both chunks. - assert added_fn.arguments == "arg1arg2" + assert added_fn.name == "my_func" # Name should be complete from first chunk + assert added_fn.arguments == "" # Arguments start empty assert output_events[2].type == "response.function_call_arguments.delta" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" - assert added_fn.name == "my_func" # Name should be concatenation of both chunks. - assert added_fn.arguments == "arg1arg2" - assert output_events[2].type == "response.function_call_arguments.delta" - assert output_events[2].delta == "arg1arg2" - assert output_events[3].type == "response.output_item.done" - assert output_events[4].type == "response.completed" + assert output_events[2].delta == "arg1" # First argument chunk + assert output_events[3].type == "response.function_call_arguments.delta" + assert output_events[3].delta == "arg2" # Second argument chunk + assert output_events[4].type == "response.output_item.done" + assert output_events[5].type == "response.completed" + # Final function call should have complete arguments + final_fn = output_events[4].item + assert isinstance(final_fn, ResponseFunctionToolCall) + assert final_fn.name == "my_func" + assert final_fn.arguments == "arg1arg2" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None: + """ + Validate that `stream_response` emits function call arguments in real-time as they + are received, not just at the end. This test simulates the real OpenAI API behavior + where function name comes first, then arguments are streamed incrementally. + """ + # Simulate realistic OpenAI API chunks: name first, then arguments incrementally + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-call-123", + function=ChoiceDeltaToolCallFunction(name="write_file", arguments=""), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='{"filename": "'), + type="function", + ) + tool_call_delta3 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='test.py", "content": "'), + type="function", + ) + tool_call_delta4 = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='print(hello)"}'), + type="function", + ) + + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + ) + chunk3 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))], + ) + chunk4 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2, chunk3, chunk4): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ): + output_events.append(event) + + # Extract events by type + created_events = [e for e in output_events if e.type == "response.created"] + output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"] + function_args_delta_events = [ + e for e in output_events if e.type == "response.function_call_arguments.delta" + ] + output_item_done_events = [e for e in output_events if e.type == "response.output_item.done"] + completed_events = [e for e in output_events if e.type == "response.completed"] + + # Verify event structure + assert len(created_events) == 1 + assert len(output_item_added_events) == 1 + assert len(function_args_delta_events) == 3 # Three incremental argument chunks + assert len(output_item_done_events) == 1 + assert len(completed_events) == 1 + + # Verify the function call started as soon as we had name and ID + added_event = output_item_added_events[0] + assert isinstance(added_event.item, ResponseFunctionToolCall) + assert added_event.item.name == "write_file" + assert added_event.item.call_id == "tool-call-123" + assert added_event.item.arguments == "" # Should be empty at start + + # Verify real-time argument streaming + expected_deltas = ['{"filename": "', 'test.py", "content": "', 'print(hello)"}'] + for i, delta_event in enumerate(function_args_delta_events): + assert delta_event.delta == expected_deltas[i] + assert delta_event.item_id == "__fake_id__" # FAKE_RESPONSES_ID + assert delta_event.output_index == 0 + + # Verify completion event has full arguments + done_event = output_item_done_events[0] + assert isinstance(done_event.item, ResponseFunctionToolCall) + assert done_event.item.name == "write_file" + assert done_event.item.arguments == '{"filename": "test.py", "content": "print(hello)"}' + + # Verify final response + completed_event = completed_events[0] + function_call_output = completed_event.response.output[0] + assert isinstance(function_call_output, ResponseFunctionToolCall) + assert function_call_output.name == "write_file" + assert function_call_output.arguments == '{"filename": "test.py", "content": "print(hello)"}'