Skip to content

Commit

Permalink
fix(llmobs): replace trace processor with event listener (#11781)
Browse files Browse the repository at this point in the history
The LLMObs service formerly depended on the TraceProcessor interface in
the tracer. This was problematic due to sharing a dependency with the
public API. As such, users could configure a trace filter (under the
hood is a trace processor) and overwrite the LLMObs TraceProcessor.

Instead, the tracer can emit span start and finish events which the
LLMObs service listens to and acts on, as proposed here.

The gotcha is that the LLMObs service no longer has a way to drop traces
when run in agentless mode, which only LLMObs supports. Instead, we
encourage users to explicitly turn off APM which carries the benefit of
clarity since this was implicit before.

Co-authored-by: Yun Kim <[email protected]>
  • Loading branch information
Kyle-Verhoog and Yun-Kim authored Jan 9, 2025
1 parent bfa3b82 commit d676233
Show file tree
Hide file tree
Showing 15 changed files with 798 additions and 1,079 deletions.
32 changes: 32 additions & 0 deletions .riot/requirements/16562eb.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# This file is autogenerated by pip-compile with Python 3.7
# by the following command:
#
# pip-compile --allow-unsafe --config=pyproject.toml --no-annotate --resolver=backtracking .riot/requirements/16562eb.in
#
attrs==24.2.0
coverage[toml]==7.2.7
exceptiongroup==1.2.2
hypothesis==6.45.0
idna==3.10
importlib-metadata==6.7.0
iniconfig==2.0.0
mock==5.1.0
multidict==6.0.5
opentracing==2.4.0
packaging==24.0
pluggy==1.2.0
pytest==7.4.4
pytest-asyncio==0.21.1
pytest-cov==4.1.0
pytest-mock==3.11.1
pyyaml==6.0.1
six==1.17.0
sortedcontainers==2.4.0
tomli==2.0.1
typing-extensions==4.7.1
urllib3==1.26.20
vcrpy==4.4.0
wrapt==1.16.0
yarl==1.9.4
zipp==3.15.0
5 changes: 4 additions & 1 deletion ddtrace/_trace/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ddtrace.internal.atexit import register_on_exit_signal
from ddtrace.internal.constants import SAMPLING_DECISION_TRACE_TAG_KEY
from ddtrace.internal.constants import SPAN_API_DATADOG
from ddtrace.internal.core import dispatch
from ddtrace.internal.dogstatsd import get_dogstatsd_client
from ddtrace.internal.logger import get_logger
from ddtrace.internal.peer_service.processor import PeerServiceProcessor
Expand Down Expand Up @@ -849,7 +850,7 @@ def _start_span(
for p in chain(self._span_processors, SpanProcessor.__processors__, self._deferred_processors):
p.on_span_start(span)
self._hooks.emit(self.__class__.start_span, span)

dispatch("trace.span_start", (span,))
return span

start_span = _start_span
Expand All @@ -866,6 +867,8 @@ def _on_span_finish(self, span: Span) -> None:
for p in chain(self._span_processors, SpanProcessor.__processors__, self._deferred_processors):
p.on_span_finish(span)

dispatch("trace.span_finish", (span,))

if log.isEnabledFor(logging.DEBUG):
log.debug("finishing span %s (enabled:%s)", span._pprint(), self.enabled)

Expand Down
161 changes: 140 additions & 21 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import ddtrace
from ddtrace import Span
from ddtrace import config
from ddtrace import patch
from ddtrace._trace.context import Context
from ddtrace.constants import ERROR_MSG
from ddtrace.constants import ERROR_STACK
from ddtrace.constants import ERROR_TYPE
from ddtrace.ext import SpanTypes
from ddtrace.internal import atexit
from ddtrace.internal import core
from ddtrace.internal import forksafe
from ddtrace.internal._rand import rand64bits
from ddtrace.internal.compat import ensure_text
Expand All @@ -24,6 +30,7 @@
from ddtrace.internal.telemetry.constants import TELEMETRY_APM_PRODUCT
from ddtrace.internal.utils.formats import asbool
from ddtrace.internal.utils.formats import parse_tags_str
from ddtrace.llmobs import _constants as constants
from ddtrace.llmobs._constants import ANNOTATIONS_CONTEXT_ID
from ddtrace.llmobs._constants import INPUT_DOCUMENTS
from ddtrace.llmobs._constants import INPUT_MESSAGES
Expand All @@ -45,11 +52,11 @@
from ddtrace.llmobs._constants import SPAN_START_WHILE_DISABLED_WARNING
from ddtrace.llmobs._constants import TAGS
from ddtrace.llmobs._evaluators.runner import EvaluatorRunner
from ddtrace.llmobs._trace_processor import LLMObsTraceProcessor
from ddtrace.llmobs._utils import AnnotationContext
from ddtrace.llmobs._utils import _get_llmobs_parent_id
from ddtrace.llmobs._utils import _get_ml_app
from ddtrace.llmobs._utils import _get_session_id
from ddtrace.llmobs._utils import _get_span_name
from ddtrace.llmobs._utils import _inject_llmobs_parent_id
from ddtrace.llmobs._utils import safe_json
from ddtrace.llmobs._utils import validate_prompt
Expand Down Expand Up @@ -81,34 +88,157 @@ class LLMObs(Service):
def __init__(self, tracer=None):
super(LLMObs, self).__init__()
self.tracer = tracer or ddtrace.tracer
self._llmobs_span_writer = None

self._llmobs_span_writer = LLMObsSpanWriter(
is_agentless=config._llmobs_agentless_enabled,
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
)

self._llmobs_eval_metric_writer = LLMObsEvalMetricWriter(
site=config._dd_site,
api_key=config._dd_api_key,
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
)

self._evaluator_runner = EvaluatorRunner(
interval=float(os.getenv("_DD_LLMOBS_EVALUATOR_INTERVAL", 1.0)),
llmobs_service=self,
)

self._trace_processor = LLMObsTraceProcessor(self._llmobs_span_writer, self._evaluator_runner)
forksafe.register(self._child_after_fork)

self._annotations = []
self._annotation_context_lock = forksafe.RLock()
self.tracer.on_start_span(self._do_annotations)

def _do_annotations(self, span):
# Register hooks for span events
core.on("trace.span_start", self._do_annotations)
core.on("trace.span_finish", self._on_span_finish)

def _on_span_finish(self, span):
if self.enabled and span.span_type == SpanTypes.LLM:
self._submit_llmobs_span(span)

def _submit_llmobs_span(self, span: Span) -> None:
"""Generate and submit an LLMObs span event to be sent to LLMObs."""
span_event = None
is_llm_span = span._get_ctx_item(SPAN_KIND) == "llm"
is_ragas_integration_span = False
try:
span_event, is_ragas_integration_span = self._llmobs_span_event(span)
self._llmobs_span_writer.enqueue(span_event)
except (KeyError, TypeError):
log.error(
"Error generating LLMObs span event for span %s, likely due to malformed span", span, exc_info=True
)
finally:
if not span_event or not is_llm_span or is_ragas_integration_span:
return
if self._evaluator_runner:
self._evaluator_runner.enqueue(span_event, span)

@classmethod
def _llmobs_span_event(cls, span: Span) -> Tuple[Dict[str, Any], bool]:
"""Span event object structure."""
span_kind = span._get_ctx_item(SPAN_KIND)
if not span_kind:
raise KeyError("Span kind not found in span context")
meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None:
meta["model_name"] = span._get_ctx_item(MODEL_NAME)
meta["model_provider"] = (span._get_ctx_item(MODEL_PROVIDER) or "custom").lower()
meta["metadata"] = span._get_ctx_item(METADATA) or {}
if span._get_ctx_item(INPUT_PARAMETERS):
meta["input"]["parameters"] = span._get_ctx_item(INPUT_PARAMETERS)
if span_kind == "llm" and span._get_ctx_item(INPUT_MESSAGES) is not None:
meta["input"]["messages"] = span._get_ctx_item(INPUT_MESSAGES)
if span._get_ctx_item(INPUT_VALUE) is not None:
meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE))
if span_kind == "llm" and span._get_ctx_item(OUTPUT_MESSAGES) is not None:
meta["output"]["messages"] = span._get_ctx_item(OUTPUT_MESSAGES)
if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None:
meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS)
if span._get_ctx_item(OUTPUT_VALUE) is not None:
meta["output"]["value"] = safe_json(span._get_ctx_item(OUTPUT_VALUE))
if span_kind == "retrieval" and span._get_ctx_item(OUTPUT_DOCUMENTS) is not None:
meta["output"]["documents"] = span._get_ctx_item(OUTPUT_DOCUMENTS)
if span._get_ctx_item(INPUT_PROMPT) is not None:
prompt_json_str = span._get_ctx_item(INPUT_PROMPT)
if span_kind != "llm":
log.warning(
"Dropping prompt on non-LLM span kind, annotating prompts is only supported for LLM span kinds."
)
else:
meta["input"]["prompt"] = prompt_json_str
if span.error:
meta.update(
{
ERROR_MSG: span.get_tag(ERROR_MSG),
ERROR_STACK: span.get_tag(ERROR_STACK),
ERROR_TYPE: span.get_tag(ERROR_TYPE),
}
)
if not meta["input"]:
meta.pop("input")
if not meta["output"]:
meta.pop("output")
metrics = span._get_ctx_item(METRICS) or {}
ml_app = _get_ml_app(span)

is_ragas_integration_span = False

if ml_app.startswith(constants.RAGAS_ML_APP_PREFIX):
is_ragas_integration_span = True

span._set_ctx_item(ML_APP, ml_app)
parent_id = str(_get_llmobs_parent_id(span) or "undefined")

llmobs_span_event = {
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(span.span_id),
"parent_id": parent_id,
"name": _get_span_name(span),
"start_ns": span.start_ns,
"duration": span.duration_ns,
"status": "error" if span.error else "ok",
"meta": meta,
"metrics": metrics,
}
session_id = _get_session_id(span)
if session_id is not None:
span._set_ctx_item(SESSION_ID, session_id)
llmobs_span_event["session_id"] = session_id

llmobs_span_event["tags"] = cls._llmobs_tags(
span, ml_app, session_id, is_ragas_integration_span=is_ragas_integration_span
)
return llmobs_span_event, is_ragas_integration_span

@staticmethod
def _llmobs_tags(
span: Span, ml_app: str, session_id: Optional[str] = None, is_ragas_integration_span: bool = False
) -> List[str]:
tags = {
"version": config.version or "",
"env": config.env or "",
"service": span.service or "",
"source": "integration",
"ml_app": ml_app,
"ddtrace.version": ddtrace.__version__,
"language": "python",
"error": span.error,
}
err_type = span.get_tag(ERROR_TYPE)
if err_type:
tags["error_type"] = err_type
if session_id:
tags["session_id"] = session_id
if is_ragas_integration_span:
tags[constants.RUNNER_IS_INTEGRATION_SPAN_TAG] = "ragas"
existing_tags = span._get_ctx_item(TAGS)
if existing_tags is not None:
tags.update(existing_tags)
return ["{}:{}".format(k, v) for k, v in tags.items()]

def _do_annotations(self, span: Span) -> None:
# get the current span context
# only do the annotations if it matches the context
if span.span_type != SpanTypes.LLM: # do this check to avoid the warning log in `annotate`
Expand All @@ -120,20 +250,14 @@ def _do_annotations(self, span):
if current_context_id == context_id:
self.annotate(span, **annotation_kwargs)

def _child_after_fork(self):
def _child_after_fork(self) -> None:
self._llmobs_span_writer = self._llmobs_span_writer.recreate()
self._llmobs_eval_metric_writer = self._llmobs_eval_metric_writer.recreate()
self._evaluator_runner = self._evaluator_runner.recreate()
self._trace_processor._span_writer = self._llmobs_span_writer
self._trace_processor._evaluator_runner = self._evaluator_runner
if self.enabled:
self._start_service()

def _start_service(self) -> None:
tracer_filters = self.tracer._filters
if not any(isinstance(tracer_filter, LLMObsTraceProcessor) for tracer_filter in tracer_filters):
tracer_filters += [self._trace_processor]
self.tracer.configure(settings={"FILTERS": tracer_filters})
try:
self._llmobs_span_writer.start()
self._llmobs_eval_metric_writer.start()
Expand All @@ -160,11 +284,7 @@ def _stop_service(self) -> None:
except ServiceStatusError:
log.debug("Error stopping LLMObs writers")

try:
forksafe.unregister(self._child_after_fork)
self.tracer.shutdown()
except Exception:
log.warning("Failed to shutdown tracer", exc_info=True)
forksafe.unregister(self._child_after_fork)

@classmethod
def enable(
Expand Down Expand Up @@ -265,7 +385,6 @@ def disable(cls) -> None:

cls._instance.stop()
cls.enabled = False
cls._instance.tracer.deregister_on_start_span(cls._instance._do_annotations)
telemetry_writer.product_activated(TELEMETRY_APM_PRODUCT.LLMOBS, False)

log.debug("%s disabled", cls.__name__)
Expand Down
Loading

0 comments on commit d676233

Please sign in to comment.