Skip to content
94 changes: 77 additions & 17 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
Expand Down Expand Up @@ -726,15 +731,7 @@ def _create_chat_result(
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
message.usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": token_usage.get(
"total_tokens", input_tokens + output_tokens
),
}
message.usage_metadata = _create_usage_metadata(token_usage)
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
Expand Down Expand Up @@ -774,7 +771,20 @@ def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage and v is not None:
overall_token_usage[k] += v
# Handle nested dictionaries
if isinstance(v, dict):
if k not in overall_token_usage:
overall_token_usage[k] = {}
for nested_k, nested_v in v.items():
if (
nested_k in overall_token_usage[k]
and nested_v is not None
):
overall_token_usage[k][nested_k] += nested_v
else:
overall_token_usage[k][nested_k] = nested_v
else:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
Expand Down Expand Up @@ -1329,13 +1339,7 @@ def _convert_chunk_to_message_chunk(
{k: executed_tool[k] for k in executed_tool if k != "output"}
)
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
usage_metadata = _create_usage_metadata(usage)
else:
usage_metadata = None
return AIMessageChunk(
Expand Down Expand Up @@ -1435,3 +1439,59 @@ def _lc_invalid_tool_call_to_groq_tool_call(
"arguments": invalid_tool_call["args"],
},
}


def _create_usage_metadata(groq_token_usage: dict) -> UsageMetadata:
"""Create usage metadata from Groq token usage response.

Args:
groq_token_usage: Token usage dict from Groq API response.

Returns:
Usage metadata dict with input/output token details.
"""
# Support both formats: new Responses API uses "input_tokens",
# Chat Completions API uses "prompt_tokens"
input_tokens = (
groq_token_usage.get("input_tokens")
or groq_token_usage.get("prompt_tokens")
or 0
)
output_tokens = (
groq_token_usage.get("output_tokens")
or groq_token_usage.get("completion_tokens")
or 0
)
total_tokens = groq_token_usage.get("total_tokens") or input_tokens + output_tokens

# Support both formats for token details:
# Responses API uses "*_tokens_details", Chat Completions API might use
# "prompt_token_details"
input_details_dict = (
groq_token_usage.get("input_tokens_details")
or groq_token_usage.get("prompt_tokens_details")
or {}
)
output_details_dict = (
groq_token_usage.get("output_tokens_details")
or groq_token_usage.get("completion_tokens_details")
or {}
)

input_token_details: dict = {
"cache_read": input_details_dict.get("cached_tokens"),
}
output_token_details: dict = {
"reasoning": output_details_dict.get("reasoning_tokens"),
}
usage_metadata: UsageMetadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}

if filtered_input := {k: v for k, v in input_token_details.items() if v}:
usage_metadata["input_token_details"] = InputTokenDetails(**filtered_input) # type: ignore[typeddict-item]
if filtered_output := {k: v for k, v in output_token_details.items() if v}:
usage_metadata["output_token_details"] = OutputTokenDetails(**filtered_output) # type: ignore[typeddict-item]
return usage_metadata
Loading