Skip to content

Commit 78d7de9

Browse files
authored
Support multiple outputs in chunk and response of chat service (#39)
# Pull Request Title ## Checklist - [x] I have read both the [CONTRIBUTING.md](CONTRIBUTING.md) and [Contributor License Agreement](CLA.md) documents. - [x] I have created an issue or feature request and received approval from xAI maintainers. (minor changes like fixing typos can skip this step) - [x] I have tested my changes locally and they pass all CI checks. - [x] I have added necessary documentation or updated existing documentation. ## Description Provide a clear and concise description of the changes in this PR. Explain the purpose, the problem it solves, and any relevant context. ## Related Issue If applicable, link to the related feature request or bug report issue (e.g., #123). If none, state "N/A". ## Type of Change - [ ] Bug fix - [x] New feature - [ ] Documentation update - [ ] Other (please specify) ## Additional Notes Add any other information or context that might be helpful for reviewers.
1 parent fffa16d commit 78d7de9

File tree

6 files changed

+280
-37
lines changed

6 files changed

+280
-37
lines changed

src/xai_sdk/aio/chat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ async def sample(self) -> Response:
106106
kind=SpanKind.CLIENT,
107107
attributes=self._make_span_request_attributes(),
108108
) as span:
109+
index = None if self._uses_server_side_tools() else 0
109110
response_pb = await self._stub.GetCompletion(self._make_request(1))
110-
response = Response(response_pb, 0)
111+
response = Response(response_pb, index)
111112
span.set_attributes(self._make_span_response_attributes([response]))
112113
return response
113114

@@ -180,7 +181,8 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
180181
kind=SpanKind.CLIENT,
181182
attributes=self._make_span_request_attributes(),
182183
) as span:
183-
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), 0)
184+
index = None if self._uses_server_side_tools() else 0
185+
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index)
184186
stream = self._stub.GetCompletionChunk(self._make_request(1))
185187

186188
async for chunk in stream:
@@ -191,7 +193,7 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
191193
first_chunk_received = True
192194

193195
response.process_chunk(chunk)
194-
chunk_obj = Chunk(chunk, 0)
196+
chunk_obj = Chunk(chunk, index)
195197
yield response, chunk_obj
196198

197199
span.set_attributes(self._make_span_response_attributes([response]))

src/xai_sdk/chat.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def append(self, message: Union[chat_pb2.Message, "Response"]) -> Self:
288288
elif isinstance(message, Response):
289289
self._proto.messages.append(
290290
chat_pb2.Message(
291-
role=message._choice.message.role,
291+
role=message._get_output().message.role,
292292
content=[text(message.content)],
293293
tool_calls=message.tool_calls,
294294
)
@@ -472,6 +472,10 @@ def _get_span_completion_attributes(self, responses: Sequence["Response"]) -> di
472472

473473
return completion_attributes
474474

475+
def _uses_server_side_tools(self) -> bool:
476+
"""Returns True if the server side tool is used in the request."""
477+
return any(tool.WhichOneof("tool") != "function" for tool in self._proto.tools)
478+
475479
@property
476480
def messages(self) -> Sequence[chat_pb2.Message]:
477481
"""Returns the messages in the conversation."""
@@ -682,22 +686,26 @@ def _format_type_to_proto(format_type: ResponseFormat) -> chat_pb2.FormatType:
682686
class Chunk(ProtoDecorator[chat_pb2.GetChatCompletionChunk]):
683687
"""Adds convenience functions to the chunk proto."""
684688

685-
_index: int
689+
_index: int | None
686690

687-
def __init__(self, proto: chat_pb2.GetChatCompletionChunk, index: int):
691+
def __init__(self, proto: chat_pb2.GetChatCompletionChunk, index: int | None):
688692
"""Creates a new decorator instance.
689693
690694
Args:
691695
proto: Chunk proto to wrap.
692-
index: Index of the response to track.
696+
index: Index of the response to track. If set to None, the chunk will expose all assistant outputs.
693697
"""
694698
super().__init__(proto)
695699
self._index = index
696700

697701
@property
698702
def choices(self) -> Sequence["ChoiceChunk"]:
699703
"""Returns the choices belonging to this index."""
700-
return [ChoiceChunk(c) for c in self.proto.outputs if c.index == self._index]
704+
return [
705+
ChoiceChunk(c)
706+
for c in self.proto.outputs
707+
if c.delta.role == chat_pb2.MessageRole.ROLE_ASSISTANT and (c.index == self._index or self._index is None)
708+
]
701709

702710
@property
703711
def output(self) -> str:
@@ -777,6 +785,14 @@ def process_chunk(self, chunk: chat_pb2.GetChatCompletionChunk):
777785
self._proto.system_fingerprint = chunk.system_fingerprint
778786
self._proto.citations.extend(chunk.citations)
779787

788+
# Make sure all chunk outputs has corresponding response outputs.
789+
if chunk.outputs:
790+
max_index = max(c.index for c in chunk.outputs)
791+
if max_index >= len(self._proto.outputs):
792+
self._proto.outputs.extend(
793+
[chat_pb2.CompletionOutput() for _ in range(max_index + 1 - len(self._proto.outputs))]
794+
)
795+
780796
for c in chunk.outputs:
781797
choice = self._proto.outputs[c.index]
782798
choice.index = c.index
@@ -792,29 +808,30 @@ class Response(_ResponseProtoDecorator):
792808

793809
# A single request can produce multiple responses. This index is used to retrieve the content of
794810
# a single answer from the response proto.
795-
_index: int
796-
# Cache to the answer indexed by this response.
797-
_choice: chat_pb2.CompletionOutput
811+
_index: int | None
798812

799-
def __init__(self, response: chat_pb2.GetChatCompletionResponse, index: int) -> None:
813+
def __init__(self, response: chat_pb2.GetChatCompletionResponse, index: int | None) -> None:
800814
"""Initializes a new instance of the `Response` class.
801815
802816
Args:
803817
response: The response proto, which can hold multiple answers.
804818
index: The index of the answer this class exposes via its convenience methods.
819+
If set to None, the response will expose all answers, the content and reasoning content
820+
will be only from the assistant response.
805821
"""
806822
super().__init__(response)
807823
self._index = index
808824

809-
# Find and cache the answer identified by the index.
810-
choices = [c for c in response.outputs if c.index == index]
811-
812-
if not choices:
813-
raise ValueError(f"Invalid response proto or index. {response:} {index:}")
814-
elif len(choices) > 1:
815-
raise ValueError(f"More than one response for index {index:}. {response:}")
816-
else:
817-
self._choice = choices[0]
825+
def _get_output(self) -> chat_pb2.CompletionOutput:
826+
outputs = [
827+
output
828+
for output in self.proto.outputs
829+
if output.message.role == chat_pb2.MessageRole.ROLE_ASSISTANT
830+
and (output.index == self._index or self._index is None)
831+
]
832+
if not outputs:
833+
return chat_pb2.CompletionOutput()
834+
return outputs[-1]
818835

819836
@property
820837
def id(self) -> str:
@@ -824,12 +841,12 @@ def id(self) -> str:
824841
@property
825842
def content(self) -> str:
826843
"""Returns the answer content of this response."""
827-
return self._choice.message.content
844+
return self._get_output().message.content
828845

829846
@property
830847
def role(self) -> str:
831848
"""Returns the role of this response."""
832-
return chat_pb2.MessageRole.Name(self._choice.message.role)
849+
return chat_pb2.MessageRole.Name(self._get_output().message.role)
833850

834851
@property
835852
def usage(self) -> usage_pb2.SamplingUsage:
@@ -842,17 +859,17 @@ def reasoning_content(self) -> str:
842859
843860
This is only available for models that support reasoning.
844861
"""
845-
return self._choice.message.reasoning_content
862+
return self._get_output().message.reasoning_content
846863

847864
@property
848865
def finish_reason(self) -> str:
849866
"""Returns the finish reason of this response."""
850-
return sample_pb2.FinishReason.Name(self._choice.finish_reason)
867+
return sample_pb2.FinishReason.Name(self._get_output().finish_reason)
851868

852869
@property
853870
def logprobs(self) -> chat_pb2.LogProbs:
854871
"""Returns the logprobs of this response."""
855-
return self._choice.logprobs
872+
return self._get_output().logprobs
856873

857874
@property
858875
def system_fingerprint(self) -> str:
@@ -861,8 +878,13 @@ def system_fingerprint(self) -> str:
861878

862879
@property
863880
def tool_calls(self) -> Sequence[chat_pb2.ToolCall]:
864-
"""Returns the tool calls of this response."""
865-
return self._choice.message.tool_calls
881+
"""Returns the all tool calls of this response."""
882+
return [
883+
tc
884+
for c in self.proto.outputs
885+
if c.message.role == chat_pb2.MessageRole.ROLE_ASSISTANT
886+
for tc in c.message.tool_calls
887+
]
866888

867889
@property
868890
def citations(self) -> Sequence[str]:

src/xai_sdk/sync/chat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def sample(self) -> Response:
104104
kind=SpanKind.CLIENT,
105105
attributes=self._make_span_request_attributes(),
106106
) as span:
107+
index = None if self._uses_server_side_tools() else 0
107108
response_pb = self._stub.GetCompletion(self._make_request(1))
108-
response = Response(response_pb, 0)
109+
response = Response(response_pb, index)
109110
span.set_attributes(self._make_span_response_attributes([response]))
110111
return response
111112

@@ -176,7 +177,8 @@ def stream(self) -> Iterator[tuple[Response, Chunk]]:
176177
kind=SpanKind.CLIENT,
177178
attributes=self._make_span_request_attributes(),
178179
) as span:
179-
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), 0)
180+
index = None if self._uses_server_side_tools() else 0
181+
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index)
180182
stream = self._stub.GetCompletionChunk(self._make_request(1))
181183

182184
for chunk in stream:
@@ -187,7 +189,7 @@ def stream(self) -> Iterator[tuple[Response, Chunk]]:
187189
first_chunk_received = True
188190

189191
response.process_chunk(chunk)
190-
chunk_obj = Chunk(chunk, 0)
192+
chunk_obj = Chunk(chunk, index)
191193
yield response, chunk_obj
192194

193195
span.set_attributes(self._make_span_response_attributes([response]))

tests/aio/chat_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,67 @@ async def test_function_calling_streaming_batch(client):
352352
assert response.tool_calls[0].function.arguments == '{"city":"London","units":"C"}'
353353

354354

355+
@pytest.mark.asyncio(loop_scope="session")
356+
async def test_agentic_tool_calling_streaming(client):
357+
chat = client.chat.create(
358+
"grok-4-fast",
359+
tools=[web_search()],
360+
)
361+
chat.append(user("What is the weather in London?"))
362+
stream = chat.stream()
363+
364+
expected_chunks = [
365+
"I",
366+
" am",
367+
" searching",
368+
".",
369+
"", # Final chunk is a tool call which has no content set
370+
]
371+
372+
last_response = None
373+
i = 0
374+
async for response, chunk in stream:
375+
last_response = response
376+
if i == 0:
377+
assert chunk.tool_calls[0].function.name == "web_search"
378+
assert chunk.tool_calls[0].function.arguments == '{"query":"What is the weather in London?"}'
379+
elif i == 1:
380+
assert chunk.proto.outputs[0].delta.role == chat_pb2.ROLE_TOOL
381+
assert chunk.proto.outputs[0].delta.content == "I am tool response"
382+
assert chunk.content == ""
383+
else:
384+
assert chunk.content == expected_chunks[i - 2]
385+
i += 1
386+
387+
assert last_response is not None
388+
assert last_response.content == "I am searching."
389+
assert len(last_response.tool_calls) == 1
390+
assert last_response.finish_reason == "REASON_STOP"
391+
assert last_response.role == "ROLE_ASSISTANT"
392+
assert last_response.tool_calls[0].function.name == "web_search"
393+
assert last_response.tool_calls[0].function.arguments == '{"query":"What is the weather in London?"}'
394+
395+
396+
@pytest.mark.asyncio(loop_scope="session")
397+
async def test_agentic_tool_calling_non_streaming(client):
398+
chat = client.chat.create(
399+
"grok-4-fast",
400+
tools=[web_search()],
401+
)
402+
chat.append(user("What is the weather in London?"))
403+
response = await chat.sample()
404+
405+
assert len(response.proto.outputs) == 3
406+
assert response.proto.outputs[1].message.role == chat_pb2.ROLE_TOOL
407+
assert response.proto.outputs[1].message.content == "I am tool response"
408+
assert response.content == "I am searching."
409+
assert len(response.tool_calls) == 1
410+
assert response.finish_reason == "REASON_STOP"
411+
assert response.role == "ROLE_ASSISTANT"
412+
assert response.tool_calls[0].function.name == "web_search"
413+
assert response.tool_calls[0].function.arguments == '{"query":"What is the weather in London?"}'
414+
415+
355416
@pytest.mark.asyncio(loop_scope="session")
356417
async def test_structured_output(client):
357418
class Weather(BaseModel):

0 commit comments

Comments
 (0)