diff --git a/litellm/llms/bedrock/count_tokens/handler.py b/litellm/llms/bedrock/count_tokens/handler.py new file mode 100644 index 000000000000..3cabdf816fa5 --- /dev/null +++ b/litellm/llms/bedrock/count_tokens/handler.py @@ -0,0 +1,123 @@ +""" +AWS Bedrock CountTokens API handler. + +Simplified handler leveraging existing LiteLLM Bedrock infrastructure. +""" + +from typing import Any, Dict + +from fastapi import HTTPException + +from litellm._logging import verbose_logger +from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig + + +class BedrockCountTokensHandler(BedrockCountTokensConfig): + """ + Simplified handler for AWS Bedrock CountTokens API requests. + + Uses existing LiteLLM infrastructure for authentication and request handling. + """ + + async def handle_count_tokens_request( + self, + request_data: Dict[str, Any], + litellm_params: Dict[str, Any], + resolved_model: str, + ) -> Dict[str, Any]: + """ + Handle a CountTokens request using existing LiteLLM patterns. + + Args: + request_data: The incoming request payload + litellm_params: LiteLLM configuration parameters + resolved_model: The actual model ID resolved from router + + Returns: + Dictionary containing token count response + """ + try: + # Validate the request + self.validate_count_tokens_request(request_data) + + verbose_logger.debug( + f"Processing CountTokens request for resolved model: {resolved_model}" + ) + + # Get AWS region using existing LiteLLM function + aws_region_name = self._get_aws_region_name( + optional_params=litellm_params, + model=resolved_model, + model_id=None, + ) + + verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}") + + # Transform request to Bedrock format (supports both Converse and InvokeModel) + bedrock_request = self.transform_anthropic_to_bedrock_count_tokens( + request_data=request_data + ) + + verbose_logger.debug(f"Transformed request: {bedrock_request}") + + # Get endpoint URL using simplified function + endpoint_url = self.get_bedrock_count_tokens_endpoint( + resolved_model, aws_region_name + ) + + verbose_logger.debug(f"Making request to: {endpoint_url}") + + # Use existing _sign_request method from BaseAWSLLM + headers = {"Content-Type": "application/json"} + signed_headers, signed_body = self._sign_request( + service_name="bedrock", + headers=headers, + optional_params=litellm_params, + request_data=bedrock_request, + api_base=endpoint_url, + model=resolved_model, + ) + + # Make HTTP request + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + endpoint_url, + headers=signed_headers, + content=signed_body, + timeout=30.0, + ) + + verbose_logger.debug(f"Response status: {response.status_code}") + + if response.status_code != 200: + error_text = response.text + verbose_logger.error(f"AWS Bedrock error: {error_text}") + raise HTTPException( + status_code=400, + detail={"error": f"AWS Bedrock error: {error_text}"}, + ) + + bedrock_response = response.json() + + verbose_logger.debug(f"Bedrock response: {bedrock_response}") + + # Transform response back to expected format + final_response = self.transform_bedrock_response_to_anthropic( + bedrock_response + ) + + verbose_logger.debug(f"Final response: {final_response}") + + return final_response + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + verbose_logger.error(f"Error in CountTokens handler: {str(e)}") + raise HTTPException( + status_code=500, + detail={"error": f"CountTokens processing error: {str(e)}"}, + ) diff --git a/litellm/llms/bedrock/count_tokens/transformation.py b/litellm/llms/bedrock/count_tokens/transformation.py new file mode 100644 index 000000000000..d46ed3aa4522 --- /dev/null +++ b/litellm/llms/bedrock/count_tokens/transformation.py @@ -0,0 +1,213 @@ +""" +AWS Bedrock CountTokens API transformation logic. + +This module handles the transformation of requests from Anthropic Messages API format +to AWS Bedrock's CountTokens API format and vice versa. +""" + +from typing import Any, Dict, List + +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM +from litellm.llms.bedrock.common_utils import BedrockModelInfo + + +class BedrockCountTokensConfig(BaseAWSLLM): + """ + Configuration and transformation logic for AWS Bedrock CountTokens API. + + AWS Bedrock CountTokens API Specification: + - Endpoint: POST /model/{modelId}/count-tokens + - Input formats: 'invokeModel' or 'converse' + - Response: {"inputTokens": } + """ + + def _detect_input_type(self, request_data: Dict[str, Any]) -> str: + """ + Detect whether to use 'converse' or 'invokeModel' input format. + + Args: + request_data: The original request data + + Returns: + 'converse' or 'invokeModel' + """ + # If the request has messages in the expected Anthropic format, use converse + if "messages" in request_data and isinstance(request_data["messages"], list): + return "converse" + + # For raw text or other formats, use invokeModel + # This handles cases where the input is prompt-based or already in raw Bedrock format + return "invokeModel" + + def transform_anthropic_to_bedrock_count_tokens( + self, + request_data: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Transform request to Bedrock CountTokens format. + Supports both Converse and InvokeModel input types. + + Input (Anthropic format): + { + "model": "claude-3-5-sonnet", + "messages": [{"role": "user", "content": "Hello!"}] + } + + Output (Bedrock CountTokens format for Converse): + { + "input": { + "converse": { + "messages": [...], + "system": [...] (if present) + } + } + } + + Output (Bedrock CountTokens format for InvokeModel): + { + "input": { + "invokeModel": { + "body": "{...raw model input...}" + } + } + } + """ + input_type = self._detect_input_type(request_data) + + if input_type == "converse": + return self._transform_to_converse_format(request_data.get("messages", [])) + else: + return self._transform_to_invoke_model_format(request_data) + + def _transform_to_converse_format( + self, messages: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Transform to Converse input format.""" + # Extract system messages if present + system_messages = [] + user_messages = [] + + for message in messages: + if message.get("role") == "system": + system_messages.append({"text": message.get("content", "")}) + else: + # Transform message content to Bedrock format + transformed_message: Dict[str, Any] = {"role": message.get("role"), "content": []} + + # Handle content - ensure it's in the correct array format + content = message.get("content", "") + if isinstance(content, str): + # String content -> convert to text block + transformed_message["content"].append({"text": content}) + elif isinstance(content, list): + # Already in blocks format - use as is + transformed_message["content"] = content + + user_messages.append(transformed_message) + + # Build the converse input format + converse_input = {"messages": user_messages} + + # Add system messages if present + if system_messages: + converse_input["system"] = system_messages + + # Build the complete request + return {"input": {"converse": converse_input}} + + def _transform_to_invoke_model_format( + self, request_data: Dict[str, Any] + ) -> Dict[str, Any]: + """Transform to InvokeModel input format.""" + import json + + # For InvokeModel, we need to provide the raw body that would be sent to the model + # Remove the 'model' field from the body as it's not part of the model input + body_data = {k: v for k, v in request_data.items() if k != "model"} + + return {"input": {"invokeModel": {"body": json.dumps(body_data)}}} + + def get_bedrock_count_tokens_endpoint( + self, model: str, aws_region_name: str + ) -> str: + """ + Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions. + + Args: + model: The resolved model ID from router lookup + aws_region_name: AWS region (e.g., "eu-west-1") + + Returns: + Complete endpoint URL for CountTokens API + """ + # Use existing LiteLLM function to get the base model ID (removes region prefix) + model_id = BedrockModelInfo.get_base_model(model) + + # Remove bedrock/ prefix if present + if model_id.startswith("bedrock/"): + model_id = model_id[8:] # Remove "bedrock/" prefix + + base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + endpoint = f"{base_url}/model/{model_id}/count-tokens" + + return endpoint + + def transform_bedrock_response_to_anthropic( + self, bedrock_response: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Transform Bedrock CountTokens response to Anthropic format. + + Input (Bedrock response): + { + "inputTokens": 123 + } + + Output (Anthropic format): + { + "input_tokens": 123 + } + """ + input_tokens = bedrock_response.get("inputTokens", 0) + + return {"input_tokens": input_tokens} + + def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None: + """ + Validate the incoming count tokens request. + Supports both Converse and InvokeModel input formats. + + Args: + request_data: The request payload + + Raises: + ValueError: If the request is invalid + """ + if not request_data.get("model"): + raise ValueError("model parameter is required") + + input_type = self._detect_input_type(request_data) + + if input_type == "converse": + # Validate Converse format (messages-based) + messages = request_data.get("messages", []) + if not messages: + raise ValueError("messages parameter is required for Converse input") + + if not isinstance(messages, list): + raise ValueError("messages must be a list") + + for i, message in enumerate(messages): + if not isinstance(message, dict): + raise ValueError(f"Message {i} must be a dictionary") + + if "role" not in message: + raise ValueError(f"Message {i} must have a 'role' field") + + if "content" not in message: + raise ValueError(f"Message {i} must have a 'content' field") + else: + # For InvokeModel format, we need at least some content to count tokens + # The content structure varies by model, so we do minimal validation + if len(request_data) <= 1: # Only has 'model' field + raise ValueError("Request must contain content to count tokens") diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 82c5b3e343d9..56ee599325a5 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -172,7 +172,9 @@ async def gemini_proxy_route( request=request, api_key=f"Bearer {google_ai_studio_api_key}" ) - base_target_url = os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com" + base_target_url = ( + os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com" + ) encoded_endpoint = httpx.URL(endpoint).path # Ensure endpoint starts with '/' for proper URL construction @@ -464,6 +466,76 @@ async def anthropic_proxy_route( return received_value +async def handle_bedrock_count_tokens( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth, + request_body: Dict[str, Any], +) -> Dict[str, Any]: + """ + Handle AWS Bedrock CountTokens API requests. + + This function processes count_tokens endpoints like: + - /v1/messages/count_tokens + - /v1/messages/count-tokens + """ + from litellm.llms.bedrock.count_tokens.handler import BedrockCountTokensHandler + from litellm.proxy.proxy_server import llm_router + + try: + # Initialize the handler + handler = BedrockCountTokensHandler() + + # Extract model from request body + model = request_body.get("model") + if not model: + raise HTTPException( + status_code=400, detail={"error": "Model is required in request body"} + ) + + # Get model parameters from router + litellm_params = {"user_api_key_dict": user_api_key_dict} + resolved_model = model # Default fallback + + if llm_router: + deployments = llm_router.get_model_list(model_name=model) + if deployments and len(deployments) > 0: + # Get the first matching deployment + deployment = deployments[0] + model_litellm_params = deployment.get("litellm_params", {}) + + # Get the resolved model ID from the configuration + if "model" in model_litellm_params: + resolved_model = model_litellm_params["model"] + + # Copy all litellm_params - BaseAWSLLM will handle AWS credential discovery + for key, value in model_litellm_params.items(): + if key != "user_api_key_dict": # Don't overwrite user_api_key_dict + litellm_params[key] = value # type: ignore + + verbose_proxy_logger.debug(f"Count tokens litellm_params: {litellm_params}") + verbose_proxy_logger.debug(f"Resolved model: {resolved_model}") + + # Handle the count tokens request + result = await handler.handle_count_tokens_request( + request_data=request_body, + litellm_params=litellm_params, + resolved_model=resolved_model, + ) + + return result + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + verbose_proxy_logger.error(f"Error in handle_bedrock_count_tokens: {str(e)}") + raise HTTPException( + status_code=500, detail={"error": f"CountTokens processing error: {str(e)}"} + ) + + async def bedrock_llm_proxy_route( endpoint: str, request: Request, @@ -489,6 +561,17 @@ async def bedrock_llm_proxy_route( ) request_body = await _read_request_body(request=request) + + # Special handling for count_tokens endpoints + if "count_tokens" in endpoint or "count-tokens" in endpoint: + return await handle_bedrock_count_tokens( + endpoint=endpoint, + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + request_body=request_body, + ) + data: Dict[str, Any] = {} base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) try: @@ -505,13 +588,13 @@ async def bedrock_llm_proxy_route( "error": "Model missing from endpoint. Expected format: /model//. Got: " + endpoint, }, - ) + ) data["method"] = request.method data["endpoint"] = endpoint data["data"] = request_body data["custom_llm_provider"] = "bedrock" - + try: result = await base_llm_response_processor.base_passthrough_process_llm_request( request=request, diff --git a/tests/proxy_unit_tests/test_proxy_token_counter.py b/tests/proxy_unit_tests/test_proxy_token_counter.py index fdce6fa3c84f..534569f15afb 100644 --- a/tests/proxy_unit_tests/test_proxy_token_counter.py +++ b/tests/proxy_unit_tests/test_proxy_token_counter.py @@ -3,29 +3,24 @@ import sys, os -import traceback from dotenv import load_dotenv -from fastapi import Request -from datetime import datetime load_dotenv() -import os, io, time +import os # this file is to test litellm/proxy sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest, logging, asyncio -import litellm, asyncio +import pytest, logging +import litellm from litellm.proxy.proxy_server import token_counter -from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend from litellm._logging import verbose_proxy_logger verbose_proxy_logger.setLevel(level=logging.DEBUG) from litellm.proxy._types import TokenCountRequest -from litellm.types.utils import TokenCountResponse import json, tempfile @@ -105,7 +100,6 @@ def load_vertex_ai_credentials(): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) - @pytest.mark.asyncio async def test_vLLM_token_counting(): """ @@ -223,65 +217,70 @@ async def test_anthropic_messages_count_tokens_endpoint(): """ from litellm.proxy.anthropic_endpoints.endpoints import count_tokens from fastapi import Request - from unittest.mock import AsyncMock, MagicMock - + from unittest.mock import MagicMock + # Mock request object mock_request = MagicMock(spec=Request) mock_request_data = { "model": "claude-3-sonnet-20240229", - "messages": [{"role": "user", "content": "Hello Claude!"}] + "messages": [{"role": "user", "content": "Hello Claude!"}], } - + # Mock the _read_request_body function async def mock_read_request_body(request): return mock_request_data - + # Mock UserAPIKeyAuth mock_user_api_key_dict = MagicMock() - + # Patch the _read_request_body function import litellm.proxy.anthropic_endpoints.endpoints as anthropic_endpoints + original_read_request_body = anthropic_endpoints._read_request_body anthropic_endpoints._read_request_body = mock_read_request_body - + # Mock the internal token_counter function to return a controlled response async def mock_token_counter(request, call_endpoint=False): - assert call_endpoint == True, "Should be called with call_endpoint=True for Anthropic endpoint" + assert ( + call_endpoint == True + ), "Should be called with call_endpoint=True for Anthropic endpoint" assert request.model == "claude-3-sonnet-20240229" assert request.messages == [{"role": "user", "content": "Hello Claude!"}] - + from litellm.types.utils import TokenCountResponse + return TokenCountResponse( total_tokens=15, request_model="claude-3-sonnet-20240229", model_used="claude-3-sonnet-20240229", - tokenizer_type="openai_tokenizer" + tokenizer_type="openai_tokenizer", ) - + # Patch the imported token_counter function from proxy_server import litellm.proxy.proxy_server as proxy_server + original_token_counter = proxy_server.token_counter proxy_server.token_counter = mock_token_counter - + try: # Call the endpoint response = await count_tokens(mock_request, mock_user_api_key_dict) - + # Verify response format matches Anthropic spec assert isinstance(response, dict) assert "input_tokens" in response assert response["input_tokens"] == 15 assert len(response) == 1 # Should only contain input_tokens - + print("✅ Anthropic endpoint test passed!") - + finally: # Restore original functions anthropic_endpoints._read_request_body = original_read_request_body proxy_server.token_counter = original_token_counter -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_anthropic_messages_count_tokens_with_non_anthropic_model(): """ Test /v1/messages/count_tokens endpoint with non-Anthropic model (GPT-4) @@ -290,58 +289,63 @@ async def test_anthropic_messages_count_tokens_with_non_anthropic_model(): """ from litellm.proxy.anthropic_endpoints.endpoints import count_tokens from fastapi import Request - from unittest.mock import AsyncMock, MagicMock - + from unittest.mock import MagicMock + # Mock request object mock_request = MagicMock(spec=Request) mock_request_data = { "model": "gpt-4", - "messages": [{"role": "user", "content": "Hello GPT!"}] + "messages": [{"role": "user", "content": "Hello GPT!"}], } - + # Mock the _read_request_body function async def mock_read_request_body(request): return mock_request_data - + # Mock UserAPIKeyAuth mock_user_api_key_dict = MagicMock() - + # Patch the _read_request_body function import litellm.proxy.anthropic_endpoints.endpoints as anthropic_endpoints + original_read_request_body = anthropic_endpoints._read_request_body anthropic_endpoints._read_request_body = mock_read_request_body - + # Mock the internal token_counter function to return a controlled response async def mock_token_counter(request, call_endpoint=True): - assert call_endpoint == True, "Should be called with call_endpoint=True for Anthropic endpoint" + assert ( + call_endpoint == True + ), "Should be called with call_endpoint=True for Anthropic endpoint" assert request.model == "gpt-4" assert request.messages == [{"role": "user", "content": "Hello GPT!"}] - + from litellm.types.utils import TokenCountResponse + return TokenCountResponse( total_tokens=12, - request_model="gpt-4", + request_model="gpt-4", model_used="gpt-4", - tokenizer_type="openai_tokenizer" + tokenizer_type="openai_tokenizer", ) - + # Patch the imported token_counter function from proxy_server import litellm.proxy.proxy_server as proxy_server + original_token_counter = proxy_server.token_counter proxy_server.token_counter = mock_token_counter - + try: # Call the endpoint response = await count_tokens(mock_request, mock_user_api_key_dict) - + # Verify response format matches Anthropic spec assert isinstance(response, dict) assert "input_tokens" in response assert response["input_tokens"] == 12 assert len(response) == 1 # Should only contain input_tokens - + print("✅ Non-Anthropic model test passed!") - + finally: # Restore original functions anthropic_endpoints._read_request_body = original_read_request_body @@ -354,7 +358,7 @@ async def test_internal_token_counter_anthropic_provider_detection(): Test that the internal token_counter correctly detects Anthropic providers and handles the from_anthropic_endpoint flag appropriately """ - + # Test with Anthropic provider llm_router = Router( model_list=[ @@ -362,30 +366,30 @@ async def test_internal_token_counter_anthropic_provider_detection(): "model_name": "claude-test", "litellm_params": { "model": "anthropic/claude-3-sonnet-20240229", - "api_key": "test-key" + "api_key": "test-key", }, } ] ) - + setattr(litellm.proxy.proxy_server, "llm_router", llm_router) - + # Test with is_direct_request=False (simulating call from Anthropic endpoint) response = await token_counter( request=TokenCountRequest( model="claude-test", messages=[{"role": "user", "content": "hello"}], ), - call_endpoint=True + call_endpoint=True, ) - + print("Anthropic provider test response:", response) - + # Verify response structure assert response.request_model == "claude-test" assert response.model_used == "claude-3-sonnet-20240229" assert response.total_tokens > 0 - + # Test with non-Anthropic provider llm_router = Router( model_list=[ @@ -397,21 +401,21 @@ async def test_internal_token_counter_anthropic_provider_detection(): } ] ) - + setattr(litellm.proxy.proxy_server, "llm_router", llm_router) - + # Test with is_direct_request=False but non-Anthropic provider response = await token_counter( request=TokenCountRequest( model="gpt-test", messages=[{"role": "user", "content": "hello"}], ), - call_endpoint=True + call_endpoint=True, ) - + print("Non-Anthropic provider test response:", response) - - # Verify response structure + + # Verify response structure assert response.request_model == "gpt-test" assert response.model_used == "gpt-4" assert response.total_tokens > 0 @@ -426,34 +430,35 @@ async def test_anthropic_endpoint_error_handling(): from litellm.proxy.anthropic_endpoints.endpoints import count_tokens from fastapi import Request, HTTPException from unittest.mock import MagicMock - + # Mock request object mock_request = MagicMock(spec=Request) mock_user_api_key_dict = MagicMock() - + # Test missing model parameter mock_request_data = { "messages": [{"role": "user", "content": "Hello!"}] # Missing "model" key } - + async def mock_read_request_body(request): return mock_request_data - + import litellm.proxy.anthropic_endpoints.endpoints as anthropic_endpoints + original_read_request_body = anthropic_endpoints._read_request_body anthropic_endpoints._read_request_body = mock_read_request_body - + try: # Should raise HTTPException for missing model with pytest.raises(HTTPException) as exc_info: await count_tokens(mock_request, mock_user_api_key_dict) - + assert exc_info.value.status_code == 400 assert "model parameter is required" in str(exc_info.value.detail) - + print("✅ Error handling test passed!") - + finally: anthropic_endpoints._read_request_body = original_read_request_body @@ -464,44 +469,50 @@ async def test_factory_anthropic_endpoint_calls_anthropic_counter(): from unittest.mock import patch, AsyncMock from fastapi.testclient import TestClient from litellm.proxy.proxy_server import app - + # Mock the anthropic token counting function - with patch('litellm.proxy.utils.count_tokens_with_anthropic_api') as mock_anthropic_count: + with patch( + "litellm.proxy.utils.count_tokens_with_anthropic_api" + ) as mock_anthropic_count: mock_anthropic_count.return_value = { "total_tokens": 42, - "tokenizer_used": "anthropic" + "tokenizer_used": "anthropic", } - + # Mock router to return Anthropic deployment - with patch('litellm.proxy.proxy_server.llm_router') as mock_router: - mock_router.model_list = [{ - "model_name": "claude-3-5-sonnet", - "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, - "model_info": {} - }] - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router: + mock_router.model_list = [ + { + "model_name": "claude-3-5-sonnet", + "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, + "model_info": {}, + } + ] + # Mock the async method properly - mock_router.async_get_available_deployment = AsyncMock(return_value={ - "model_name": "claude-3-5-sonnet", - "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, - "model_info": {} - }) - + mock_router.async_get_available_deployment = AsyncMock( + return_value={ + "model_name": "claude-3-5-sonnet", + "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, + "model_info": {}, + } + ) + client = TestClient(app) - + response = client.post( "/v1/messages/count_tokens", json={ "model": "claude-3-5-sonnet", - "messages": [{"role": "user", "content": "Hello"}] + "messages": [{"role": "user", "content": "Hello"}], }, - headers={"Authorization": "Bearer test-key"} + headers={"Authorization": "Bearer test-key"}, ) - + assert response.status_code == 200 data = response.json() assert data["input_tokens"] == 42 - + # Verify that Anthropic API was called mock_anthropic_count.assert_called_once() @@ -512,43 +523,49 @@ async def test_factory_gpt4_endpoint_does_not_call_anthropic_counter(): from unittest.mock import patch, AsyncMock from fastapi.testclient import TestClient from litellm.proxy.proxy_server import app - + # Mock the anthropic token counting function - with patch('litellm.proxy.utils.count_tokens_with_anthropic_api') as mock_anthropic_count: + with patch( + "litellm.proxy.utils.count_tokens_with_anthropic_api" + ) as mock_anthropic_count: # Mock litellm token counter - with patch('litellm.token_counter') as mock_litellm_counter: + with patch("litellm.token_counter") as mock_litellm_counter: mock_litellm_counter.return_value = 50 - + # Mock router to return GPT-4 deployment - with patch('litellm.proxy.proxy_server.llm_router') as mock_router: - mock_router.model_list = [{ - "model_name": "gpt-4", - "litellm_params": {"model": "openai/gpt-4"}, - "model_info": {} - }] - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router: + mock_router.model_list = [ + { + "model_name": "gpt-4", + "litellm_params": {"model": "openai/gpt-4"}, + "model_info": {}, + } + ] + # Mock the async method properly - mock_router.async_get_available_deployment = AsyncMock(return_value={ - "model_name": "gpt-4", - "litellm_params": {"model": "openai/gpt-4"}, - "model_info": {} - }) - + mock_router.async_get_available_deployment = AsyncMock( + return_value={ + "model_name": "gpt-4", + "litellm_params": {"model": "openai/gpt-4"}, + "model_info": {}, + } + ) + client = TestClient(app) - + response = client.post( "/v1/messages/count_tokens", json={ "model": "gpt-4", - "messages": [{"role": "user", "content": "Hello"}] + "messages": [{"role": "user", "content": "Hello"}], }, - headers={"Authorization": "Bearer test-key"} + headers={"Authorization": "Bearer test-key"}, ) - + assert response.status_code == 200 data = response.json() assert data["input_tokens"] == 50 - + # Verify that Anthropic API was NOT called mock_anthropic_count.assert_not_called() @@ -559,43 +576,53 @@ async def test_factory_normal_token_counter_endpoint_does_not_call_anthropic(): from unittest.mock import patch, AsyncMock from fastapi.testclient import TestClient from litellm.proxy.proxy_server import app - + # Mock the anthropic token counting function - with patch('litellm.proxy.utils.count_tokens_with_anthropic_api') as mock_anthropic_count: + with patch( + "litellm.proxy.utils.count_tokens_with_anthropic_api" + ) as mock_anthropic_count: # Mock litellm token counter - with patch('litellm.token_counter') as mock_litellm_counter: + with patch("litellm.token_counter") as mock_litellm_counter: mock_litellm_counter.return_value = 35 - + # Mock router to return Anthropic deployment - with patch('litellm.proxy.proxy_server.llm_router') as mock_router: - mock_router.model_list = [{ - "model_name": "claude-3-5-sonnet", - "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, - "model_info": {} - }] - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router: + mock_router.model_list = [ + { + "model_name": "claude-3-5-sonnet", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20241022" + }, + "model_info": {}, + } + ] + # Mock the async method properly - mock_router.async_get_available_deployment = AsyncMock(return_value={ - "model_name": "claude-3-5-sonnet", - "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"}, - "model_info": {} - }) - + mock_router.async_get_available_deployment = AsyncMock( + return_value={ + "model_name": "claude-3-5-sonnet", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20241022" + }, + "model_info": {}, + } + ) + client = TestClient(app) - + response = client.post( "/utils/token_counter", json={ "model": "claude-3-5-sonnet", - "messages": [{"role": "user", "content": "Hello"}] + "messages": [{"role": "user", "content": "Hello"}], }, - headers={"Authorization": "Bearer test-key"} + headers={"Authorization": "Bearer test-key"}, ) - + assert response.status_code == 200 data = response.json() assert data["total_tokens"] == 35 - + # Verify that Anthropic API was NOT called (since call_endpoint=False) mock_anthropic_count.assert_not_called() @@ -604,34 +631,30 @@ async def test_factory_normal_token_counter_endpoint_does_not_call_anthropic(): async def test_factory_registration(): """Test that the new factory pattern correctly provides counters.""" from litellm.llms.anthropic.common_utils import AnthropicModelInfo - + # Test Anthropic ModelInfo provides token counter anthropic_model_info = AnthropicModelInfo() counter = anthropic_model_info.get_token_counter() assert counter is not None - + # Create test deployments anthropic_deployment = { "litellm_params": {"model": "anthropic/claude-3-5-sonnet-20241022"} } - - non_anthropic_deployment = { - "litellm_params": {"model": "openai/gpt-4"} - } - + + non_anthropic_deployment = {"litellm_params": {"model": "openai/gpt-4"}} + # Test Anthropic counter supports provider assert counter.should_use_token_counting_api(custom_llm_provider="anthropic") assert not counter.should_use_token_counting_api(custom_llm_provider="openai") - + # Test non-Anthropic provider assert not counter.should_use_token_counting_api(custom_llm_provider="openai") - + # Test None deployment assert not counter.should_use_token_counting_api(custom_llm_provider=None) - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", ["gemini-2.5-pro", "vertex-ai-gemini-2.5-pro"]) async def test_vertex_ai_gemini_token_counting_with_contents(model_name): @@ -655,26 +678,20 @@ async def test_vertex_ai_gemini_token_counting_with_contents(model_name): }, ] ) - + setattr(litellm.proxy.proxy_server, "llm_router", llm_router) - + # Test with contents format and call_endpoint=True response = await token_counter( request=TokenCountRequest( model=model_name, contents=[ - { - "parts": [ - { - "text": "Hello world, how are you doing today? i am ij" - } - ] - } + {"parts": [{"text": "Hello world, how are you doing today? i am ij"}]} ], ), - call_endpoint=True + call_endpoint=True, ) - + print("Vertex AI Gemini token counting response:", response) # validate we have orignal response @@ -684,3 +701,45 @@ async def test_vertex_ai_gemini_token_counting_with_contents(model_name): prompt_tokens_details = response.original_response.get("promptTokensDetails") assert prompt_tokens_details is not None + + +@pytest.mark.asyncio +async def test_bedrock_count_tokens_endpoint(): + """ + Test that Bedrock CountTokens endpoint correctly extracts model from request body. + """ + from litellm.router import Router + + # Mock the Bedrock CountTokens handler + async def mock_count_tokens_handler(request_data, litellm_params, resolved_model): + # Verify the correct model was resolved + assert resolved_model == "anthropic.claude-3-sonnet-20240229-v1:0" + assert request_data["model"] == "anthropic.claude-3-sonnet-20240229-v1:0" + assert request_data["messages"] == [{"role": "user", "content": "Hello!"}] + + return {"input_tokens": 25} + + # Set up router with Bedrock model + llm_router = Router( + model_list=[ + { + "model_name": "claude-bedrock", + "litellm_params": { + "model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + }, + } + ] + ) + + setattr(litellm.proxy.proxy_server, "llm_router", llm_router) + + # Test the mock handler directly to verify correct parameter extraction + request_data = { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [{"role": "user", "content": "Hello!"}], + } + + # Test the mock handler directly to verify correct parameter extraction + await mock_count_tokens_handler( + request_data, {}, "anthropic.claude-3-sonnet-20240229-v1:0" + ) diff --git a/tests/test_litellm/llms/bedrock/count_tokens/test_bedrock_count_tokens_transformation.py b/tests/test_litellm/llms/bedrock/count_tokens/test_bedrock_count_tokens_transformation.py new file mode 100644 index 000000000000..ed8d6e1b3595 --- /dev/null +++ b/tests/test_litellm/llms/bedrock/count_tokens/test_bedrock_count_tokens_transformation.py @@ -0,0 +1,36 @@ +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../../../../..") +) # Adds the parent directory to the system path +from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig + + +def test_detect_input_type(): + """Test input type detection (converse vs invokeModel)""" + config = BedrockCountTokensConfig() + + # Test messages format -> converse + request_with_messages = {"messages": [{"role": "user", "content": "hi"}]} + assert config._detect_input_type(request_with_messages) == "converse" + + # Test text format -> invokeModel + request_with_text = {"inputText": "hello"} + assert config._detect_input_type(request_with_text) == "invokeModel" + + +def test_transform_anthropic_to_bedrock_request(): + """Test basic request transformation""" + config = BedrockCountTokensConfig() + + anthropic_request = { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [{"role": "user", "content": "Hello"}], + } + + result = config.transform_anthropic_to_bedrock_count_tokens(anthropic_request) + + assert "input" in result + assert "converse" in result["input"] + assert "messages" in result["input"]["converse"]