Skip to content

Commit

Permalink
Add ability in Token Counter to retrieve Open AI cached_tokens run-ll…
Browse files Browse the repository at this point in the history
  • Loading branch information
“sangwon.ku” committed Dec 27, 2024
1 parent 1be7b0b commit 6863ec8
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions llama-index-core/llama_index/core/callbacks/token_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TokenCountingEvent:
completion_token_count: int
prompt_token_count: int
total_token_count: int = 0
cached_tokens: int = 0
event_id: str = ""

def __post_init__(self) -> None:
Expand All @@ -56,7 +57,8 @@ def get_tokens_from_response(

possible_input_keys = ("prompt_tokens", "input_tokens")
possible_output_keys = ("completion_tokens", "output_tokens")

openai_prompt_tokens_details_key = 'prompt_tokens_details'

prompt_tokens = 0
for input_key in possible_input_keys:
if input_key in usage:
Expand All @@ -68,8 +70,12 @@ def get_tokens_from_response(
if output_key in usage:
completion_tokens = usage[output_key]
break

return prompt_tokens, completion_tokens

cached_tokens = 0
if openai_prompt_tokens_details_key in usage:
cached_tokens = usage[openai_prompt_tokens_details_key]['cached_tokens']

return prompt_tokens, completion_tokens, cached_tokens


def get_llm_token_counts(
Expand All @@ -83,9 +89,9 @@ def get_llm_token_counts(

if completion:
# get from raw or additional_kwargs
prompt_tokens, completion_tokens = get_tokens_from_response(completion)
prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(completion)
else:
prompt_tokens, completion_tokens = 0, 0
prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0

if prompt_tokens == 0:
prompt_tokens = token_counter.get_string_tokens(str(prompt))
Expand All @@ -99,6 +105,7 @@ def get_llm_token_counts(
prompt_token_count=prompt_tokens,
completion=str(completion),
completion_token_count=completion_tokens,
cached_tokens=cached_tokens,
)

elif EventPayload.MESSAGES in payload:
Expand All @@ -109,9 +116,9 @@ def get_llm_token_counts(
response_str = str(response)

if response:
prompt_tokens, completion_tokens = get_tokens_from_response(response)
prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(response)
else:
prompt_tokens, completion_tokens = 0, 0
prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0

if prompt_tokens == 0:
prompt_tokens = token_counter.estimate_tokens_in_messages(messages)
Expand All @@ -125,6 +132,7 @@ def get_llm_token_counts(
prompt_token_count=prompt_tokens,
completion=response_str,
completion_token_count=completion_tokens,
cached_tokens=cached_tokens,
)
else:
return TokenCountingEvent(
Expand All @@ -133,6 +141,7 @@ def get_llm_token_counts(
prompt_token_count=0,
completion="",
completion_token_count=0,
cached_tokens=0,
)


Expand Down Expand Up @@ -214,7 +223,9 @@ def on_event_end(
"LLM Prompt Token Usage: "
f"{self.llm_token_counts[-1].prompt_token_count}\n"
"LLM Completion Token Usage: "
f"{self.llm_token_counts[-1].completion_token_count}",
f"{self.llm_token_counts[-1].completion_token_count}"
"LLM Cached Tokens: "
f"{self.llm_token_counts[-1].cached_tokens}",
)
elif (
event_type == CBEventType.EMBEDDING
Expand Down Expand Up @@ -251,6 +262,11 @@ def prompt_llm_token_count(self) -> int:
def completion_llm_token_count(self) -> int:
"""Get the current total LLM completion token count."""
return sum([x.completion_token_count for x in self.llm_token_counts])

@property
def total_cached_token_count(self) -> int:
"""Get the current total cached token count."""
return sum([x.cached_tokens for x in self.llm_token_counts])

@property
def total_embedding_token_count(self) -> int:
Expand Down

0 comments on commit 6863ec8

Please sign in to comment.