Skip to content

Commit c27271f

Browse files
authored
fix(openai): update file index key name (#33350)
1 parent a3e4f4c commit c27271f

File tree

4 files changed

+76
-11
lines changed

4 files changed

+76
-11
lines changed

libs/partners/openai/langchain_openai/chat_models/_compat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,19 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
187187
new_ann["title"] = annotation["title"]
188188
new_ann["type"] = "url_citation"
189189
new_ann["url"] = annotation["url"]
190+
191+
if extra_fields := annotation.get("extras"):
192+
new_ann.update(dict(extra_fields.items()))
190193
else:
191194
# Document citation
192195
new_ann["type"] = "file_citation"
196+
197+
if extra_fields := annotation.get("extras"):
198+
new_ann.update(dict(extra_fields.items()))
199+
193200
if "title" in annotation:
194201
new_ann["filename"] = annotation["title"]
195202

196-
if extra_fields := annotation.get("extras"):
197-
new_ann.update(dict(extra_fields.items()))
198-
199203
return new_ann
200204

201205
if annotation["type"] == "non_standard_annotation":

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,6 +3589,24 @@ def _construct_responses_api_payload(
35893589
return payload
35903590

35913591

3592+
def _format_annotation_to_lc(annotation: dict[str, Any]) -> dict[str, Any]:
3593+
# langchain-core reserves the `"index"` key for streaming aggregation.
3594+
# Here we re-name.
3595+
if annotation.get("type") == "file_citation" and "index" in annotation:
3596+
new_annotation = annotation.copy()
3597+
new_annotation["file_index"] = new_annotation.pop("index")
3598+
return new_annotation
3599+
return annotation
3600+
3601+
3602+
def _format_annotation_from_lc(annotation: dict[str, Any]) -> dict[str, Any]:
3603+
if annotation.get("type") == "file_citation" and "file_index" in annotation:
3604+
new_annotation = annotation.copy()
3605+
new_annotation["index"] = new_annotation.pop("file_index")
3606+
return new_annotation
3607+
return annotation
3608+
3609+
35923610
def _convert_chat_completions_blocks_to_responses(
35933611
block: dict[str, Any],
35943612
) -> dict[str, Any]:
@@ -3775,7 +3793,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
37753793
new_block = {
37763794
"type": "output_text",
37773795
"text": block["text"],
3778-
"annotations": block.get("annotations") or [],
3796+
"annotations": [
3797+
_format_annotation_from_lc(annotation)
3798+
for annotation in block.get("annotations") or []
3799+
],
37793800
}
37803801
elif block_type == "refusal":
37813802
new_block = {
@@ -3951,7 +3972,7 @@ def _construct_lc_result_from_responses_api(
39513972
"type": "text",
39523973
"text": content.text,
39533974
"annotations": [
3954-
annotation.model_dump()
3975+
_format_annotation_to_lc(annotation.model_dump())
39553976
for annotation in content.annotations
39563977
]
39573978
if isinstance(content.annotations, list)
@@ -4142,7 +4163,11 @@ def _advance(output_idx: int, sub_idx: int | None = None) -> None:
41424163
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
41434164

41444165
content.append(
4145-
{"type": "text", "annotations": [annotation], "index": current_index}
4166+
{
4167+
"type": "text",
4168+
"annotations": [_format_annotation_to_lc(annotation)],
4169+
"index": current_index,
4170+
}
41464171
)
41474172
elif chunk.type == "response.output_text.done":
41484173
content.append({"type": "text", "id": chunk.item_id, "index": current_index})
26 KB
Binary file not shown.

libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _check_response(response: BaseMessage | None) -> None:
3434
if annotation["type"] == "file_citation":
3535
assert all(
3636
key in annotation
37-
for key in ["file_id", "filename", "index", "type"]
37+
for key in ["file_id", "filename", "file_index", "type"]
3838
)
3939
elif annotation["type"] == "web_search":
4040
assert all(
@@ -374,9 +374,17 @@ def test_computer_calls() -> None:
374374
assert response.additional_kwargs["tool_outputs"]
375375

376376

377-
def test_file_search() -> None:
378-
pytest.skip() # TODO: set up infra
379-
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
377+
@pytest.mark.default_cassette("test_file_search.yaml.gz")
378+
@pytest.mark.vcr
379+
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
380+
def test_file_search(
381+
output_version: Literal["responses/v1", "v1"],
382+
) -> None:
383+
llm = ChatOpenAI(
384+
model=MODEL_NAME,
385+
use_responses_api=True,
386+
output_version=output_version,
387+
)
380388
tool = {
381389
"type": "file_search",
382390
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
@@ -386,16 +394,44 @@ def test_file_search() -> None:
386394
response = llm.invoke([input_message], tools=[tool])
387395
_check_response(response)
388396

389-
full: BaseMessageChunk | None = None
397+
if output_version == "v1":
398+
assert [block["type"] for block in response.content] == [ # type: ignore[index]
399+
"server_tool_call",
400+
"server_tool_result",
401+
"text",
402+
]
403+
else:
404+
assert [block["type"] for block in response.content] == [ # type: ignore[index]
405+
"file_search_call",
406+
"text",
407+
]
408+
409+
full: AIMessageChunk | None = None
390410
for chunk in llm.stream([input_message], tools=[tool]):
391411
assert isinstance(chunk, AIMessageChunk)
392412
full = chunk if full is None else full + chunk
393413
assert isinstance(full, AIMessageChunk)
394414
_check_response(full)
395415

416+
if output_version == "v1":
417+
assert [block["type"] for block in full.content] == [ # type: ignore[index]
418+
"server_tool_call",
419+
"server_tool_result",
420+
"text",
421+
]
422+
else:
423+
assert [block["type"] for block in full.content] == ["file_search_call", "text"] # type: ignore[index]
424+
396425
next_message = {"role": "user", "content": "Thank you."}
397426
_ = llm.invoke([input_message, full, next_message])
398427

428+
for message in [response, full]:
429+
assert [block["type"] for block in message.content_blocks] == [
430+
"server_tool_call",
431+
"server_tool_result",
432+
"text",
433+
]
434+
399435

400436
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
401437
@pytest.mark.vcr

0 commit comments

Comments
 (0)