1
+ from dataclasses import dataclass
2
+ from dataclasses import field
1
3
import json
2
4
import os
3
5
import time
4
6
from typing import Any
7
+ from typing import Callable
5
8
from typing import Dict
6
9
from typing import List
10
+ from typing import Literal
7
11
from typing import Optional
8
12
from typing import Tuple
13
+ from typing import TypedDict
9
14
from typing import Union
10
15
from typing import cast
11
16
101
106
}
102
107
103
108
109
+ @dataclass
110
+ class LLMObsSpan :
111
+ """LLMObs span object.
112
+
113
+ Passed to the `span_processor` function in the `enable` or `register_processor` methods.
114
+
115
+ Example::
116
+ def span_processor(span: LLMObsSpan) -> LLMObsSpan:
117
+ if span.get_tag("no_input") == "1":
118
+ span.input = []
119
+ return span
120
+ """
121
+
122
+ class Message (TypedDict ):
123
+ content : str
124
+ role : str
125
+
126
+ input : List [Message ] = field (default_factory = list )
127
+ output : List [Message ] = field (default_factory = list )
128
+ _tags : Dict [str , str ] = field (default_factory = dict )
129
+
130
+ def get_tag (self , key : str ) -> Optional [str ]:
131
+ """Get a tag from the span.
132
+
133
+ :param str key: The key of the tag to get.
134
+ :return: The value of the tag or None if the tag does not exist.
135
+ :rtype: Optional[str]
136
+ """
137
+ return self ._tags .get (key )
138
+
139
+
104
140
class LLMObs (Service ):
105
141
_instance = None # type: LLMObs
106
142
enabled = False
107
143
108
- def __init__ (self , tracer : Optional [Tracer ] = None ):
144
+ def __init__ (
145
+ self ,
146
+ tracer : Optional [Tracer ] = None ,
147
+ span_processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ,
148
+ ) -> None :
109
149
super (LLMObs , self ).__init__ ()
110
150
self .tracer = tracer or ddtrace .tracer
111
151
self ._llmobs_context_provider = LLMObsContextProvider ()
152
+ self ._user_span_processor = span_processor
112
153
agentless_enabled = config ._llmobs_agentless_enabled if config ._llmobs_agentless_enabled is not None else True
113
154
self ._llmobs_span_writer = LLMObsSpanWriter (
114
155
interval = float (os .getenv ("_DD_LLMOBS_WRITER_INTERVAL" , 1.0 )),
@@ -160,33 +201,46 @@ def _submit_llmobs_span(self, span: Span) -> None:
160
201
if self ._evaluator_runner :
161
202
self ._evaluator_runner .enqueue (span_event , span )
162
203
163
- @classmethod
164
- def _llmobs_span_event (cls , span : Span ) -> LLMObsSpanEvent :
204
+ def _llmobs_span_event (self , span : Span ) -> LLMObsSpanEvent :
165
205
"""Span event object structure."""
166
206
span_kind = span ._get_ctx_item (SPAN_KIND )
167
207
if not span_kind :
168
208
raise KeyError ("Span kind not found in span context" )
209
+
210
+ llmobs_span = LLMObsSpan ()
211
+
169
212
meta : Dict [str , Any ] = {"span.kind" : span_kind , "input" : {}, "output" : {}}
170
213
if span_kind in ("llm" , "embedding" ) and span ._get_ctx_item (MODEL_NAME ) is not None :
171
214
meta ["model_name" ] = span ._get_ctx_item (MODEL_NAME )
172
215
meta ["model_provider" ] = (span ._get_ctx_item (MODEL_PROVIDER ) or "custom" ).lower ()
173
216
meta ["metadata" ] = span ._get_ctx_item (METADATA ) or {}
174
217
218
+ input_type : Literal ["value" , "messages" , "" ] = ""
219
+ output_type : Literal ["value" , "messages" , "" ] = ""
220
+ if span ._get_ctx_item (INPUT_VALUE ) is not None :
221
+ input_type = "value"
222
+ llmobs_span .input = [
223
+ {"content" : safe_json (span ._get_ctx_item (INPUT_VALUE ), ensure_ascii = False ), "role" : "" }
224
+ ]
225
+
175
226
input_messages = span ._get_ctx_item (INPUT_MESSAGES )
176
227
if span_kind == "llm" and input_messages is not None :
177
- meta ["input" ]["messages" ] = enforce_message_role (input_messages )
228
+ input_type = "messages"
229
+ llmobs_span .input = cast (List [LLMObsSpan .Message ], enforce_message_role (input_messages ))
178
230
179
- if span ._get_ctx_item (INPUT_VALUE ) is not None :
180
- meta ["input" ]["value" ] = safe_json (span ._get_ctx_item (INPUT_VALUE ), ensure_ascii = False )
231
+ if span ._get_ctx_item (OUTPUT_VALUE ) is not None :
232
+ output_type = "value"
233
+ llmobs_span .output = [
234
+ {"content" : safe_json (span ._get_ctx_item (OUTPUT_VALUE ), ensure_ascii = False ), "role" : "" }
235
+ ]
181
236
182
237
output_messages = span ._get_ctx_item (OUTPUT_MESSAGES )
183
238
if span_kind == "llm" and output_messages is not None :
184
- meta ["output" ]["messages" ] = enforce_message_role (output_messages )
239
+ output_type = "messages"
240
+ llmobs_span .output = cast (List [LLMObsSpan .Message ], enforce_message_role (output_messages ))
185
241
186
242
if span_kind == "embedding" and span ._get_ctx_item (INPUT_DOCUMENTS ) is not None :
187
243
meta ["input" ]["documents" ] = span ._get_ctx_item (INPUT_DOCUMENTS )
188
- if span ._get_ctx_item (OUTPUT_VALUE ) is not None :
189
- meta ["output" ]["value" ] = safe_json (span ._get_ctx_item (OUTPUT_VALUE ), ensure_ascii = False )
190
244
if span_kind == "retrieval" and span ._get_ctx_item (OUTPUT_DOCUMENTS ) is not None :
191
245
meta ["output" ]["documents" ] = span ._get_ctx_item (OUTPUT_DOCUMENTS )
192
246
if span ._get_ctx_item (INPUT_PROMPT ) is not None :
@@ -205,6 +259,32 @@ def _llmobs_span_event(cls, span: Span) -> LLMObsSpanEvent:
205
259
ERROR_TYPE : span .get_tag (ERROR_TYPE ),
206
260
}
207
261
)
262
+
263
+ if self ._user_span_processor :
264
+ error = False
265
+ try :
266
+ llmobs_span ._tags = cast (Dict [str , str ], span ._get_ctx_item (TAGS ))
267
+ user_llmobs_span = self ._user_span_processor (llmobs_span )
268
+ if not isinstance (user_llmobs_span , LLMObsSpan ):
269
+ raise TypeError ("User span processor must return an LLMObsSpan, got %r" % type (user_llmobs_span ))
270
+ llmobs_span = user_llmobs_span
271
+ except Exception as e :
272
+ log .error ("Error in LLMObs span processor (%r): %r" , self ._user_span_processor , e )
273
+ error = True
274
+ finally :
275
+ telemetry .record_llmobs_user_processor_called (error )
276
+
277
+ if llmobs_span .input is not None :
278
+ if input_type == "messages" :
279
+ meta ["input" ]["messages" ] = llmobs_span .input
280
+ elif input_type == "value" :
281
+ meta ["input" ]["value" ] = llmobs_span .input [0 ]["content" ]
282
+ if llmobs_span .output is not None :
283
+ if output_type == "messages" :
284
+ meta ["output" ]["messages" ] = llmobs_span .output
285
+ elif output_type == "value" :
286
+ meta ["output" ]["value" ] = llmobs_span .output [0 ]["content" ]
287
+
208
288
if not meta ["input" ]:
209
289
meta .pop ("input" )
210
290
if not meta ["output" ]:
@@ -233,7 +313,7 @@ def _llmobs_span_event(cls, span: Span) -> LLMObsSpanEvent:
233
313
span ._set_ctx_item (SESSION_ID , session_id )
234
314
llmobs_span_event ["session_id" ] = session_id
235
315
236
- llmobs_span_event ["tags" ] = cls ._llmobs_tags (span , ml_app , session_id )
316
+ llmobs_span_event ["tags" ] = self ._llmobs_tags (span , ml_app , session_id )
237
317
238
318
span_links = span ._get_ctx_item (SPAN_LINKS )
239
319
if isinstance (span_links , list ) and span_links :
@@ -339,6 +419,7 @@ def enable(
339
419
api_key : Optional [str ] = None ,
340
420
env : Optional [str ] = None ,
341
421
service : Optional [str ] = None ,
422
+ span_processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ,
342
423
_tracer : Optional [Tracer ] = None ,
343
424
_auto : bool = False ,
344
425
) -> None :
@@ -352,6 +433,8 @@ def enable(
352
433
:param str api_key: Your datadog api key.
353
434
:param str env: Your environment name.
354
435
:param str service: Your service name.
436
+ :param Callable[[LLMObsSpan], LLMObsSpan] span_processor: A function that takes an LLMObsSpan and returns an
437
+ LLMObsSpan.
355
438
"""
356
439
if cls .enabled :
357
440
log .debug ("%s already enabled" , cls .__name__ )
@@ -379,9 +462,9 @@ def enable(
379
462
)
380
463
381
464
config ._llmobs_agentless_enabled = should_use_agentless (
382
- user_defined_agentless_enabled = agentless_enabled
383
- if agentless_enabled is not None
384
- else config . _llmobs_agentless_enabled
465
+ user_defined_agentless_enabled = (
466
+ agentless_enabled if agentless_enabled is not None else config . _llmobs_agentless_enabled
467
+ )
385
468
)
386
469
387
470
if config ._llmobs_agentless_enabled :
@@ -411,7 +494,7 @@ def enable(
411
494
cls ._patch_integrations ()
412
495
413
496
# override the default _instance with a new tracer
414
- cls ._instance = cls (tracer = _tracer )
497
+ cls ._instance = cls (tracer = _tracer , span_processor = span_processor )
415
498
cls .enabled = True
416
499
cls ._instance .start ()
417
500
@@ -434,6 +517,18 @@ def enable(
434
517
finally :
435
518
telemetry .record_llmobs_enabled (error , config ._llmobs_agentless_enabled , config ._dd_site , start_ns , _auto )
436
519
520
+ @classmethod
521
+ def register_processor (cls , processor : Optional [Callable [[LLMObsSpan ], LLMObsSpan ]] = None ) -> None :
522
+ """Register a processor to be called on each LLMObs span.
523
+
524
+ This can be used to modify the span before it is sent to LLMObs. For example, you can modify the input/output.
525
+
526
+ To deregister the processor, call `register_processor(None)`.
527
+
528
+ :param processor: A function that takes an LLMObsSpan and returns an LLMObsSpan.
529
+ """
530
+ cls ._instance ._user_span_processor = processor
531
+
437
532
@classmethod
438
533
def _integration_is_enabled (cls , integration : str ) -> bool :
439
534
if integration not in SUPPORTED_LLMOBS_INTEGRATIONS :
0 commit comments