Skip to content

Commit

Permalink
fix(langchain): handle optional serialized values gracefully (#1047)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Thomas <[email protected]>
  • Loading branch information
hassiebp and SebastianThomas1 authored Dec 16, 2024
1 parent 37187cc commit 2324237
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
44 changes: 28 additions & 16 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,36 @@ def on_llm_new_token(

self.updated_completion_start_time_memo.add(run_id)

def get_langchain_run_name(self, serialized: Dict[str, Any], **kwargs: Any) -> str:
"""Retrieves the 'run_name' for an entity based on Langchain convention, prioritizing the 'name'
key in 'kwargs' or falling back to the 'name' or 'id' in 'serialized'. Defaults to "<unknown>"
if none are available.
def get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
"""Retrieve the name of a serialized LangChain runnable.
The prioritization for the determination of the run name is as follows:
- The value assigned to the "name" key in `kwargs`.
- The value assigned to the "name" key in `serialized`.
- The last entry of the value assigned to the "id" key in `serialized`.
- "<unknown>".
Args:
serialized (Dict[str, Any]): A dictionary containing the entity's serialized data.
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
Returns:
str: The determined Langchain run name for the entity.
str: The determined name of the Langchain runnable.
"""
# Check if 'name' is in kwargs and not None, otherwise use default fallback logic
if "name" in kwargs and kwargs["name"] is not None:
return kwargs["name"]

# Fallback to serialized 'name', 'id', or "<unknown>"
return serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
try:
return serialized["name"]
except (KeyError, TypeError):
pass

try:
return serialized["id"][-1]
except (KeyError, TypeError):
pass

return "<unknown>"

def on_retriever_error(
self,
Expand Down Expand Up @@ -196,7 +208,7 @@ def on_retriever_error(

def on_chain_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Dict[str, Any],
*,
run_id: UUID,
Expand Down Expand Up @@ -289,7 +301,7 @@ def _deregister_langfuse_prompt(self, run_id: Optional[UUID]):

def __generate_trace_and_parent(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Union[Dict[str, Any], List[str], str, None],
*,
run_id: UUID,
Expand Down Expand Up @@ -479,7 +491,7 @@ def on_chain_error(

def on_chat_model_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
Expand Down Expand Up @@ -508,7 +520,7 @@ def on_chat_model_start(

def on_llm_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
prompts: List[str],
*,
run_id: UUID,
Expand All @@ -535,7 +547,7 @@ def on_llm_start(

def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -573,7 +585,7 @@ def on_tool_start(

def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -698,7 +710,7 @@ def on_tool_error(

def __on_llm_action(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
run_id: UUID,
prompts: List[str],
parent_run_id: Optional[UUID] = None,
Expand Down
23 changes: 17 additions & 6 deletions langfuse/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _extract_model_name(
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
**kwargs: Any,
):
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
Expand Down Expand Up @@ -106,13 +106,18 @@ def _extract_model_name(


def _extract_model_from_repr_by_pattern(
id: str, serialized: dict, pattern: str, default: Optional[str] = None
id: str, serialized: Optional[Dict[str, Any]], pattern: str, default: Optional[str] = None
):
if serialized is None:
return None

if serialized.get("id")[-1] == id:
if serialized.get("repr"):
extracted = _extract_model_with_regex(pattern, serialized.get("repr"))
return extracted if extracted else default if default else None

return None


def _extract_model_with_regex(pattern: str, text: str):
match = re.search(rf"{pattern}='(.*?)'", text)
Expand All @@ -123,21 +128,27 @@ def _extract_model_with_regex(pattern: str, text: str):

def _extract_model_by_path_for_id(
id: str,
serialized: dict,
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None

if serialized.get("id")[-1] == id:
return _extract_model_by_path(serialized, kwargs, keys, select_from)


def _extract_model_by_path(
serialized: dict,
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None

current_obj = kwargs if select_from == "kwargs" else serialized

for key in keys:
Expand Down

0 comments on commit 2324237

Please sign in to comment.