Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 225 additions & 16 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Union

from anthropic import AnthropicVertex
from anthropic import AsyncAnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
Expand Down Expand Up @@ -84,8 +85,20 @@ def part_to_message_block(
anthropic_types.ImageBlockParam,
anthropic_types.ToolUseBlockParam,
anthropic_types.ToolResultBlockParam,
dict, # For thinking blocks
]:
if part.text:
# Handle thinking blocks (must check thought=True BEFORE text)
# Thinking is stored as Part(text=..., thought=True, thought_signature=...)
if part.text and hasattr(part, 'thought') and part.thought:
thinking_block = {"type": "thinking", "thinking": part.text}
if hasattr(part, 'thought_signature') and part.thought_signature:
# thought_signature is stored as bytes in Part, but API expects base64 string
thinking_block["signature"] = base64.b64encode(part.thought_signature).decode('utf-8')
logger.debug(f"Including signature with thinking block")
else:
logger.warning(f"No signature found for thinking block - this may cause API errors")
return thinking_block
elif part.text:
return anthropic_types.TextBlockParam(text=part.text, type="text")
elif part.function_call:
assert part.function_call.name
Expand Down Expand Up @@ -139,14 +152,26 @@ def part_to_message_block(
def content_to_message_param(
content: types.Content,
) -> anthropic_types.MessageParam:
message_block = []
thinking_blocks = []
other_blocks = []

for part in content.parts or []:
# Image data is not supported in Claude for model turns.
if _is_image_part(part):
logger.warning("Image data is not supported in Claude for model turns.")
continue

message_block.append(part_to_message_block(part))
block = part_to_message_block(part)

# Separate thinking blocks from other blocks
# Anthropic requires thinking blocks to come FIRST in assistant messages
if isinstance(block, dict) and block.get("type") == "thinking":
thinking_blocks.append(block)
else:
other_blocks.append(block)

# Thinking blocks MUST come first (Anthropic API requirement)
message_block = thinking_blocks + other_blocks

return {
"role": to_claude_role(content.role),
Expand All @@ -166,7 +191,84 @@ def content_block_to_part(
)
part.function_call.id = content_block.id
return part
raise NotImplementedError("Not supported yet.")

# Handle thinking blocks from Anthropic extended thinking feature
# Thinking blocks have a 'thinking' attribute containing the reasoning text
if hasattr(content_block, "thinking"):
thinking_text = content_block.thinking
signature = getattr(content_block, 'signature', None)
logger.info(f"Received thinking block ({len(thinking_text)} chars, signature={'present' if signature else 'missing'})")
# Return as Part with thought=True and preserve signature (standard GenAI format)
return types.Part(text=thinking_text, thought=True, thought_signature=signature)

# Alternative check: some versions may use type attribute
if (
hasattr(content_block, "type")
and getattr(content_block, "type", None) == "thinking"
):
thinking_text = str(content_block)
signature = getattr(content_block, 'signature', None)
logger.info(
f"Received thinking block via type check ({len(thinking_text)} chars, signature={'present' if signature else 'missing'})"
)
# Return as Part with thought=True and preserve signature (standard GenAI format)
return types.Part(text=thinking_text, thought=True, thought_signature=signature)

raise NotImplementedError(
f"Not supported yet: {type(content_block).__name__}"
)


def streaming_event_to_llm_response(
event: anthropic_types.MessageStreamEvent,
) -> Optional[LlmResponse]:
"""Convert Anthropic streaming events to ADK LlmResponse format.

Args:
event: Anthropic streaming event

Returns:
LlmResponse or None if event should be skipped
"""
# Handle content block deltas
if event.type == "content_block_delta":
delta = event.delta

# Text delta
if delta.type == "text_delta":
return LlmResponse(
content=types.Content(
role="model",
parts=[types.Part.from_text(text=delta.text)],
),
partial=True,
)

# Thinking delta
elif delta.type == "thinking_delta":
return LlmResponse(
content=types.Content(
role="model",
parts=[types.Part(text=delta.thinking, thought=True)],
),
partial=True,
)

# Handle message deltas (usage updates)
elif event.type == "message_delta":
if hasattr(event, "usage"):
input_tokens = getattr(event.usage, "input_tokens", 0) or 0
output_tokens = getattr(event.usage, "output_tokens", 0) or 0
return LlmResponse(
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=input_tokens,
candidates_token_count=output_tokens,
total_token_count=input_tokens + output_tokens,
),
)

# Ignore start/stop events
return None


def message_to_generate_content_response(
Expand Down Expand Up @@ -250,10 +352,12 @@ class Claude(BaseLlm):
Attributes:
model: The name of the Claude model.
max_tokens: The maximum number of tokens to generate.
extra_headers: Optional extra headers to pass to the Anthropic API.
"""

model: str = "claude-3-5-sonnet-v2@20241022"
max_tokens: int = 8192
extra_headers: Optional[dict[str, str]] = None

@classmethod
@override
Expand Down Expand Up @@ -283,19 +387,124 @@ async def generate_content_async(
if llm_request.tools_dict
else NOT_GIVEN
)
# TODO(b/421255973): Enable streaming for anthropic models.
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
)
yield message_to_generate_content_response(message)

# Extract and convert thinking config from ADK to Anthropic format
thinking = NOT_GIVEN

if llm_request.config and llm_request.config.thinking_config:
budget = llm_request.config.thinking_config.thinking_budget
if budget:
if budget == -1:
raise ValueError(
"Unlimited thinking budget (-1) is not supported with Claude."
)
elif budget > 0:

thinking = {"type": "enabled", "budget_tokens": budget}
logger.info(
f"Extended thinking enabled (budget: {budget} tokens)"
)
else:
logger.warning(f"Budget not given! budget={budget}")
else:
logger.warning(f"No thinking_config found in llm_request.config")

# Use extra headers if provided
extra_headers = self.extra_headers or NOT_GIVEN

if stream:
# Use streaming mode
logger.info(
f"Using streaming mode (stream={stream}, "
f"has_thinking={thinking != NOT_GIVEN}, "
f"large_max_tokens={self.max_tokens >= 8192})"
)

# Accumulators for text and thinking
accumulated_text = ""
accumulated_thinking = ""

async with self._anthropic_client.messages.stream(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
thinking=thinking,
extra_headers=extra_headers,
) as anthropic_stream:
# Process streaming events
async for event in anthropic_stream:
# Convert Anthropic event to LlmResponse
if llm_response := streaming_event_to_llm_response(event):
# Track accumulated content
is_thought = False
if llm_response.content and llm_response.content.parts:
for part in llm_response.content.parts:
if part.text:
if hasattr(part, "thought") and part.thought:
accumulated_thinking += part.text
is_thought = True
else:
accumulated_text += part.text

# If we have accumulated thinking and now getting text,
# yield the accumulated thinking first
# NOTE: This partial response is for UI display only
# The final response with signature will be yielded after the stream ends
if accumulated_thinking and accumulated_text and not is_thought:
yield LlmResponse(
content=types.Content(
role="model",
parts=[
types.Part(text=accumulated_thinking, thought=True)
],
),
partial=True,
)
accumulated_thinking = "" # Reset after yielding

# Yield partial response (but skip individual thought deltas)
if not is_thought:
yield llm_response

# Get final message with complete content blocks (includes signatures)
final_message = await anthropic_stream.get_final_message()

# Build final response from complete content blocks to preserve thinking signatures
# IMPORTANT: Use final_message.content instead of accumulated strings
# because accumulated strings don't have signatures
if final_message.content:
parts = [content_block_to_part(cb) for cb in final_message.content]
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
yield LlmResponse(
content=types.Content(role="model", parts=parts),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=input_tokens,
candidates_token_count=output_tokens,
total_token_count=input_tokens + output_tokens,
),
)

else:
# Non-streaming mode
logger.info("Using non-streaming mode")
message = await self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
thinking=thinking,
extra_headers=extra_headers,
)
yield message_to_generate_content_response(message)

@cached_property
def _anthropic_client(self) -> AnthropicVertex:
def _anthropic_client(self) -> AsyncAnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
Expand All @@ -305,7 +514,7 @@ def _anthropic_client(self) -> AnthropicVertex:
" Anthropic on Vertex."
)

return AnthropicVertex(
return AsyncAnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)
4 changes: 3 additions & 1 deletion tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ async def test_function_declaration_to_tool_param(

@pytest.mark.asyncio
async def test_generate_content_async(
claude_llm, llm_request, generate_content_response, generate_llm_response
llm_request, generate_content_response, generate_llm_response
):
# Use max_tokens < 8192 to trigger non-streaming mode
claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096)
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
with mock.patch.object(
anthropic_llm,
Expand Down
Loading