Skip to content

Commit 230b1ca

Browse files
committed
feat(llmobs): add span processor
Add capability to add a span processor. The processor can be used to mutate or redact sensitive data contained in inputs and outputs from LLM calls. ```python def my_processor(span): for message in span.output_messages: message["content"] = "" LLMObs.enable(span_processor=my_processor) LLMObs.add_processor(my_processor) ```
1 parent 83dea4c commit 230b1ca

File tree

7 files changed

+255
-14
lines changed

7 files changed

+255
-14
lines changed

ddtrace/llmobs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from ._llmobs import LLMObs
10+
from ._llmobs import LLMObsSpan
1011

1112

12-
__all__ = ["LLMObs"]
13+
__all__ = ["LLMObs", "LLMObsSpan"]

ddtrace/llmobs/_llmobs.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from dataclasses import dataclass
2+
from dataclasses import field
13
import json
24
import os
35
import time
46
from typing import Any
7+
from typing import Callable
58
from typing import Dict
69
from typing import List
710
from typing import Optional
11+
from typing import TypedDict
812
from typing import Union
13+
from typing import cast
914

1015
import ddtrace
1116
from ddtrace import config
@@ -97,14 +102,32 @@
97102
}
98103

99104

105+
@dataclass
106+
class LLMObsSpan:
107+
class Message(TypedDict):
108+
content: str
109+
110+
input_messages: List[Message] = field(default_factory=list)
111+
output_messages: List[Message] = field(default_factory=list)
112+
_tags: Dict[str, str] = field(default_factory=dict)
113+
114+
def get_tag(self, key: str) -> Optional[str]:
115+
return self._tags.get(key)
116+
117+
100118
class LLMObs(Service):
101119
_instance = None # type: LLMObs
102120
enabled = False
103121

104-
def __init__(self, tracer=None):
122+
def __init__(
123+
self,
124+
tracer: Tracer = None,
125+
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
126+
):
105127
super(LLMObs, self).__init__()
106128
self.tracer = tracer or ddtrace.tracer
107129
self._llmobs_context_provider = LLMObsContextProvider()
130+
self._user_span_processor = span_processor
108131
agentless_enabled = config._llmobs_agentless_enabled if config._llmobs_agentless_enabled is not None else True
109132
self._llmobs_span_writer = LLMObsSpanWriter(
110133
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
@@ -156,12 +179,14 @@ def _submit_llmobs_span(self, span: Span) -> None:
156179
if self._evaluator_runner:
157180
self._evaluator_runner.enqueue(span_event, span)
158181

159-
@classmethod
160-
def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
182+
def _llmobs_span_event(self, span: Span) -> Dict[str, Any]:
161183
"""Span event object structure."""
162184
span_kind = span._get_ctx_item(SPAN_KIND)
163185
if not span_kind:
164186
raise KeyError("Span kind not found in span context")
187+
188+
llmobs_span = LLMObsSpan()
189+
165190
meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
166191
if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None:
167192
meta["model_name"] = span._get_ctx_item(MODEL_NAME)
@@ -170,14 +195,14 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
170195

171196
input_messages = span._get_ctx_item(INPUT_MESSAGES)
172197
if span_kind == "llm" and input_messages is not None:
173-
meta["input"]["messages"] = enforce_message_role(input_messages)
198+
llmobs_span.input_messages = cast(List[LLMObsSpan.Message], enforce_message_role(input_messages))
174199

175200
if span._get_ctx_item(INPUT_VALUE) is not None:
176201
meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE), ensure_ascii=False)
177202

178203
output_messages = span._get_ctx_item(OUTPUT_MESSAGES)
179204
if span_kind == "llm" and output_messages is not None:
180-
meta["output"]["messages"] = enforce_message_role(output_messages)
205+
llmobs_span.output_messages = cast(List[LLMObsSpan.Message], enforce_message_role(output_messages))
181206

182207
if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None:
183208
meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS)
@@ -201,6 +226,26 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
201226
ERROR_TYPE: span.get_tag(ERROR_TYPE),
202227
}
203228
)
229+
230+
if self._user_span_processor:
231+
error = False
232+
try:
233+
llmobs_span._tags = cast(Dict[str, str], span._get_ctx_item(TAGS))
234+
user_llmobs_span = self._user_span_processor(llmobs_span)
235+
if not isinstance(user_llmobs_span, LLMObsSpan):
236+
raise TypeError("User span processor must return an LLMObsSpan, got %r" % type(user_llmobs_span))
237+
llmobs_span = user_llmobs_span
238+
except Exception as e:
239+
log.error("Error in LLMObs span processor (%r): %r", self._user_span_processor, e)
240+
error = True
241+
finally:
242+
telemetry.record_llmobs_user_processor_called(error)
243+
244+
if llmobs_span.input_messages is not None:
245+
meta["input"]["messages"] = llmobs_span.input_messages
246+
if llmobs_span.output_messages is not None:
247+
meta["output"]["messages"] = llmobs_span.output_messages
248+
204249
if not meta["input"]:
205250
meta.pop("input")
206251
if not meta["output"]:
@@ -228,7 +273,7 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
228273
span._set_ctx_item(SESSION_ID, session_id)
229274
llmobs_span_event["session_id"] = session_id
230275

231-
llmobs_span_event["tags"] = cls._llmobs_tags(span, ml_app, session_id)
276+
llmobs_span_event["tags"] = self._llmobs_tags(span, ml_app, session_id)
232277

233278
span_links = span._get_ctx_item(SPAN_LINKS)
234279
if isinstance(span_links, list) and span_links:
@@ -332,6 +377,7 @@ def enable(
332377
api_key: Optional[str] = None,
333378
env: Optional[str] = None,
334379
service: Optional[str] = None,
380+
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
335381
_tracer: Optional[Tracer] = None,
336382
_auto: bool = False,
337383
) -> None:
@@ -372,9 +418,9 @@ def enable(
372418
)
373419

374420
config._llmobs_agentless_enabled = should_use_agentless(
375-
user_defined_agentless_enabled=agentless_enabled
376-
if agentless_enabled is not None
377-
else config._llmobs_agentless_enabled
421+
user_defined_agentless_enabled=(
422+
agentless_enabled if agentless_enabled is not None else config._llmobs_agentless_enabled
423+
)
378424
)
379425

380426
if config._llmobs_agentless_enabled:
@@ -404,7 +450,7 @@ def enable(
404450
cls._patch_integrations()
405451

406452
# override the default _instance with a new tracer
407-
cls._instance = cls(tracer=_tracer)
453+
cls._instance = cls(tracer=_tracer, span_processor=span_processor)
408454
cls.enabled = True
409455
cls._instance.start()
410456

@@ -427,6 +473,18 @@ def enable(
427473
finally:
428474
telemetry.record_llmobs_enabled(error, config._llmobs_agentless_enabled, config._dd_site, start_ns, _auto)
429475

476+
@classmethod
477+
def register_processor(cls, processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None) -> None:
478+
"""Register a processor to be called on each LLMObs span.
479+
480+
This can be used to modify the span before it is sent to LLMObs. For example, you can modify the input/output.
481+
482+
To deregister the processor, call `register_processor(None)`.
483+
484+
:param processor: A function that takes an LLMObsSpan and returns an LLMObsSpan.
485+
"""
486+
cls._instance._user_span_processor = processor
487+
430488
@classmethod
431489
def _integration_is_enabled(cls, integration: str) -> bool:
432490
if integration not in SUPPORTED_LLMOBS_INTEGRATIONS:

ddtrace/llmobs/_telemetry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class LLMObsTelemetryMetrics:
3131
USER_FLUSHES = "user_flush"
3232
INJECT_HEADERS = "inject_distributed_headers"
3333
ACTIVATE_HEADERS = "activate_distributed_headers"
34+
USER_PROCESSOR_CALLED = "user_processor_called"
3435

3536

3637
def _find_integration_from_tags(tags):
@@ -156,6 +157,16 @@ def record_llmobs_annotate(span: Optional[Span], error: Optional[str]):
156157
)
157158

158159

160+
def record_llmobs_user_processor_called(error: bool):
161+
tags = [("error", "1" if error else "0")]
162+
telemetry_writer.add_count_metric(
163+
namespace=TELEMETRY_NAMESPACE.MLOBS,
164+
name=LLMObsTelemetryMetrics.USER_PROCESSOR_CALLED,
165+
value=1,
166+
tags=tuple(tags),
167+
)
168+
169+
159170
def record_llmobs_submit_evaluation(join_on: Dict[str, Any], metric_type: str, error: Optional[str]):
160171
_metric_type = metric_type if metric_type in ("categorical", "score") else "other"
161172
custom_joining_key = str(int(join_on.get("tag") is not None))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
LLM Observability: add processor capability to process span inputs and outputs. See usage documentation [here](https://docs.datadoghq.com/llm_observability/setup/sdk/python/#span-processing).

tests/llmobs/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,8 @@ def _expected_span_link(span_event, link_from, link_to):
832832

833833

834834
class TestLLMObsSpanWriter(LLMObsSpanWriter):
835+
__test__ = False
836+
835837
def __init__(self, *args, **kwargs):
836838
super().__init__(*args, **kwargs)
837839
self.events = []

tests/llmobs/conftest.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from http.server import HTTPServer
33
import json
44
import os
5+
import pprint
56
import threading
67
import time
78

@@ -224,26 +225,35 @@ def _llmobs_backend():
224225

225226
@pytest.fixture
226227
def llmobs_backend(_llmobs_backend):
227-
_, reqs = _llmobs_backend
228+
_url, reqs = _llmobs_backend
228229

229230
class _LLMObsBackend:
231+
def url(self):
232+
return _url
233+
230234
def wait_for_num_events(self, num, attempts=1000):
231235
for _ in range(attempts):
232236
if len(reqs) == num:
233237
return [json.loads(r["body"]) for r in reqs]
234238
# time.sleep will yield the GIL so the server can process the request
235239
time.sleep(0.001)
236240
else:
237-
raise TimeoutError(f"Expected {num} events, got {len(reqs)}")
241+
raise TimeoutError(f"Expected {num} events, got {len(reqs)}: {pprint.pprint(reqs)}")
238242

239243
return _LLMObsBackend()
240244

241245

246+
@pytest.fixture
247+
def llmobs_enable_opts():
248+
yield {}
249+
250+
242251
@pytest.fixture
243252
def llmobs(
244253
ddtrace_global_config,
245254
monkeypatch,
246255
tracer,
256+
llmobs_enable_opts,
247257
llmobs_env,
248258
llmobs_span_writer,
249259
mock_llmobs_eval_metric_writer,
@@ -256,7 +266,7 @@ def llmobs(
256266
global_config.update(ddtrace_global_config)
257267
# TODO: remove once rest of tests are moved off of global config tampering
258268
with override_global_config(global_config):
259-
llmobs_service.enable(_tracer=tracer)
269+
llmobs_service.enable(_tracer=tracer, **llmobs_enable_opts)
260270
llmobs_service._instance._llmobs_span_writer = llmobs_span_writer
261271
llmobs_service._instance._llmobs_span_writer.start()
262272
yield llmobs_service

0 commit comments

Comments
 (0)