Skip to content

feat(llmobs): add span processor #13426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ddtrace/llmobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from ._llmobs import LLMObs
from ._llmobs import LLMObsSpan


__all__ = ["LLMObs"]
__all__ = ["LLMObs", "LLMObsSpan"]
98 changes: 88 additions & 10 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from dataclasses import dataclass
from dataclasses import field
import json
import os
import time
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import TypedDict
from typing import Union
from typing import cast

import ddtrace
from ddtrace import config
Expand Down Expand Up @@ -97,14 +102,50 @@
}


@dataclass
class LLMObsSpan:
"""LLMObs span object.

Passed to the `span_processor` function in the `enable` or `register_processor` methods.

Example::
def span_processor(span: LLMObsSpan) -> LLMObsSpan:
if span.get_tag("no_input") == "1":
span.input = []
return span
"""

class Message(TypedDict):
content: str
role: str

input: List[Message] = field(default_factory=list)
output: List[Message] = field(default_factory=list)
_tags: Dict[str, str] = field(default_factory=dict)

def get_tag(self, key: str) -> Optional[str]:
"""Get a tag from the span.

:param str key: The key of the tag to get.
:return: The value of the tag or None if the tag does not exist.
:rtype: Optional[str]
"""
return self._tags.get(key)


class LLMObs(Service):
_instance = None # type: LLMObs
enabled = False

def __init__(self, tracer=None):
def __init__(
self,
tracer: Tracer = None,
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
):
super(LLMObs, self).__init__()
self.tracer = tracer or ddtrace.tracer
self._llmobs_context_provider = LLMObsContextProvider()
self._user_span_processor = span_processor
agentless_enabled = config._llmobs_agentless_enabled if config._llmobs_agentless_enabled is not None else True
self._llmobs_span_writer = LLMObsSpanWriter(
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
Expand Down Expand Up @@ -156,12 +197,14 @@ def _submit_llmobs_span(self, span: Span) -> None:
if self._evaluator_runner:
self._evaluator_runner.enqueue(span_event, span)

@classmethod
def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
def _llmobs_span_event(self, span: Span) -> Dict[str, Any]:
"""Span event object structure."""
span_kind = span._get_ctx_item(SPAN_KIND)
if not span_kind:
raise KeyError("Span kind not found in span context")

llmobs_span = LLMObsSpan()

meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None:
meta["model_name"] = span._get_ctx_item(MODEL_NAME)
Expand All @@ -170,14 +213,14 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:

input_messages = span._get_ctx_item(INPUT_MESSAGES)
if span_kind == "llm" and input_messages is not None:
meta["input"]["messages"] = enforce_message_role(input_messages)
llmobs_span.input = cast(List[LLMObsSpan.Message], enforce_message_role(input_messages))

if span._get_ctx_item(INPUT_VALUE) is not None:
meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE), ensure_ascii=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do the same as above to cast the input value to llmobs_span.input?


output_messages = span._get_ctx_item(OUTPUT_MESSAGES)
if span_kind == "llm" and output_messages is not None:
meta["output"]["messages"] = enforce_message_role(output_messages)
llmobs_span.output = cast(List[LLMObsSpan.Message], enforce_message_role(output_messages))

if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None:
meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS)
Expand All @@ -201,6 +244,26 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
ERROR_TYPE: span.get_tag(ERROR_TYPE),
}
)

if self._user_span_processor:
error = False
try:
llmobs_span._tags = cast(Dict[str, str], span._get_ctx_item(TAGS))
user_llmobs_span = self._user_span_processor(llmobs_span)
if not isinstance(user_llmobs_span, LLMObsSpan):
raise TypeError("User span processor must return an LLMObsSpan, got %r" % type(user_llmobs_span))
llmobs_span = user_llmobs_span
except Exception as e:
log.error("Error in LLMObs span processor (%r): %r", self._user_span_processor, e)
error = True
finally:
telemetry.record_llmobs_user_processor_called(error)

if llmobs_span.input is not None:
meta["input"]["messages"] = llmobs_span.input
if llmobs_span.output is not None:
meta["output"]["messages"] = llmobs_span.output

if not meta["input"]:
meta.pop("input")
if not meta["output"]:
Expand Down Expand Up @@ -228,7 +291,7 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
span._set_ctx_item(SESSION_ID, session_id)
llmobs_span_event["session_id"] = session_id

llmobs_span_event["tags"] = cls._llmobs_tags(span, ml_app, session_id)
llmobs_span_event["tags"] = self._llmobs_tags(span, ml_app, session_id)

span_links = span._get_ctx_item(SPAN_LINKS)
if isinstance(span_links, list) and span_links:
Expand Down Expand Up @@ -332,6 +395,7 @@ def enable(
api_key: Optional[str] = None,
env: Optional[str] = None,
service: Optional[str] = None,
span_processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None,
_tracer: Optional[Tracer] = None,
_auto: bool = False,
) -> None:
Expand All @@ -345,6 +409,8 @@ def enable(
:param str api_key: Your datadog api key.
:param str env: Your environment name.
:param str service: Your service name.
:param Callable[[LLMObsSpan], LLMObsSpan] span_processor: A function that takes an LLMObsSpan and returns an
LLMObsSpan.
"""
if cls.enabled:
log.debug("%s already enabled", cls.__name__)
Expand Down Expand Up @@ -372,9 +438,9 @@ def enable(
)

config._llmobs_agentless_enabled = should_use_agentless(
user_defined_agentless_enabled=agentless_enabled
if agentless_enabled is not None
else config._llmobs_agentless_enabled
user_defined_agentless_enabled=(
agentless_enabled if agentless_enabled is not None else config._llmobs_agentless_enabled
)
)

if config._llmobs_agentless_enabled:
Expand Down Expand Up @@ -404,7 +470,7 @@ def enable(
cls._patch_integrations()

# override the default _instance with a new tracer
cls._instance = cls(tracer=_tracer)
cls._instance = cls(tracer=_tracer, span_processor=span_processor)
cls.enabled = True
cls._instance.start()

Expand All @@ -427,6 +493,18 @@ def enable(
finally:
telemetry.record_llmobs_enabled(error, config._llmobs_agentless_enabled, config._dd_site, start_ns, _auto)

@classmethod
def register_processor(cls, processor: Optional[Callable[[LLMObsSpan], LLMObsSpan]] = None) -> None:
"""Register a processor to be called on each LLMObs span.

This can be used to modify the span before it is sent to LLMObs. For example, you can modify the input/output.

To deregister the processor, call `register_processor(None)`.

:param processor: A function that takes an LLMObsSpan and returns an LLMObsSpan.
"""
cls._instance._user_span_processor = processor

@classmethod
def _integration_is_enabled(cls, integration: str) -> bool:
if integration not in SUPPORTED_LLMOBS_INTEGRATIONS:
Expand Down
11 changes: 11 additions & 0 deletions ddtrace/llmobs/_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LLMObsTelemetryMetrics:
USER_FLUSHES = "user_flush"
INJECT_HEADERS = "inject_distributed_headers"
ACTIVATE_HEADERS = "activate_distributed_headers"
USER_PROCESSOR_CALLED = "user_processor_called"


def _find_integration_from_tags(tags):
Expand Down Expand Up @@ -156,6 +157,16 @@ def record_llmobs_annotate(span: Optional[Span], error: Optional[str]):
)


def record_llmobs_user_processor_called(error: bool) -> None:
tags = [("error", "1" if error else "0")]
telemetry_writer.add_count_metric(
namespace=TELEMETRY_NAMESPACE.MLOBS,
name=LLMObsTelemetryMetrics.USER_PROCESSOR_CALLED,
value=1,
tags=tuple(tags),
)


def record_llmobs_submit_evaluation(join_on: Dict[str, Any], metric_type: str, error: Optional[str]):
_metric_type = metric_type if metric_type in ("categorical", "score") else "other"
custom_joining_key = str(int(join_on.get("tag") is not None))
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/llmobs-processor-d5cb47b12bc3bbd1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
LLM Observability: add processor capability to process span inputs and outputs. See usage documentation [here](https://docs.datadoghq.com/llm_observability/setup/sdk/python/#span-processing).
2 changes: 2 additions & 0 deletions tests/llmobs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ def _expected_span_link(span_event, link_from, link_to):


class TestLLMObsSpanWriter(LLMObsSpanWriter):
__test__ = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.events = []
Expand Down
16 changes: 13 additions & 3 deletions tests/llmobs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http.server import HTTPServer
import json
import os
import pprint
import threading
import time

Expand Down Expand Up @@ -224,26 +225,35 @@ def _llmobs_backend():

@pytest.fixture
def llmobs_backend(_llmobs_backend):
_, reqs = _llmobs_backend
_url, reqs = _llmobs_backend

class _LLMObsBackend:
def url(self):
return _url

def wait_for_num_events(self, num, attempts=1000):
for _ in range(attempts):
if len(reqs) == num:
return [json.loads(r["body"]) for r in reqs]
# time.sleep will yield the GIL so the server can process the request
time.sleep(0.001)
else:
raise TimeoutError(f"Expected {num} events, got {len(reqs)}")
raise TimeoutError(f"Expected {num} events, got {len(reqs)}: {pprint.pprint(reqs)}")

return _LLMObsBackend()


@pytest.fixture
def llmobs_enable_opts():
yield {}


@pytest.fixture
def llmobs(
ddtrace_global_config,
monkeypatch,
tracer,
llmobs_enable_opts,
llmobs_env,
llmobs_span_writer,
mock_llmobs_eval_metric_writer,
Expand All @@ -256,7 +266,7 @@ def llmobs(
global_config.update(ddtrace_global_config)
# TODO: remove once rest of tests are moved off of global config tampering
with override_global_config(global_config):
llmobs_service.enable(_tracer=tracer)
llmobs_service.enable(_tracer=tracer, **llmobs_enable_opts)
llmobs_service._instance._llmobs_span_writer = llmobs_span_writer
llmobs_service._instance._llmobs_span_writer.start()
yield llmobs_service
Expand Down
Loading
Loading