Skip to content

Add Streaming of Function Call Arguments to Chat Completions #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
181 changes: 141 additions & 40 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class StreamingState:
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
# Fields for real-time function call streaming
function_call_streaming: dict[int, bool] = field(default_factory=dict)
function_call_output_idx: dict[int, int] = field(default_factory=dict)


class SequenceNumber:
Expand Down Expand Up @@ -255,9 +258,7 @@ async def handle_stream(
# Accumulate the refusal string in the output part
state.refusal_content_index_and_output[1].refusal += delta.refusal

# Handle tool calls
# Because we don't know the name of the function until the end of the stream, we'll
# save everything and yield events at the end
# Handle tool calls with real-time streaming support
if delta.tool_calls:
for tc_delta in delta.tool_calls:
if tc_delta.index not in state.function_calls:
Expand All @@ -268,15 +269,86 @@ async def handle_stream(
type="function_call",
call_id="",
)
state.function_call_streaming[tc_delta.index] = False

tc_function = tc_delta.function

# Accumulate the data as before
state.function_calls[tc_delta.index].arguments += (
tc_function.arguments if tc_function else ""
) or ""
state.function_calls[tc_delta.index].name += (
tc_function.name if tc_function else ""
) or ""
state.function_calls[tc_delta.index].call_id = tc_delta.id or ""
if tc_delta.id:
state.function_calls[tc_delta.index].call_id = tc_delta.id

# Check if we have enough info to start streaming this function call
function_call = state.function_calls[tc_delta.index]

# Strategy: Only start streaming when we see arguments coming in
# but no new name information, indicating the name is finalized
current_chunk_has_name = tc_function and tc_function.name
current_chunk_has_args = tc_function and tc_function.arguments

# If this chunk has a name, it means the function name might still be building
# We should wait until we get a chunk with only arguments (no name)
name_seems_finalized = not current_chunk_has_name and current_chunk_has_args

if (not state.function_call_streaming[tc_delta.index] and
function_call.name and
function_call.call_id and
# Only start streaming when we're confident the name is finalized
# This happens when we get args but no new name chunk
name_seems_finalized):

# Calculate the output index for this function call
function_call_starting_index = 0
if state.reasoning_content_index_and_output:
function_call_starting_index += 1
if state.text_content_index_and_output:
function_call_starting_index += 1
if state.refusal_content_index_and_output:
function_call_starting_index += 1

# Add offset for already started function calls
function_call_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Mark this function call as streaming and store its output index
state.function_call_streaming[tc_delta.index] = True
state.function_call_output_idx[
tc_delta.index
] = function_call_starting_index

# Send initial function call added event
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments="", # Start with empty arguments
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)

# Stream arguments if we've started streaming this function call
if (state.function_call_streaming[tc_delta.index] and
tc_function and
tc_function.arguments):

output_index = state.function_call_output_idx[tc_delta.index]
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=tc_function.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=output_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)

if state.reasoning_content_index_and_output:
yield ResponseReasoningSummaryPartDoneEvent(
Expand Down Expand Up @@ -327,42 +399,71 @@ async def handle_stream(
sequence_number=sequence_number.get_and_increment(),
)

# Actually send events for the function calls
for function_call in state.function_calls.values():
# First, a ResponseOutputItemAdded for the function call
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
# Then, yield the args
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=function_call_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
# Finally, the ResponseOutputItemDone
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
# Send completion events for function calls
for index, function_call in state.function_calls.items():
if state.function_call_streaming.get(index, False):
# Function call was streamed, just send the completion event
output_index = state.function_call_output_idx[index]
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=output_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
else:
# Function call was not streamed (fallback to old behavior)
# This handles edge cases where function name never arrived
fallback_starting_index = 0
if state.reasoning_content_index_and_output:
fallback_starting_index += 1
if state.text_content_index_and_output:
fallback_starting_index += 1
if state.refusal_content_index_and_output:
fallback_starting_index += 1

# Add offset for already started function calls
fallback_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Send all events at once (backward compatibility)
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=fallback_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)

# Finally, send the Response completed event
outputs: list[ResponseOutputItem] = []
Expand Down
115 changes: 115 additions & 0 deletions tests/models/test_litellm_chatcompletions_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,118 @@ async def patched_fetch_response(self, *args, **kwargs):
assert output_events[2].delta == "arg1arg2"
assert output_events[3].type == "response.output_item.done"
assert output_events[4].type == "response.completed"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None:
"""
Validate that LiteLLM `stream_response` also emits function call arguments in real-time
as they are received, ensuring consistent behavior across model providers.
"""
# Simulate realistic chunks: name first, then arguments incrementally
tool_call_delta1 = ChoiceDeltaToolCall(
index=0,
id="litellm-call-456",
function=ChoiceDeltaToolCallFunction(name="generate_code", arguments=""),
type="function",
)
tool_call_delta2 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='{"language": "'),
type="function",
)
tool_call_delta3 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='python", "task": "'),
type="function",
)
tool_call_delta4 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='hello world"}'),
type="function",
)

chunk1 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
)
chunk2 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
)
chunk3 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))],
)
chunk4 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))],
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
)

async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
for c in (chunk1, chunk2, chunk3, chunk4):
yield c

async def patched_fetch_response(self, *args, **kwargs):
resp = Response(
id="resp-id",
created_at=0,
model="fake-model",
object="response",
output=[],
tool_choice="none",
tools=[],
parallel_tool_calls=False,
)
return resp, fake_stream()

monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
model = LitellmProvider().get_model("gpt-4")
output_events = []
async for event in model.stream_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None,
):
output_events.append(event)

# Extract events by type
function_args_delta_events = [
e for e in output_events if e.type == "response.function_call_arguments.delta"
]
output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"]

# Verify we got real-time streaming (3 argument delta events)
assert len(function_args_delta_events) == 3
assert len(output_item_added_events) == 1

# Verify the deltas were streamed correctly
expected_deltas = ['{"language": "', 'python", "task": "', 'hello world"}']
for i, delta_event in enumerate(function_args_delta_events):
assert delta_event.delta == expected_deltas[i]

# Verify function call metadata
added_event = output_item_added_events[0]
assert isinstance(added_event.item, ResponseFunctionToolCall)
assert added_event.item.name == "generate_code"
assert added_event.item.call_id == "litellm-call-456"
Loading