Skip to content

Commit

Permalink
LangChain tracing, with LangGraph tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Twixes committed Jan 21, 2025
1 parent 0384b8c commit 199dfd2
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 147 deletions.
186 changes: 173 additions & 13 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"Please install LangChain to use this feature: 'pip install langchain'"
)

import json
import logging
import time
import uuid
Expand All @@ -30,10 +31,12 @@
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain.schema.agent import AgentAction, AgentFinish
from pydantic import BaseModel

from posthog.ai.utils import get_model_params, with_privacy_mode
from posthog.client import Client
from posthog import default_client

log = logging.getLogger("posthog")

Expand All @@ -53,7 +56,7 @@ class RunMetadata(TypedDict, total=False):

class CallbackHandler(BaseCallbackHandler):
"""
A callback handler for LangChain that sends events to PostHog LLM Observability.
The PostHog LLM observability callback handler for LangChain.
"""

_client: Client
Expand All @@ -74,7 +77,8 @@ class CallbackHandler(BaseCallbackHandler):

def __init__(
self,
client: Client,
client: Optional[Client] = None,
*,
distinct_id: Optional[Union[str, int, float, UUID]] = None,
trace_id: Optional[Union[str, int, float, UUID]] = None,
properties: Optional[Dict[str, Any]] = None,
Expand All @@ -90,7 +94,7 @@ def __init__(
privacy_mode: Whether to redact the input and output of the trace.
groups: Optional additional PostHog groups to use for the trace.
"""
self._client = client
self._client = client or default_client
self._distinct_id = distinct_id
self._trace_id = trace_id
self._properties = properties or {}
Expand All @@ -106,9 +110,12 @@ def on_chain_start(
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(serialized, run_id, inputs, metadata, **kwargs)

def on_chat_model_start(
self,
Expand All @@ -119,6 +126,9 @@ def on_chat_model_start(
parent_run_id: Optional[UUID] = None,
**kwargs,
):
self._log_debug_event(
"on_chat_model_start", run_id, parent_run_id, messages=messages
)
self._set_parent_of_run(run_id, parent_run_id)
input = [
_convert_message_to_dict(message) for row in messages for message in row
Expand All @@ -134,9 +144,58 @@ def on_llm_start(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(serialized, run_id, prompts, **kwargs)

def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
self.log.debug(
f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}"
)

def on_tool_start(
self,
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event(
"on_tool_start", run_id, parent_run_id, input_str=input_str
)

def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)

def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)

def on_chain_end(
self,
outputs: Dict[str, Any],
Expand All @@ -146,7 +205,35 @@ def on_chain_end(
tags: Optional[List[str]] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
self._pop_parent_of_run(run_id)
run_metadata = self._pop_run_metadata(run_id)

if parent_run_id is None:
self._end_trace(
self._get_trace_id(run_id),
inputs=run_metadata.get("messages") if run_metadata else None,
outputs=outputs,
)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
self._pop_parent_of_run(run_id)
run_metadata = self._pop_run_metadata(run_id)

if parent_run_id is None:
self._end_trace(
self._get_trace_id(run_id),
inputs=run_metadata.get("messages") if run_metadata else None,
outputs=None,
)

def on_llm_end(
self,
Expand All @@ -160,6 +247,9 @@ def on_llm_end(
"""
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
"""
self._log_debug_event(
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
)
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
Expand Down Expand Up @@ -207,16 +297,6 @@ def on_llm_end(
groups=self._groups,
)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._pop_parent_of_run(run_id)

def on_llm_error(
self,
error: BaseException,
Expand All @@ -226,6 +306,7 @@ def on_llm_error(
tags: Optional[List[str]] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
Expand Down Expand Up @@ -255,6 +336,51 @@ def on_llm_error(
groups=self._groups,
)

def on_retriever_start(
self,
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)

def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)

def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)

def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)

def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
"""
Set the parent run ID for a chain run. If there is no parent, the run is the root.
Expand Down Expand Up @@ -324,6 +450,40 @@ def _get_trace_id(self, run_id: UUID):
trace_id = uuid.uuid4()
return trace_id

def _end_trace(
self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]
):
event_properties = {
"$ai_trace_id": trace_id,
"$ai_input_state": with_privacy_mode(
self._client, self._privacy_mode, inputs
),
**self._properties,
}
if outputs is not None:
event_properties["$ai_output_state"] = with_privacy_mode(
self._client, self._privacy_mode, outputs
)
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
self._client.capture(
distinct_id=self._distinct_id or trace_id,
event="$ai_trace",
properties=event_properties,
groups=self._groups,
)

def _log_debug_event(
self,
event_name: str,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs,
):
log.debug(
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
)


def _extract_raw_esponse(last_response):
"""Extract the response from the last response of the LLM call."""
Expand Down
1 change: 1 addition & 0 deletions posthog/test/ai/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

pytest.importorskip("langchain")
pytest.importorskip("langchain_community")
pytest.importorskip("langgraph")
Loading

0 comments on commit 199dfd2

Please sign in to comment.