diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index fff0fd6f..f43053f7 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -275,7 +275,7 @@ async def process_stream( } state["content"] = state["message"]["content"] - usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0, cacheReadInputTokens=0, cacheWriteInputTokens=0) metrics: Metrics = Metrics(latencyMs=0) async for chunk in chunks: diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index dae05394..e357eb18 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -331,6 +331,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: "inputTokens": usage["input_tokens"], "outputTokens": usage["output_tokens"], "totalTokens": usage["input_tokens"] + usage["output_tokens"], + "cacheReadInputTokens": usage.get("cache_read_input_tokens", 0), + "cacheWriteInputTokens": usage.get("cache_creation_input_tokens", 0), }, "metrics": { "latencyMs": 0, # TODO diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 82bbb1ea..f4bb68cc 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -177,7 +177,15 @@ async def stream( async for event in response: _ = event - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + usage = event.usage + cache_read = max( + getattr(usage, "cache_read_input_tokens", 0), + getattr(getattr(usage, "prompt_tokens_details", {}), "cached_tokens", 0), + ) + + usage.prompt_tokens_details.cached_tokens = cache_read + + yield self.format_chunk({"chunk_type": "metadata", "data": usage}) logger.debug("finished streaming response from model") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 3bae2233..889164fd 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -308,6 +308,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: inputTokens=usage["inputTokens"], outputTokens=usage["outputTokens"], totalTokens=usage["totalTokens"], + # TODO does not seem to support caching as of July 2025 + cacheWriteInputTokens=0, + cacheReadInputTokens=0, ) return { "metadata": { diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 151b423d..e2af10e7 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -338,6 +338,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: "inputTokens": usage.prompt_tokens, "outputTokens": usage.completion_tokens, "totalTokens": usage.total_tokens, + # TODO does not seem to support caching as of July 2025 + "cacheWriteInputTokens": 0, + "cacheReadInputTokens": 0, }, "metrics": { "latencyMs": event.get("latency_ms", 0), diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 5fb0c1ff..80b111f4 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -268,6 +268,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: "inputTokens": event["data"].eval_count, "outputTokens": event["data"].prompt_eval_count, "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, + # TODO add cache metrics + "cacheWriteInputTokens": 0, + "cacheReadInputTokens": 0, }, "metrics": { "latencyMs": event["data"].total_duration / 1e6, diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6374590b..28edc532 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -310,6 +310,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: "inputTokens": event["data"].prompt_tokens, "outputTokens": event["data"].completion_tokens, "totalTokens": event["data"].total_tokens, + "cacheReadInputTokens": event["data"].prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": 0, # OpenAI does not return cache write information }, "metrics": { "latencyMs": 0, # TODO diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 332ab2ae..3a256708 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -168,7 +168,11 @@ class EventLoopMetrics: tool_metrics: Dict[str, ToolMetrics] = field(default_factory=dict) cycle_durations: List[float] = field(default_factory=list) traces: List[Trace] = field(default_factory=list) - accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) + accumulated_usage: Usage = field( + default_factory=lambda: Usage( + inputTokens=0, outputTokens=0, totalTokens=0, cacheReadInputTokens=0, cacheWriteInputTokens=0 + ) + ) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) @property @@ -263,6 +267,8 @@ def update_usage(self, usage: Usage) -> None: self.accumulated_usage["inputTokens"] += usage["inputTokens"] self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] + self.accumulated_usage["cacheReadInputTokens"] += usage.get("cacheReadInputTokens", 0) + self.accumulated_usage["cacheWriteInputTokens"] += usage.get("cacheWriteInputTokens", 0) def update_metrics(self, metrics: Metrics) -> None: """Update the accumulated performance metrics with new metrics data. @@ -320,15 +326,18 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name An iterable of formatted text lines representing the metrics. """ summary = event_loop_metrics.get_summary() + accumulated_usage = summary["accumulated_usage"] yield "Event Loop Metrics Summary:" yield ( f"├─ Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, " f"total_time={summary['total_duration']:.3f}s" ) yield ( - f"├─ Tokens: in={summary['accumulated_usage']['inputTokens']}, " - f"out={summary['accumulated_usage']['outputTokens']}, " - f"total={summary['accumulated_usage']['totalTokens']}" + f"├─ Tokens: in={accumulated_usage['inputTokens']}" + f" (cache_write={accumulated_usage.get('cacheWriteInputTokens', 0)}), " + f"out={accumulated_usage['outputTokens']}, " + f"total={accumulated_usage['totalTokens']}" + f" (cache_read={accumulated_usage.get('cacheReadInputTokens', 0)})" ) yield f"├─ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms" @@ -421,6 +430,8 @@ class MetricsClient: event_loop_latency: Histogram event_loop_input_tokens: Histogram event_loop_output_tokens: Histogram + event_loop_input_tokens_cache_read: Histogram + event_loop_input_tokens_cache_write: Histogram tool_call_count: Counter tool_success_count: Counter @@ -474,3 +485,9 @@ def create_instruments(self) -> None: self.event_loop_output_tokens = self.meter.create_histogram( name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" ) + self.event_loop_input_tokens_cache_read = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKEN_CACHE_READ, unit="token" + ) + self.event_loop_input_tokens_cache_write = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_INPUT_TOKENS_CACHE_WRITE, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index b622eebf..caae0509 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -13,3 +13,5 @@ STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" +STRANDS_EVENT_LOOP_INPUT_TOKEN_CACHE_READ = "strands.event_loop.input.tokens.cache.read" +STRANDS_EVENT_LOOP_INPUT_TOKENS_CACHE_WRITE = "strands.event_loop.input.tokens.cache.write" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index ff3f832a..0995ea56 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -259,6 +259,8 @@ def end_model_invoke_span( attributes: Dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], + "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.output_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], @@ -492,6 +494,8 @@ def end_agent_span( "gen_ai.usage.input_tokens": accumulated_usage["inputTokens"], "gen_ai.usage.output_tokens": accumulated_usage["outputTokens"], "gen_ai.usage.total_tokens": accumulated_usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0), } ) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 7be33b6f..99f15480 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -5,18 +5,22 @@ from typing_extensions import TypedDict -class Usage(TypedDict): +class Usage(TypedDict, total=False): """Token usage information for model interactions. Attributes: inputTokens: Number of tokens sent in the request to the model.. outputTokens: Number of tokens that the model generated for the request. totalTokens: Total number of tokens (input + output). + cacheReadInputTokens: Number of tokens read from cache. + cacheWriteInputTokens: Number of tokens written to cache. """ inputTokens: int outputTokens: int totalTokens: int + cacheReadInputTokens: int + cacheWriteInputTokens: int class Metrics(TypedDict): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 80d6a5ef..6733bab3 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -250,7 +250,13 @@ def test_handle_message_stop(): def test_extract_usage_metrics(): event = { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 0}, } @@ -279,7 +285,13 @@ def test_extract_usage_metrics(): }, { "metadata": { - "usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, + }, "metrics": {"latencyMs": 1}, } }, @@ -364,6 +376,8 @@ def test_extract_usage_metrics(): "inputTokens": 1, "outputTokens": 1, "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, }, }, }, @@ -376,7 +390,13 @@ def test_extract_usage_metrics(): "role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, + }, {"latencyMs": 1}, ) }, @@ -398,7 +418,13 @@ def test_extract_usage_metrics(): "role": "assistant", "content": [], }, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, {"latencyMs": 0}, ), }, @@ -426,7 +452,13 @@ def test_extract_usage_metrics(): }, { "metadata": { - "usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, + }, "metrics": {"latencyMs": 1}, } }, @@ -506,6 +538,8 @@ def test_extract_usage_metrics(): "inputTokens": 1, "outputTokens": 1, "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, }, }, }, @@ -518,7 +552,13 @@ def test_extract_usage_metrics(): "role": "assistant", "content": [{"text": "REDACTED."}], }, - {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + "cacheReadInputTokens": 1, + "cacheWriteInputTokens": 1, + }, {"latencyMs": 1}, ), }, @@ -584,7 +624,13 @@ async def test_stream_messages(agenerator, alist): "stop": ( "end_turn", {"role": "assistant", "content": [{"text": "test"}]}, - {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, {"latencyMs": 0}, ) }, diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 5e8d69ea..99580e49 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -597,7 +597,12 @@ def test_format_chunk_message_stop(model): def test_format_chunk_metadata(model): event = { "type": "metadata", - "usage": {"input_tokens": 1, "output_tokens": 2}, + "usage": { + "input_tokens": 1, + "output_tokens": 2, + "cache_read_input_tokens": 4, + "cache_creation_input_tokens": 5, + }, } tru_chunk = model.format_chunk(event) @@ -607,6 +612,8 @@ def test_format_chunk_metadata(model): "inputTokens": 1, "outputTokens": 2, "totalTokens": 3, + "cacheReadInputTokens": 4, + "cacheWriteInputTokens": 5, }, "metrics": { "latencyMs": 0, @@ -656,7 +663,18 @@ async def test_stream(anthropic_client, model, agenerator, alist): tru_events = await alist(response) exp_events = [ {"messageStart": {"role": "assistant"}}, - {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, + { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 2, + "totalTokens": 3, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, + "metrics": {"latencyMs": 0}, + } + }, ] assert tru_events == exp_events diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 6060500b..2d58d70a 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -476,7 +476,13 @@ async def test_stream_stream_input_guardrails( ): metadata_event = { "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 245}, "trace": { "guardrail": { @@ -531,7 +537,13 @@ async def test_stream_stream_output_guardrails( model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) metadata_event = { "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 245}, "trace": { "guardrail": { @@ -588,7 +600,13 @@ async def test_stream_output_guardrails_redacts_input_and_output( model.update_config(guardrail_redact_output=True) metadata_event = { "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 245}, "trace": { "guardrail": { @@ -645,7 +663,13 @@ async def test_stream_output_no_blocked_guardrails_doesnt_redact( ): metadata_event = { "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 245}, "trace": { "guardrail": { @@ -698,7 +722,13 @@ async def test_stream_output_no_guardrail_redact( ): metadata_event = { "metadata": { - "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 245}, "trace": { "guardrail": { @@ -888,7 +918,13 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, - "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "usage": { + "inputTokens": 1234, + "outputTokens": 1234, + "totalTokens": 2468, + "cacheReadInputTokens": 128, + "cacheWriteInputTokens": 512, + }, "metrics": {"latencyMs": 1234}, "stopReason": "tool_use", } @@ -906,7 +942,13 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, { "metadata": { - "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "usage": { + "inputTokens": 1234, + "outputTokens": 1234, + "totalTokens": 2468, + "cacheReadInputTokens": 128, + "cacheWriteInputTokens": 512, + }, "metrics": {"latencyMs": 1234}, } }, diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bddd44ab..c14b52ac 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -135,7 +135,15 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) - mock_event_6 = unittest.mock.Mock() + mock_event_6 = unittest.mock.Mock( + usage=unittest.mock.Mock( + prompt_tokens_details=unittest.mock.Mock( + audio_tokens=None, cached_tokens=0, text_tokens=None, image_tokens=None + ), + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ) + ) litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]) @@ -178,6 +186,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, "inputTokens": mock_event_6.usage.prompt_tokens, "outputTokens": mock_event_6.usage.completion_tokens, "totalTokens": mock_event_6.usage.total_tokens, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "metrics": {"latencyMs": 0}, } diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 309dac2e..ee8449e5 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -346,6 +346,8 @@ def test_format_chunk_metadata(model): "inputTokens": 100, "outputTokens": 50, "totalTokens": 150, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "metrics": { "latencyMs": 0, diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 2a78024f..15dd2e5a 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -391,6 +391,8 @@ def test_format_chunk_metadata(model): "inputTokens": 100, "outputTokens": 50, "totalTokens": 150, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "metrics": { "latencyMs": 250, @@ -419,6 +421,8 @@ def test_format_chunk_metadata_no_latency(model): "inputTokens": 100, "outputTokens": 50, "totalTokens": 150, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "metrics": { "latencyMs": 0, diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index c3fb7736..1ba7caa9 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -396,6 +396,8 @@ def test_format_chunk_metadata(model): "inputTokens": 100, "outputTokens": 50, "totalTokens": 150, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "metrics": { "latencyMs": 1.0, @@ -437,7 +439,13 @@ async def test_stream(ollama_client, model, agenerator, alist): {"messageStop": {"stopReason": "end_turn"}}, { "metadata": { - "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "usage": { + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 1.0}, } }, @@ -484,7 +492,13 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): {"messageStop": {"stopReason": "tool_use"}}, { "metadata": { - "usage": {"inputTokens": 15, "outputTokens": 8, "totalTokens": 23}, + "usage": { + "inputTokens": 15, + "outputTokens": 8, + "totalTokens": 23, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, "metrics": {"latencyMs": 2.0}, } }, diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0a095ab9..2ce1f66d 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -352,7 +352,7 @@ def test_format_request(model, messages, tool_specs, system_prompt): ( { "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150, prompt_tokens_details=unittest.mock.Mock(cached_tokens=40)), }, { "metadata": { @@ -360,6 +360,8 @@ def test_format_request(model, messages, tool_specs, system_prompt): "inputTokens": 100, "outputTokens": 50, "totalTokens": 150, + "cacheReadInputTokens": 40, + "cacheWriteInputTokens": 0 }, "metrics": { "latencyMs": 0, diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 215e1efd..f1615e79 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -90,6 +90,8 @@ def usage(request): "inputTokens": 1, "outputTokens": 2, "totalTokens": 3, + "cacheReadInputTokens": 4, + "cacheWriteInputTokens": 5, } if hasattr(request, "param"): params.update(request.param) @@ -315,17 +317,15 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met event_loop_metrics.update_usage(usage) tru_usage = event_loop_metrics.accumulated_usage - exp_usage = Usage( - inputTokens=3, - outputTokens=6, - totalTokens=9, - ) + exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheReadInputTokens=12, cacheWriteInputTokens=15) assert tru_usage == exp_usage mock_get_meter_provider.return_value.get_meter.assert_called() metrics_client = event_loop_metrics._metrics_client metrics_client.event_loop_input_tokens.record.assert_called() metrics_client.event_loop_output_tokens.record.assert_called() + metrics_client.event_loop_input_tokens_cache_read.record.assert_called() + metrics_client.event_loop_input_tokens_cache_write.record.assert_called() def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): @@ -358,6 +358,8 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge "inputTokens": 0, "outputTokens": 0, "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, }, "average_cycle_time": 0, "tool_usage": { @@ -394,7 +396,7 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge {}, "Event Loop Metrics Summary:\n" "├─ Cycles: total=0, avg_time=0.000s, total_time=0.000s\n" - "├─ Tokens: in=0, out=0, total=0\n" + "├─ Tokens: in=0 (cache_write=0), out=0, total=0 (cache_read=0)\n" "├─ Bedrock Latency: 0ms\n" "├─ Tool Usage:\n" " └─ tool1:\n" @@ -412,7 +414,7 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge {}, "Event Loop Metrics Summary:\n" "├─ Cycles: total=0, avg_time=0.000s, total_time=0.000s\n" - "├─ Tokens: in=0, out=0, total=0\n" + "├─ Tokens: in=0 (cache_write=0), out=0, total=0 (cache_read=0)\n" "├─ Bedrock Latency: 0ms\n" "├─ Tool Usage:\n" " └─ tool1:\n" diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 7623085f..4bd38da3 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -166,7 +166,7 @@ def test_end_model_invoke_span(mock_span): """Test ending a model invoke span.""" tracer = Tracer() message = {"role": "assistant", "content": [{"text": "Response"}]} - usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30, cacheReadInputTokens=4, cacheWriteInputTokens=25) stop_reason: StopReason = "end_turn" tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) @@ -176,6 +176,9 @@ def test_end_model_invoke_span(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 4) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 25) + mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -305,7 +308,13 @@ def test_end_agent_span(mock_span): # Mock AgentResult with metrics mock_metrics = mock.MagicMock() - mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + mock_metrics.accumulated_usage = { + "inputTokens": 50, + "outputTokens": 100, + "totalTokens": 150, + "cacheReadInputTokens": 60, + "cacheWriteInputTokens": 100, + } mock_response = mock.MagicMock() mock_response.metrics = mock_metrics @@ -319,6 +328,8 @@ def test_end_agent_span(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 60) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 100) mock_span.add_event.assert_any_call( "gen_ai.choice", attributes={"message": "Agent response", "finish_reason": "end_turn"}, diff --git a/tests_integ/test_bedrock_cache_point.py b/tests_integ/test_bedrock_cache_point.py index 82bca22a..8d6d3c16 100644 --- a/tests_integ/test_bedrock_cache_point.py +++ b/tests_integ/test_bedrock_cache_point.py @@ -16,16 +16,8 @@ def test_bedrock_cache_point(): {"role": "assistant", "content": [{"text": "Blue!"}]}, ] - cache_point_usage = 0 + agent = Agent(messages=messages, load_tools_from_directory=False) + response = agent("What is favorite color?") - def cache_point_callback_handler(**kwargs): - nonlocal cache_point_usage - if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]: - metadata = kwargs["event"]["metadata"] - if "usage" in metadata and metadata["usage"]: - if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]: - cache_point_usage += 1 - - agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False) - agent("What is favorite color?") - assert cache_point_usage > 0 + usage = response.metrics.accumulated_usage + assert usage["cacheReadInputTokens"] >= 0 or usage["cacheWriteInputTokens"] > 0 # At least one should have tokens