Skip to content

Commit cc8e98c

Browse files
authored
feat(llmobs): add span processor (#13426)
Add capability to add a span processor. The processor can be used to mutate or redact sensitive data contained in inputs and outputs from LLM calls. ```python from ddtrace.llmobs import LLMObsSpan def my_processor(span: LLMObsSpan): for message in span.output: message["content"] = "" LLMObs.enable(span_processor=my_processor) LLMObs.register_processor(my_processor) ``` Public docs: DataDog/documentation#29365 Shared tests: TODO Closes: #11179
1 parent 8ee6868 commit cc8e98c

File tree

8 files changed

+307
-20
lines changed

8 files changed

+307
-20
lines changed

ddtrace/llmobs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from ._llmobs import LLMObs
10+
from ._llmobs import LLMObsSpan
1011

1112

12-
__all__ = ["LLMObs"]
13+
__all__ = ["LLMObs", "LLMObsSpan"]

ddtrace/llmobs/_llmobs.py

Lines changed: 109 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from dataclasses import dataclass
2+
from dataclasses import field
13
import json
24
import os
35
import time
46
from typing import Any
7+
from typing import Callable
58
from typing import Dict
69
from typing import List
10+
from typing import Literal
711
from typing import Optional
812
from typing import Tuple
13+
from typing import TypedDict
914
from typing import Union
1015
from typing import cast
1116

@@ -101,14 +106,50 @@
101106
}
102107

103108

109+
@dataclass
110+
class LLMObsSpan:
111+
"""LLMObs span object.
112+
113+
Passed to the `span_processor` function in the `enable` or `register_processor` methods.
114+
115+
Example::
116+
def span_processor(span: LLMObsSpan) -> LLMObsSpan:
117+
if span.get_tag("no_input") == "1":
118+
span.input = []
119+
return span
120+
"""
121+
122+
class Message(TypedDict):
123+
content: str
124+
role: str
125+
126+
input: List[Message] = field(default_factory=list)
127+
output: List[Message] = field(default_factory=list)
128+
_tags: Dict[str, str] = field(default_factory=dict)
129+
130+
def get_tag(self, key: str) -> Optional[str]:
131+
"""Get a tag from the span.
132+
133+
:param str key: The key of the tag to get.
134+
:return: The value of the tag or None if the tag does not exist.
135+
:rtype: Optional[str]
136+
"""
137+
return self._tags.get(key)
138+
139+
104140
class LLMObs(Service):
105141
_instance = None # type: LLMObs
106142
enabled = False
107143

108-
def __init__(self, tracer: Optional[Tracer] = None):
144+
def __init__(
145+
self,
146+
tracer: Optional[Tracer] = None,
147+
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
148+
) -> None:
109149
super(LLMObs, self).__init__()
110150
self.tracer = tracer or ddtrace.tracer
111151
self._llmobs_context_provider = LLMObsContextProvider()
152+
self._user_span_processor = span_processor
112153
agentless_enabled = config._llmobs_agentless_enabled if config._llmobs_agentless_enabled is not None else True
113154
self._llmobs_span_writer = LLMObsSpanWriter(
114155
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
@@ -160,33 +201,46 @@ def _submit_llmobs_span(self, span: Span) -> None:
160201
if self._evaluator_runner:
161202
self._evaluator_runner.enqueue(span_event, span)
162203

163-
@classmethod
164-
def _llmobs_span_event(cls, span: Span) -> LLMObsSpanEvent:
204+
def _llmobs_span_event(self, span: Span) -> LLMObsSpanEvent:
165205
"""Span event object structure."""
166206
span_kind = span._get_ctx_item(SPAN_KIND)
167207
if not span_kind:
168208
raise KeyError("Span kind not found in span context")
209+
210+
llmobs_span = LLMObsSpan()
211+
169212
meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
170213
if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None:
171214
meta["model_name"] = span._get_ctx_item(MODEL_NAME)
172215
meta["model_provider"] = (span._get_ctx_item(MODEL_PROVIDER) or "custom").lower()
173216
meta["metadata"] = span._get_ctx_item(METADATA) or {}
174217

218+
input_type: Literal["value", "messages", ""] = ""
219+
output_type: Literal["value", "messages", ""] = ""
220+
if span._get_ctx_item(INPUT_VALUE) is not None:
221+
input_type = "value"
222+
llmobs_span.input = [
223+
{"content": safe_json(span._get_ctx_item(INPUT_VALUE), ensure_ascii=False), "role": ""}
224+
]
225+
175226
input_messages = span._get_ctx_item(INPUT_MESSAGES)
176227
if span_kind == "llm" and input_messages is not None:
177-
meta["input"]["messages"] = enforce_message_role(input_messages)
228+
input_type = "messages"
229+
llmobs_span.input = cast(List[LLMObsSpan.Message], enforce_message_role(input_messages))
178230

179-
if span._get_ctx_item(INPUT_VALUE) is not None:
180-
meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE), ensure_ascii=False)
231+
if span._get_ctx_item(OUTPUT_VALUE) is not None:
232+
output_type = "value"
233+
llmobs_span.output = [
234+
{"content": safe_json(span._get_ctx_item(OUTPUT_VALUE), ensure_ascii=False), "role": ""}
235+
]
181236

182237
output_messages = span._get_ctx_item(OUTPUT_MESSAGES)
183238
if span_kind == "llm" and output_messages is not None:
184-
meta["output"]["messages"] = enforce_message_role(output_messages)
239+
output_type = "messages"
240+
llmobs_span.output = cast(List[LLMObsSpan.Message], enforce_message_role(output_messages))
185241

186242
if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None:
187243
meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS)
188-
if span._get_ctx_item(OUTPUT_VALUE) is not None:
189-
meta["output"]["value"] = safe_json(span._get_ctx_item(OUTPUT_VALUE), ensure_ascii=False)
190244
if span_kind == "retrieval" and span._get_ctx_item(OUTPUT_DOCUMENTS) is not None:
191245
meta["output"]["documents"] = span._get_ctx_item(OUTPUT_DOCUMENTS)
192246
if span._get_ctx_item(INPUT_PROMPT) is not None:
@@ -205,6 +259,32 @@ def _llmobs_span_event(cls, span: Span) -> LLMObsSpanEvent:
205259
ERROR_TYPE: span.get_tag(ERROR_TYPE),
206260
}
207261
)
262+
263+
if self._user_span_processor:
264+
error = False
265+
try:
266+
llmobs_span._tags = cast(Dict[str, str], span._get_ctx_item(TAGS))
267+
user_llmobs_span = self._user_span_processor(llmobs_span)
268+
if not isinstance(user_llmobs_span, LLMObsSpan):
269+
raise TypeError("User span processor must return an LLMObsSpan, got %r" % type(user_llmobs_span))
270+
llmobs_span = user_llmobs_span
271+
except Exception as e:
272+
log.error("Error in LLMObs span processor (%r): %r", self._user_span_processor, e)
273+
error = True
274+
finally:
275+
telemetry.record_llmobs_user_processor_called(error)
276+
277+
if llmobs_span.input is not None:
278+
if input_type == "messages":
279+
meta["input"]["messages"] = llmobs_span.input
280+
elif input_type == "value":
281+
meta["input"]["value"] = llmobs_span.input[0]["content"]
282+
if llmobs_span.output is not None:
283+
if output_type == "messages":
284+
meta["output"]["messages"] = llmobs_span.output
285+
elif output_type == "value":
286+
meta["output"]["value"] = llmobs_span.output[0]["content"]
287+
208288
if not meta["input"]:
209289
meta.pop("input")
210290
if not meta["output"]:
@@ -233,7 +313,7 @@ def _llmobs_span_event(cls, span: Span) -> LLMObsSpanEvent:
233313
span._set_ctx_item(SESSION_ID, session_id)
234314
llmobs_span_event["session_id"] = session_id
235315

236-
llmobs_span_event["tags"] = cls._llmobs_tags(span, ml_app, session_id)
316+
llmobs_span_event["tags"] = self._llmobs_tags(span, ml_app, session_id)
237317

238318
span_links = span._get_ctx_item(SPAN_LINKS)
239319
if isinstance(span_links, list) and span_links:
@@ -339,6 +419,7 @@ def enable(
339419
api_key: Optional[str] = None,
340420
env: Optional[str] = None,
341421
service: Optional[str] = None,
422+
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
342423
_tracer: Optional[Tracer] = None,
343424
_auto: bool = False,
344425
) -> None:
@@ -352,6 +433,8 @@ def enable(
352433
:param str api_key: Your datadog api key.
353434
:param str env: Your environment name.
354435
:param str service: Your service name.
436+
:param Callable[[LLMObsSpan], LLMObsSpan] span_processor: A function that takes an LLMObsSpan and returns an
437+
LLMObsSpan.
355438
"""
356439
if cls.enabled:
357440
log.debug("%s already enabled", cls.__name__)
@@ -379,9 +462,9 @@ def enable(
379462
)
380463

381464
config._llmobs_agentless_enabled = should_use_agentless(
382-
user_defined_agentless_enabled=agentless_enabled
383-
if agentless_enabled is not None
384-
else config._llmobs_agentless_enabled
465+
user_defined_agentless_enabled=(
466+
agentless_enabled if agentless_enabled is not None else config._llmobs_agentless_enabled
467+
)
385468
)
386469

387470
if config._llmobs_agentless_enabled:
@@ -411,7 +494,7 @@ def enable(
411494
cls._patch_integrations()
412495

413496
# override the default _instance with a new tracer
414-
cls._instance = cls(tracer=_tracer)
497+
cls._instance = cls(tracer=_tracer, span_processor=span_processor)
415498
cls.enabled = True
416499
cls._instance.start()
417500

@@ -434,6 +517,18 @@ def enable(
434517
finally:
435518
telemetry.record_llmobs_enabled(error, config._llmobs_agentless_enabled, config._dd_site, start_ns, _auto)
436519

520+
@classmethod
521+
def register_processor(cls, processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None) -> None:
522+
"""Register a processor to be called on each LLMObs span.
523+
524+
This can be used to modify the span before it is sent to LLMObs. For example, you can modify the input/output.
525+
526+
To deregister the processor, call `register_processor(None)`.
527+
528+
:param processor: A function that takes an LLMObsSpan and returns an LLMObsSpan.
529+
"""
530+
cls._instance._user_span_processor = processor
531+
437532
@classmethod
438533
def _integration_is_enabled(cls, integration: str) -> bool:
439534
if integration not in SUPPORTED_LLMOBS_INTEGRATIONS:

ddtrace/llmobs/_telemetry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class LLMObsTelemetryMetrics:
3131
USER_FLUSHES = "user_flush"
3232
INJECT_HEADERS = "inject_distributed_headers"
3333
ACTIVATE_HEADERS = "activate_distributed_headers"
34+
USER_PROCESSOR_CALLED = "user_processor_called"
3435

3536

3637
def _find_integration_from_tags(tags):
@@ -156,6 +157,16 @@ def record_llmobs_annotate(span: Optional[Span], error: Optional[str]):
156157
)
157158

158159

160+
def record_llmobs_user_processor_called(error: bool) -> None:
161+
tags = [("error", "1" if error else "0")]
162+
telemetry_writer.add_count_metric(
163+
namespace=TELEMETRY_NAMESPACE.MLOBS,
164+
name=LLMObsTelemetryMetrics.USER_PROCESSOR_CALLED,
165+
value=1,
166+
tags=tuple(tags),
167+
)
168+
169+
159170
def record_llmobs_submit_evaluation(join_on: Dict[str, Any], metric_type: str, error: Optional[str]):
160171
_metric_type = metric_type if metric_type in ("categorical", "score") else "other"
161172
custom_joining_key = str(int(join_on.get("tag") is not None))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
LLM Observability: add processor capability to process span inputs and outputs. See usage documentation [here](https://docs.datadoghq.com/llm_observability/setup/sdk/python/#span-processing).

tests/llmobs/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,8 @@ def _expected_span_link(span_event, link_from, link_to):
832832

833833

834834
class TestLLMObsSpanWriter(LLMObsSpanWriter):
835+
__test__ = False
836+
835837
def __init__(self, *args, **kwargs):
836838
super().__init__(*args, **kwargs)
837839
self.events = []

tests/llmobs/conftest.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from http.server import HTTPServer
33
import json
44
import os
5+
import pprint
56
import threading
67
import time
78

@@ -224,26 +225,35 @@ def _llmobs_backend():
224225

225226
@pytest.fixture
226227
def llmobs_backend(_llmobs_backend):
227-
_, reqs = _llmobs_backend
228+
_url, reqs = _llmobs_backend
228229

229230
class _LLMObsBackend:
231+
def url(self):
232+
return _url
233+
230234
def wait_for_num_events(self, num, attempts=1000):
231235
for _ in range(attempts):
232236
if len(reqs) == num:
233237
return [json.loads(r["body"]) for r in reqs]
234238
# time.sleep will yield the GIL so the server can process the request
235239
time.sleep(0.001)
236240
else:
237-
raise TimeoutError(f"Expected {num} events, got {len(reqs)}")
241+
raise TimeoutError(f"Expected {num} events, got {len(reqs)}: {pprint.pprint(reqs)}")
238242

239243
return _LLMObsBackend()
240244

241245

246+
@pytest.fixture
247+
def llmobs_enable_opts():
248+
yield {}
249+
250+
242251
@pytest.fixture
243252
def llmobs(
244253
ddtrace_global_config,
245254
monkeypatch,
246255
tracer,
256+
llmobs_enable_opts,
247257
llmobs_env,
248258
llmobs_span_writer,
249259
mock_llmobs_eval_metric_writer,
@@ -256,7 +266,7 @@ def llmobs(
256266
global_config.update(ddtrace_global_config)
257267
# TODO: remove once rest of tests are moved off of global config tampering
258268
with override_global_config(global_config):
259-
llmobs_service.enable(_tracer=tracer)
269+
llmobs_service.enable(_tracer=tracer, **llmobs_enable_opts)
260270
llmobs_service._instance._llmobs_span_writer = llmobs_span_writer
261271
llmobs_service._instance._llmobs_span_writer.start()
262272
yield llmobs_service

0 commit comments

Comments
 (0)