Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 163 additions & 28 deletions litellm/litellm_core_utils/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import base64
import io
import struct
from typing import Callable, List, Literal, Optional, Tuple, Union, cast
from typing import (
Any,
Callable,
List,
Literal,
Optional,
Tuple,
Union,
cast,
get_type_hints,
)

import tiktoken

Expand All @@ -20,6 +30,10 @@
)
from litellm.litellm_core_utils.default_encoding import encoding as default_encoding
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.types.llms.anthropic import (
AnthropicMessagesToolResultParam,
AnthropicMessagesToolUseParam,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionNamedToolChoiceParam,
Expand Down Expand Up @@ -552,57 +566,178 @@ def _fix_model_name(model: str) -> str:
return "gpt-3.5-turbo"


def _count_image_tokens(
image_url: Any,
use_default_image_token_count: bool,
) -> int:
"""
Count tokens for an image_url content block.

Args:
image_url: The image URL data - can be a string URL or dict with 'url' and 'detail'
use_default_image_token_count: Whether to use default image token counts

Returns:
int: Number of tokens for the image

Raises:
ValueError: If image_url is invalid type or detail value is invalid
"""
if isinstance(image_url, dict):
detail = image_url.get("detail", "auto")
if detail not in ["low", "high", "auto"]:
raise ValueError(
f"Invalid detail value: {detail}. Expected 'low', 'high', or 'auto'."
)
url = image_url.get("url")
if not url:
raise ValueError("Missing required key 'url' in image_url dict.")
return calculate_img_tokens(
data=url,
mode=detail, # type: ignore
use_default_image_token_count=use_default_image_token_count,
)

elif isinstance(image_url, str):
if not image_url.strip():
raise ValueError("Empty image_url string is not valid.")
return calculate_img_tokens(
data=image_url,
mode="auto",
use_default_image_token_count=use_default_image_token_count,
)

else:
raise ValueError(
f"Invalid image_url type: {type(image_url).__name__}. "
"Expected str or dict with 'url' field."
)


def _validate_anthropic_content(content: dict) -> type:
"""
Validate and determine which Anthropic TypedDict applies.

Returns the corresponding TypedDict class if recognized, otherwise raises.
"""
content_type = content.get("type")
if not content_type:
raise ValueError("Anthropic content missing required field: 'type'")

mapping = {
"tool_use": AnthropicMessagesToolUseParam,
"tool_result": AnthropicMessagesToolResultParam,
}

expected_cls = mapping.get(content_type)
if expected_cls is None:
raise ValueError(f"Unknown Anthropic content type: '{content_type}'")

missing = [
k for k in getattr(expected_cls, "__required_keys__", set()) if k not in content
]
if missing:
raise ValueError(
f"Missing required fields in {content_type} block: {', '.join(missing)}"
)

return expected_cls


def _count_anthropic_content(
content: dict,
count_function: TokenCounterFunction,
use_default_image_token_count: bool,
default_token_count: Optional[int],
) -> int:
"""
Count tokens in Anthropic-specific content blocks (tool_use, tool_result, etc.).

Uses TypedDict definitions from litellm.types.llms.anthropic to determine
what fields to count and how to handle nested structures.

Dynamically infers which fields to count based on the TypedDict definition,
avoiding hardcoded field names.
"""
typeddict_cls = _validate_anthropic_content(content)
type_hints = getattr(typeddict_cls, "__annotations__", {})
tokens = 0

# Fields to skip (metadata/identifiers that don't contribute to prompt tokens)
skip_fields = {"type", "id", "tool_use_id", "cache_control", "is_error"}

# Iterate over all fields defined in the TypedDict
for field_name, field_type in type_hints.items():
if field_name in skip_fields:
continue

field_value = content.get(field_name)
if field_value is None:
continue
try:
if isinstance(field_value, str):
tokens += count_function(field_value)
elif isinstance(field_value, list):
tokens += _count_content_list(
count_function,
field_value, # type: ignore
use_default_image_token_count,
default_token_count,
)
elif isinstance(field_value, dict):
tokens += count_function(str(field_value))
except Exception as e:
if default_token_count is not None:
return default_token_count
raise ValueError(f"Error counting field '{field_name}': {e}")
return tokens


def _count_content_list(
count_function: TokenCounterFunction,
content_list: OpenAIMessageContent,
use_default_image_token_count: bool,
default_token_count: Optional[int],
) -> int:
"""
Get the number of tokens from a list of content.
Recursively count tokens from a list of content blocks.
"""
try:
num_tokens = 0
for c in content_list:
if isinstance(c, str):
num_tokens += count_function(c)
elif c["type"] == "text":
num_tokens += count_function(c["text"])
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
if detail not in ["low", "high", "auto"]:
raise ValueError(
f"Invalid detail value: {detail}. Expected 'low', 'high', or 'auto'."
)
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail, # type: ignore
use_default_image_token_count=use_default_image_token_count,
elif isinstance(c, dict):
ctype = c.get("type")
if ctype == "text":
num_tokens += count_function(c.get("text", ""))
elif ctype == "image_url":
image_url = c.get("image_url")
num_tokens += _count_image_tokens(
image_url, use_default_image_token_count
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count,
elif ctype in ("tool_use", "tool_result"):
num_tokens += _count_anthropic_content(
c,
count_function,
use_default_image_token_count,
default_token_count,
)
else:
raise ValueError(
f"Invalid image_url type: {type(c['image_url'])}. Expected str or dict."
)
raise ValueError(f"Invalid content type: {ctype}")
else:
raise ValueError(
f"Invalid content type: {type(c)}. Expected str or dict."
f"Invalid content item type: {type(c).__name__}. "
f"Expected str or dict with 'type' field. "
f"Value: {c!r}"
)
return num_tokens
except Exception as e:
if default_token_count is not None:
return default_token_count
raise ValueError(
f"Error getting number of tokens from content list: {e}, default_token_count={default_token_count}"
f"Error getting number of tokens from content list: {e}, "
f"default_token_count={default_token_count}"
)


Expand Down
Loading
Loading