Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions litellm/llms/bedrock/count_tokens/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
AWS Bedrock CountTokens API handler.

Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
"""

from typing import Any, Dict

from fastapi import HTTPException

from litellm._logging import verbose_logger
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig


class BedrockCountTokensHandler(BedrockCountTokensConfig):
"""
Simplified handler for AWS Bedrock CountTokens API requests.

Uses existing LiteLLM infrastructure for authentication and request handling.
"""

async def handle_count_tokens_request(
self,
request_data: Dict[str, Any],
litellm_params: Dict[str, Any],
resolved_model: str,
) -> Dict[str, Any]:
"""
Handle a CountTokens request using existing LiteLLM patterns.

Args:
request_data: The incoming request payload
litellm_params: LiteLLM configuration parameters
resolved_model: The actual model ID resolved from router

Returns:
Dictionary containing token count response
"""
try:
# Validate the request
self.validate_count_tokens_request(request_data)

verbose_logger.debug(
f"Processing CountTokens request for resolved model: {resolved_model}"
)

# Get AWS region using existing LiteLLM function
aws_region_name = self._get_aws_region_name(
optional_params=litellm_params,
model=resolved_model,
model_id=None,
)

verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")

# Transform request to Bedrock format (supports both Converse and InvokeModel)
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
request_data=request_data
)

verbose_logger.debug(f"Transformed request: {bedrock_request}")

# Get endpoint URL using simplified function
endpoint_url = self.get_bedrock_count_tokens_endpoint(
resolved_model, aws_region_name
)

verbose_logger.debug(f"Making request to: {endpoint_url}")

# Use existing _sign_request method from BaseAWSLLM
headers = {"Content-Type": "application/json"}
signed_headers, signed_body = self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=litellm_params,
request_data=bedrock_request,
api_base=endpoint_url,
model=resolved_model,
)

# Make HTTP request
import httpx

async with httpx.AsyncClient() as client:
response = await client.post(
endpoint_url,
headers=signed_headers,
content=signed_body,
timeout=30.0,
)

verbose_logger.debug(f"Response status: {response.status_code}")

if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"AWS Bedrock error: {error_text}")
raise HTTPException(
status_code=400,
detail={"error": f"AWS Bedrock error: {error_text}"},
)

bedrock_response = response.json()

verbose_logger.debug(f"Bedrock response: {bedrock_response}")

# Transform response back to expected format
final_response = self.transform_bedrock_response_to_anthropic(
bedrock_response
)

verbose_logger.debug(f"Final response: {final_response}")

return final_response

except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": f"CountTokens processing error: {str(e)}"},
)
213 changes: 213 additions & 0 deletions litellm/llms/bedrock/count_tokens/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
AWS Bedrock CountTokens API transformation logic.

This module handles the transformation of requests from Anthropic Messages API format
to AWS Bedrock's CountTokens API format and vice versa.
"""

from typing import Any, Dict, List

from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import BedrockModelInfo


class BedrockCountTokensConfig(BaseAWSLLM):
"""
Configuration and transformation logic for AWS Bedrock CountTokens API.

AWS Bedrock CountTokens API Specification:
- Endpoint: POST /model/{modelId}/count-tokens
- Input formats: 'invokeModel' or 'converse'
- Response: {"inputTokens": <number>}
"""

def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
"""
Detect whether to use 'converse' or 'invokeModel' input format.

Args:
request_data: The original request data

Returns:
'converse' or 'invokeModel'
"""
# If the request has messages in the expected Anthropic format, use converse
if "messages" in request_data and isinstance(request_data["messages"], list):
return "converse"

# For raw text or other formats, use invokeModel
# This handles cases where the input is prompt-based or already in raw Bedrock format
return "invokeModel"

def transform_anthropic_to_bedrock_count_tokens(
self,
request_data: Dict[str, Any],
) -> Dict[str, Any]:
"""
Transform request to Bedrock CountTokens format.
Supports both Converse and InvokeModel input types.

Input (Anthropic format):
{
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Hello!"}]
}

Output (Bedrock CountTokens format for Converse):
{
"input": {
"converse": {
"messages": [...],
"system": [...] (if present)
}
}
}

Output (Bedrock CountTokens format for InvokeModel):
{
"input": {
"invokeModel": {
"body": "{...raw model input...}"
}
}
}
"""
input_type = self._detect_input_type(request_data)

if input_type == "converse":
return self._transform_to_converse_format(request_data.get("messages", []))
else:
return self._transform_to_invoke_model_format(request_data)

def _transform_to_converse_format(
self, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Transform to Converse input format."""
# Extract system messages if present
system_messages = []
user_messages = []

for message in messages:
if message.get("role") == "system":
system_messages.append({"text": message.get("content", "")})
else:
# Transform message content to Bedrock format
transformed_message: Dict[str, Any] = {"role": message.get("role"), "content": []}

# Handle content - ensure it's in the correct array format
content = message.get("content", "")
if isinstance(content, str):
# String content -> convert to text block
transformed_message["content"].append({"text": content})
elif isinstance(content, list):
# Already in blocks format - use as is
transformed_message["content"] = content

user_messages.append(transformed_message)

# Build the converse input format
converse_input = {"messages": user_messages}

# Add system messages if present
if system_messages:
converse_input["system"] = system_messages

# Build the complete request
return {"input": {"converse": converse_input}}

def _transform_to_invoke_model_format(
self, request_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Transform to InvokeModel input format."""
import json

# For InvokeModel, we need to provide the raw body that would be sent to the model
# Remove the 'model' field from the body as it's not part of the model input
body_data = {k: v for k, v in request_data.items() if k != "model"}

return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}

def get_bedrock_count_tokens_endpoint(
self, model: str, aws_region_name: str
) -> str:
"""
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.

Args:
model: The resolved model ID from router lookup
aws_region_name: AWS region (e.g., "eu-west-1")

Returns:
Complete endpoint URL for CountTokens API
"""
# Use existing LiteLLM function to get the base model ID (removes region prefix)
model_id = BedrockModelInfo.get_base_model(model)

# Remove bedrock/ prefix if present
if model_id.startswith("bedrock/"):
model_id = model_id[8:] # Remove "bedrock/" prefix

base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
endpoint = f"{base_url}/model/{model_id}/count-tokens"

return endpoint

def transform_bedrock_response_to_anthropic(
self, bedrock_response: Dict[str, Any]
) -> Dict[str, Any]:
"""
Transform Bedrock CountTokens response to Anthropic format.

Input (Bedrock response):
{
"inputTokens": 123
}

Output (Anthropic format):
{
"input_tokens": 123
}
"""
input_tokens = bedrock_response.get("inputTokens", 0)

return {"input_tokens": input_tokens}

def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
"""
Validate the incoming count tokens request.
Supports both Converse and InvokeModel input formats.

Args:
request_data: The request payload

Raises:
ValueError: If the request is invalid
"""
if not request_data.get("model"):
raise ValueError("model parameter is required")

input_type = self._detect_input_type(request_data)

if input_type == "converse":
# Validate Converse format (messages-based)
messages = request_data.get("messages", [])
if not messages:
raise ValueError("messages parameter is required for Converse input")

if not isinstance(messages, list):
raise ValueError("messages must be a list")

for i, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"Message {i} must be a dictionary")

if "role" not in message:
raise ValueError(f"Message {i} must have a 'role' field")

if "content" not in message:
raise ValueError(f"Message {i} must have a 'content' field")
else:
# For InvokeModel format, we need at least some content to count tokens
# The content structure varies by model, so we do minimal validation
if len(request_data) <= 1: # Only has 'model' field
raise ValueError("Request must contain content to count tokens")
Loading
Loading