diff --git a/src/ragas/cost.py b/src/ragas/cost.py index 144f66a12..d665db9e1 100644 --- a/src/ragas/cost.py +++ b/src/ragas/cost.py @@ -127,6 +127,19 @@ def get_token_usage_for_bedrock( return TokenUsage(input_tokens=0, output_tokens=0) +def get_token_usage_for_azure_ai( + llm_result: t.Union[LLMResult, ChatResult], +) -> TokenUsage: + # AzureAI like interfaces + llm_output = llm_result.llm_output + if llm_output is None: + logger.info("No llm_output found in the LLMResult") + return TokenUsage(input_tokens=0, output_tokens=0) + output_tokens = get_from_dict(llm_output, "token_usage.input_tokens", 0) + input_tokens = get_from_dict(llm_output, "token_usage.output_tokens", 0) + return TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens) + + class CostCallbackHandler(BaseCallbackHandler): def __init__(self, token_usage_parser: TokenUsageParser): self.token_usage_parser = token_usage_parser diff --git a/tests/unit/test_cost.py b/tests/unit/test_cost.py index 715f28f94..6a7a91a20 100644 --- a/tests/unit/test_cost.py +++ b/tests/unit/test_cost.py @@ -6,6 +6,7 @@ CostCallbackHandler, TokenUsage, get_token_usage_for_anthropic, + get_token_usage_for_azure_ai, get_token_usage_for_bedrock, get_token_usage_for_openai, ) @@ -129,6 +130,18 @@ def test_token_usage_cost(): llm_output={}, ) +azure_ai_result = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="Hello, world!"))]], + llm_output={ + "token_usage": { + "input_tokens": 10, + "output_tokens": 10, + "total_tokens": 20, + }, + "model_name": "mistral-small-2503", + }, +) + def test_parse_llm_results(): # openai @@ -147,6 +160,10 @@ def test_parse_llm_results(): token_usage = get_token_usage_for_bedrock(bedrock_claude_result) assert token_usage == TokenUsage(input_tokens=10, output_tokens=10) + # Azure AI + token_usage = get_token_usage_for_azure_ai(azure_ai_result) + assert token_usage == TokenUsage(input_tokens=10, output_tokens=10) + def test_cost_callback_handler(): cost_cb = CostCallbackHandler(token_usage_parser=get_token_usage_for_openai)