Skip to content

Commit 8a727a9

Browse files
committed
chore: allow custom tracer provider to Agent
1 parent 264f511 commit 8a727a9

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

src/strands/agent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
record_direct_tool_call: bool = True,
217217
load_tools_from_directory: bool = True,
218218
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
219+
tracer_provider: Optional[trace.TracerProvider] = None,
219220
):
220221
"""Initialize the Agent with the specified configuration.
221222
@@ -248,6 +249,7 @@ def __init__(
248249
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
249250
Defaults to True.
250251
trace_attributes: Custom trace attributes to apply to the agent's trace span.
252+
tracer_provider: Custom trace provider to apply to the agents' tracer.
251253
252254
Raises:
253255
ValueError: If max_parallel_tools is less than 1.
@@ -306,7 +308,7 @@ def __init__(
306308
self.event_loop_metrics = EventLoopMetrics()
307309

308310
# Initialize tracer instance (no-op if not configured)
309-
self.tracer = get_tracer()
311+
self.tracer = get_tracer(tracer_provider=tracer_provider)
310312
self.trace_span: Optional[trace.Span] = None
311313

312314
self.tool_caller = Agent.ToolCaller(self)

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
140140
if results_truncated:
141141
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
142142
return
143-
143+
144144
# Try to trim index id when tool result cannot be truncated anymore
145145
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
146146
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

src/strands/telemetry/tracer.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
from importlib.metadata import version
1212
from typing import Any, Dict, Mapping, Optional
1313

14-
from opentelemetry import trace
14+
from opentelemetry import propagate, trace
15+
from opentelemetry.baggage.propagation import W3CBaggagePropagator
1516
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
17+
from opentelemetry.propagators.composite import CompositePropagator
1618
from opentelemetry.sdk.resources import Resource
1719
from opentelemetry.sdk.trace import TracerProvider
1820
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor
1921
from opentelemetry.trace import StatusCode
22+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
2023

2124
from ..agent.agent_result import AgentResult
2225
from ..types.content import Message, Messages
@@ -94,6 +97,7 @@ def __init__(
9497
otlp_endpoint: Optional[str] = None,
9598
otlp_headers: Optional[Dict[str, str]] = None,
9699
enable_console_export: Optional[bool] = None,
100+
tracer_provider: Optional[TracerProvider] = None,
97101
):
98102
"""Initialize the tracer.
99103
@@ -102,6 +106,7 @@ def __init__(
102106
otlp_endpoint: OTLP endpoint URL for sending traces.
103107
otlp_headers: Headers to include with OTLP requests.
104108
enable_console_export: Whether to also export traces to console.
109+
tracer_provider: Optional existing TracerProvider to use instead of creating a new one.
105110
"""
106111
# Check environment variables first
107112
env_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
@@ -133,10 +138,22 @@ def __init__(
133138

134139
self.service_name = service_name
135140
self.otlp_headers = otlp_headers or {}
136-
self.tracer_provider: Optional[TracerProvider] = None
141+
self.tracer_provider = tracer_provider
137142
self.tracer: Optional[trace.Tracer] = None
138143

139-
if self.otlp_endpoint or self.enable_console_export:
144+
propagate.set_global_textmap(
145+
CompositePropagator(
146+
[
147+
W3CBaggagePropagator(),
148+
TraceContextTextMapPropagator(),
149+
]
150+
)
151+
)
152+
if self.tracer_provider:
153+
# Use the provided tracer provider directly
154+
self.tracer = self.tracer_provider.get_tracer(self.service_name)
155+
elif self.otlp_endpoint or self.enable_console_export:
156+
# Create our own tracer provider
140157
self._initialize_tracer()
141158

142159
def _initialize_tracer(self) -> None:
@@ -549,6 +566,7 @@ def get_tracer(
549566
otlp_endpoint: Optional[str] = None,
550567
otlp_headers: Optional[Dict[str, str]] = None,
551568
enable_console_export: Optional[bool] = None,
569+
tracer_provider: Optional[TracerProvider] = None,
552570
) -> Tracer:
553571
"""Get or create the global tracer.
554572
@@ -557,18 +575,24 @@ def get_tracer(
557575
otlp_endpoint: OTLP endpoint URL for sending traces.
558576
otlp_headers: Headers to include with OTLP requests.
559577
enable_console_export: Whether to also export traces to console.
578+
tracer_provider: Optional existing TracerProvider to use instead of creating a new one.
560579
561580
Returns:
562581
The global tracer instance.
563582
"""
564583
global _tracer_instance
565584

566-
if _tracer_instance is None or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint): # type: ignore[unreachable]
585+
if (
586+
_tracer_instance is None
587+
or (otlp_endpoint and _tracer_instance.otlp_endpoint != otlp_endpoint)
588+
or (tracer_provider is not None and _tracer_instance.tracer_provider != tracer_provider)
589+
):
567590
_tracer_instance = Tracer(
568591
service_name=service_name,
569592
otlp_endpoint=otlp_endpoint,
570593
otlp_headers=otlp_headers,
571594
enable_console_export=enable_console_export,
595+
tracer_provider=tracer_provider,
572596
)
573597

574598
return _tracer_instance

tests/strands/telemetry/test_tracer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from unittest import mock
55

66
import pytest
7-
from opentelemetry.trace import StatusCode # type: ignore
7+
from opentelemetry.trace import (
8+
NoOpTracerProvider,
9+
StatusCode, # type: ignore
10+
)
811

912
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
1013
from strands.types.streaming import Usage
@@ -104,6 +107,12 @@ def env_with_both():
104107
yield
105108

106109

110+
@pytest.fixture
111+
def mock_initialize():
112+
with mock.patch("strands.telemetry.tracer.Tracer._initialize_tracer") as mock_initialize:
113+
yield mock_initialize
114+
115+
107116
def test_init_default():
108117
"""Test initializing the Tracer with default parameters."""
109118
tracer = Tracer()
@@ -486,6 +495,22 @@ def test_initialize_tracer_with_invalid_otlp_endpoint(
486495
mock_set_tracer_provider.assert_called_once_with(mock_tracer_provider.return_value)
487496

488497

498+
def test_initialize_tracer_with_custom_tracer_provider(mock_initialize):
499+
"""Test initializing the tracer with NoOpTracerProvider."""
500+
noop_provider = NoOpTracerProvider()
501+
502+
tracer = Tracer(tracer_provider=noop_provider)
503+
504+
# Verify the NoOp provider is used
505+
assert tracer.tracer_provider == noop_provider
506+
507+
# Verify tracer is set (will be a NoOp tracer)
508+
assert tracer.tracer is not None
509+
510+
# Verify _initialize_tracer was NOT called
511+
mock_initialize.assert_not_called()
512+
513+
489514
def test_end_span_with_exception_handling(mock_span):
490515
"""Test ending a span with exception handling."""
491516
tracer = Tracer()

0 commit comments

Comments
 (0)