1
+ from dataclasses import dataclass
2
+ from dataclasses import field
1
3
import json
2
4
import os
3
5
import time
4
6
from typing import Any
7
+ from typing import Callable
5
8
from typing import Dict
6
9
from typing import List
7
10
from typing import Optional
11
+ from typing import TypedDict
8
12
from typing import Union
13
+ from typing import cast
9
14
10
15
import ddtrace
11
16
from ddtrace import config
97
102
}
98
103
99
104
105
+ @dataclass
106
+ class LLMObsSpan :
107
+ class Message (TypedDict ):
108
+ content : str
109
+
110
+ input_messages : List [Message ] = field (default_factory = list )
111
+ output_messages : List [Message ] = field (default_factory = list )
112
+ _tags : Dict [str , str ] = field (default_factory = dict )
113
+
114
+ def get_tag (self , key : str ) -> Optional [str ]:
115
+ return self ._tags .get (key )
116
+
117
+
100
118
class LLMObs (Service ):
101
119
_instance = None # type: LLMObs
102
120
enabled = False
103
121
104
- def __init__ (self , tracer = None ):
122
+ def __init__ (
123
+ self ,
124
+ tracer : Tracer = None ,
125
+ span_processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ,
126
+ ):
105
127
super (LLMObs , self ).__init__ ()
106
128
self .tracer = tracer or ddtrace .tracer
107
129
self ._llmobs_context_provider = LLMObsContextProvider ()
130
+ self ._user_span_processor = span_processor
108
131
agentless_enabled = config ._llmobs_agentless_enabled if config ._llmobs_agentless_enabled is not None else True
109
132
self ._llmobs_span_writer = LLMObsSpanWriter (
110
133
interval = float (os .getenv ("_DD_LLMOBS_WRITER_INTERVAL" , 1.0 )),
@@ -156,12 +179,14 @@ def _submit_llmobs_span(self, span: Span) -> None:
156
179
if self ._evaluator_runner :
157
180
self ._evaluator_runner .enqueue (span_event , span )
158
181
159
- @classmethod
160
- def _llmobs_span_event (cls , span : Span ) -> Dict [str , Any ]:
182
+ def _llmobs_span_event (self , span : Span ) -> Dict [str , Any ]:
161
183
"""Span event object structure."""
162
184
span_kind = span ._get_ctx_item (SPAN_KIND )
163
185
if not span_kind :
164
186
raise KeyError ("Span kind not found in span context" )
187
+
188
+ llmobs_span = LLMObsSpan ()
189
+
165
190
meta : Dict [str , Any ] = {"span.kind" : span_kind , "input" : {}, "output" : {}}
166
191
if span_kind in ("llm" , "embedding" ) and span ._get_ctx_item (MODEL_NAME ) is not None :
167
192
meta ["model_name" ] = span ._get_ctx_item (MODEL_NAME )
@@ -170,14 +195,14 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
170
195
171
196
input_messages = span ._get_ctx_item (INPUT_MESSAGES )
172
197
if span_kind == "llm" and input_messages is not None :
173
- meta [ "input" ][ "messages" ] = enforce_message_role (input_messages )
198
+ llmobs_span . input_messages = cast ( List [ LLMObsSpan . Message ], enforce_message_role (input_messages ) )
174
199
175
200
if span ._get_ctx_item (INPUT_VALUE ) is not None :
176
201
meta ["input" ]["value" ] = safe_json (span ._get_ctx_item (INPUT_VALUE ), ensure_ascii = False )
177
202
178
203
output_messages = span ._get_ctx_item (OUTPUT_MESSAGES )
179
204
if span_kind == "llm" and output_messages is not None :
180
- meta [ "output" ][ "messages" ] = enforce_message_role (output_messages )
205
+ llmobs_span . output_messages = cast ( List [ LLMObsSpan . Message ], enforce_message_role (output_messages ) )
181
206
182
207
if span_kind == "embedding" and span ._get_ctx_item (INPUT_DOCUMENTS ) is not None :
183
208
meta ["input" ]["documents" ] = span ._get_ctx_item (INPUT_DOCUMENTS )
@@ -201,6 +226,26 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
201
226
ERROR_TYPE : span .get_tag (ERROR_TYPE ),
202
227
}
203
228
)
229
+
230
+ if self ._user_span_processor :
231
+ error = False
232
+ try :
233
+ llmobs_span ._tags = cast (Dict [str , str ], span ._get_ctx_item (TAGS ))
234
+ user_llmobs_span = self ._user_span_processor (llmobs_span )
235
+ if not isinstance (user_llmobs_span , LLMObsSpan ):
236
+ raise TypeError ("User span processor must return an LLMObsSpan, got %r" % type (user_llmobs_span ))
237
+ llmobs_span = user_llmobs_span
238
+ except Exception as e :
239
+ log .error ("Error in LLMObs span processor (%r): %r" , self ._user_span_processor , e )
240
+ error = True
241
+ finally :
242
+ telemetry .record_llmobs_user_processor_called (error )
243
+
244
+ if llmobs_span .input_messages is not None :
245
+ meta ["input" ]["messages" ] = llmobs_span .input_messages
246
+ if llmobs_span .output_messages is not None :
247
+ meta ["output" ]["messages" ] = llmobs_span .output_messages
248
+
204
249
if not meta ["input" ]:
205
250
meta .pop ("input" )
206
251
if not meta ["output" ]:
@@ -228,7 +273,7 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
228
273
span ._set_ctx_item (SESSION_ID , session_id )
229
274
llmobs_span_event ["session_id" ] = session_id
230
275
231
- llmobs_span_event ["tags" ] = cls ._llmobs_tags (span , ml_app , session_id )
276
+ llmobs_span_event ["tags" ] = self ._llmobs_tags (span , ml_app , session_id )
232
277
233
278
span_links = span ._get_ctx_item (SPAN_LINKS )
234
279
if isinstance (span_links , list ) and span_links :
@@ -332,6 +377,7 @@ def enable(
332
377
api_key : Optional [str ] = None ,
333
378
env : Optional [str ] = None ,
334
379
service : Optional [str ] = None ,
380
+ span_processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ,
335
381
_tracer : Optional [Tracer ] = None ,
336
382
_auto : bool = False ,
337
383
) -> None :
@@ -372,9 +418,9 @@ def enable(
372
418
)
373
419
374
420
config ._llmobs_agentless_enabled = should_use_agentless (
375
- user_defined_agentless_enabled = agentless_enabled
376
- if agentless_enabled is not None
377
- else config . _llmobs_agentless_enabled
421
+ user_defined_agentless_enabled = (
422
+ agentless_enabled if agentless_enabled is not None else config . _llmobs_agentless_enabled
423
+ )
378
424
)
379
425
380
426
if config ._llmobs_agentless_enabled :
@@ -404,7 +450,7 @@ def enable(
404
450
cls ._patch_integrations ()
405
451
406
452
# override the default _instance with a new tracer
407
- cls ._instance = cls (tracer = _tracer )
453
+ cls ._instance = cls (tracer = _tracer , span_processor = span_processor )
408
454
cls .enabled = True
409
455
cls ._instance .start ()
410
456
@@ -427,6 +473,18 @@ def enable(
427
473
finally :
428
474
telemetry .record_llmobs_enabled (error , config ._llmobs_agentless_enabled , config ._dd_site , start_ns , _auto )
429
475
476
+ @classmethod
477
+ def register_processor (cls , processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ) -> None :
478
+ """Register a processor to be called on each LLMObs span.
479
+
480
+ This can be used to modify the span before it is sent to LLMObs. For example, you can modify the input/output.
481
+
482
+ To deregister the processor, call `register_processor(None)`.
483
+
484
+ :param processor: A function that takes an LLMObsSpan and returns an LLMObsSpan.
485
+ """
486
+ cls ._instance ._user_span_processor = processor
487
+
430
488
@classmethod
431
489
def _integration_is_enabled (cls , integration : str ) -> bool :
432
490
if integration not in SUPPORTED_LLMOBS_INTEGRATIONS :
0 commit comments