Skip to content

Commit

Permalink
support citations in streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Feb 4, 2025
1 parent 5771e56 commit 56d708d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
48 changes: 32 additions & 16 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ def _stream(
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = self._client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
Expand All @@ -745,7 +747,9 @@ async def _astream(
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
stream = await self._async_client.messages.create(**payload)
coerce_content_to_string = not _tools_in_params(payload)
coerce_content_to_string = not _tools_in_params(
payload
) and not _documents_in_params(payload)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
event,
Expand Down Expand Up @@ -1254,6 +1258,16 @@ def _tools_in_params(params: dict) -> bool:
)


def _documents_in_params(params: dict) -> bool:
for message in params.get("messages", []):
for block in message.get("content", []):
if block.get("type") == "document" and block.get("citations", {}).get(
"enabled"
):
return True
return False


class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
Expand Down Expand Up @@ -1296,34 +1310,36 @@ def _make_message_chunk_from_anthropic_event(
content="" if coerce_content_to_string else [],
usage_metadata=usage_metadata,
)
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type == "tool_use"
):
elif event.type == "content_block_start" and event.content_block is not None:
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump()
content_block["index"] = event.index
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
if event.content_block.type == "tool_use":
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
tool_call_chunks = [tool_call_chunk]
else:
tool_call_chunks = []
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore
tool_call_chunks=tool_call_chunks, # type: ignore
)
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
if coerce_content_to_string:
if event.delta.type in ("text_delta", "citations_delta"):
if coerce_content_to_string and hasattr(event.delta, "text"):
text = event.delta.text
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "text"
if "citation" in content_block:
content_block["citations"] = [content_block.pop("citation")]
message_chunk = AIMessageChunk(content=[content_block])
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from base64 import b64encode
from typing import List, Optional
from typing import List, Optional, cast

import pytest
import requests
Expand Down Expand Up @@ -649,3 +649,23 @@ def test_citations() -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert any("citations" in block for block in response.content)

# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(messages):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, list)
assert any("citations" in block for block in full.content)
for block in full.content:
assert "citation" not in block

streamed_citations = [
cast(dict, block)["citations"] for block in full.content if "citations" in block
]
invoked_citations = [
cast(dict, block)["citations"]
for block in response.content
if "citations" in block
]
assert streamed_citations == invoked_citations

0 comments on commit 56d708d

Please sign in to comment.