Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding bedrock token usage callback handler #69

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions libs/aws/langchain_aws/callbacks/bedrock_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import threading
from typing import Any, Dict, List, Union

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

MODEL_COST_PER_1K_INPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.00265,
"anthropic.claude-v2": 0.008,
"anthropic.claude-v2:1": 0.008,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
"meta.llama3-70b-instruct-v1:0": 0.00265,
"meta.llama3-8b-instruct-v1:0" : 0.00040,
"meta.llama2-13b-chat-v1" : 0.00075,
"meta.llama2-70b-chat-v1" : 0.00195
}

MODEL_COST_PER_1K_OUTPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0035,
"anthropic.claude-v2": 0.024,
"anthropic.claude-v2:1": 0.024,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
"meta.llama3-70b-instruct-v1:0": 0.0035,
"meta.llama3-8b-instruct-v1:0" : 0.0006,
"meta.llama2-13b-chat-v1" : 0.00100,
"meta.llama2-70b-chat-v1" : 0.00256
}


def _get_token_cost(
prompt_tokens: int, completion_tokens: int, model_id: Union[str, None]
) -> float:
"""Get the cost of tokens for the Claude model."""
if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
raise ValueError(
f"Unknown model: {model_id}. Please provide a valid Bedrock model name."
"Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())
)
return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[model_id] + (
completion_tokens / 1000
) * MODEL_COST_PER_1K_OUTPUT_TOKENS[model_id]


class BedrockTokenUsageCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks bedrock anthropic info."""

total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None

if "usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None

# compute tokens and cost for this request
token_usage = response.llm_output["usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
total_tokens = token_usage.get("total_tokens", 0)
model_id = response.llm_output.get("model_id", None)
total_cost = _get_token_cost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
model_id=model_id,
)

# update shared state behind lock
with self._lock:
self.total_cost += total_cost
self.total_tokens += total_tokens
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1

def __copy__(self) -> "BedrockTokenUsageCallbackHandler":
"""Return a copy of the callback handler."""
return self

def __deepcopy__(self, memo: Any) -> "BedrockTokenUsageCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
45 changes: 45 additions & 0 deletions libs/aws/langchain_aws/callbacks/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import logging
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
Generator,
Optional,
)

from langchain_core.tracers.context import register_configure_hook

from langchain_aws.callbacks.bedrock_callback import (
BedrockTokenUsageCallbackHandler,
)

logger = logging.getLogger(__name__)


bedrock_callback_var: (ContextVar)[
Optional[BedrockTokenUsageCallbackHandler]
] = ContextVar("bedrock_anthropic_callback", default=None)

register_configure_hook(bedrock_callback_var, True)


@contextmanager
def get_bedrock_callback() -> (
Generator[BedrockTokenUsageCallbackHandler, None, None]
):
"""Get the Bedrock callback handler in a context manager.
which conveniently exposes token and cost information.

Returns:
BedrockTokenUsageCallbackHandler:
The Bedrock callback handler.

Example:
>>> with get_bedrock_callback() as cb:
... # Use the Bedrock callback handler
"""
cb = BedrockTokenUsageCallbackHandler()
bedrock_callback_var.set(cb)
yield cb
bedrock_callback_var.set(None)
10 changes: 5 additions & 5 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _nest_usage_info_token_counts(usage_info: dict) -> dict:


def _combine_generation_info_for_llm_result(
chunks_generation_info: List[Dict[str, Any]], provider_stop_code: str
chunks_generation_info: List[Dict[str, Any]], provider_stop_code: str, model_id: str
) -> Dict[str, Any]:
"""
Returns usage and stop reason information with the intent to pack into an LLMResult
Expand Down Expand Up @@ -171,7 +171,7 @@ def _combine_generation_info_for_llm_result(
total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"]
)

return {"usage": total_usage_info, "stop_reason": stop_reason}
return {"usage": total_usage_info, "stop_reason": stop_reason, "model_id" : model_id}


class LLMInputOutputAdapter:
Expand Down Expand Up @@ -631,7 +631,7 @@ def _prepare_input_and_invoke(
if stop is not None:
text = enforce_stop_tokens(text, stop)

llm_output = {"usage": usage_info, "stop_reason": stop_reason}
llm_output = {"usage": usage_info, "stop_reason": stop_reason, "model_id": self.model_id}

# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
Expand Down Expand Up @@ -939,7 +939,7 @@ def _call(
if chunk.generation_info is not None
]
llm_output = _combine_generation_info_for_llm_result(
chunks_generation_info, provider_stop_code=provider_stop_reason_code
chunks_generation_info, provider_stop_code=provider_stop_reason_code, model_id=self.model_id
)
all_generations = [
Generation(text=chunk.text, generation_info=chunk.generation_info)
Expand Down Expand Up @@ -1031,7 +1031,7 @@ async def _acall(
if chunk.generation_info is not None
]
llm_output = _combine_generation_info_for_llm_result(
chunks_generation_info, provider_stop_code=provider_stop_reason_code
chunks_generation_info, provider_stop_code=provider_stop_reason_code, model_id=self.model_id
)
generations = [
Generation(text=chunk.text, generation_info=chunk.generation_info)
Expand Down