diff --git a/langfuse/callback/langchain.py b/langfuse/callback/langchain.py index 7a4f9094..1c327aab 100644 --- a/langfuse/callback/langchain.py +++ b/langfuse/callback/langchain.py @@ -149,24 +149,36 @@ def on_llm_new_token( self.updated_completion_start_time_memo.add(run_id) - def get_langchain_run_name(self, serialized: Dict[str, Any], **kwargs: Any) -> str: - """Retrieves the 'run_name' for an entity based on Langchain convention, prioritizing the 'name' - key in 'kwargs' or falling back to the 'name' or 'id' in 'serialized'. Defaults to "" - if none are available. + def get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str: + """Retrieve the name of a serialized LangChain runnable. + + The prioritization for the determination of the run name is as follows: + - The value assigned to the "name" key in `kwargs`. + - The value assigned to the "name" key in `serialized`. + - The last entry of the value assigned to the "id" key in `serialized`. + - "". Args: - serialized (Dict[str, Any]): A dictionary containing the entity's serialized data. + serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data. **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. Returns: - str: The determined Langchain run name for the entity. + str: The determined name of the Langchain runnable. """ - # Check if 'name' is in kwargs and not None, otherwise use default fallback logic if "name" in kwargs and kwargs["name"] is not None: return kwargs["name"] - # Fallback to serialized 'name', 'id', or "" - return serialized.get("name", serialized.get("id", [""])[-1]) + try: + return serialized["name"] + except (KeyError, TypeError): + pass + + try: + return serialized["id"][-1] + except (KeyError, TypeError): + pass + + return "" def on_retriever_error( self, @@ -196,7 +208,7 @@ def on_retriever_error( def on_chain_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], inputs: Dict[str, Any], *, run_id: UUID, @@ -289,7 +301,7 @@ def _deregister_langfuse_prompt(self, run_id: Optional[UUID]): def __generate_trace_and_parent( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], inputs: Union[Dict[str, Any], List[str], str, None], *, run_id: UUID, @@ -479,7 +491,7 @@ def on_chain_error( def on_chat_model_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], messages: List[List[BaseMessage]], *, run_id: UUID, @@ -508,7 +520,7 @@ def on_chat_model_start( def on_llm_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], prompts: List[str], *, run_id: UUID, @@ -535,7 +547,7 @@ def on_llm_start( def on_tool_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], input_str: str, *, run_id: UUID, @@ -573,7 +585,7 @@ def on_tool_start( def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], query: str, *, run_id: UUID, @@ -698,7 +710,7 @@ def on_tool_error( def __on_llm_action( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], run_id: UUID, prompts: List[str], parent_run_id: Optional[UUID] = None, diff --git a/langfuse/extract_model.py b/langfuse/extract_model.py index ba365b53..19252284 100644 --- a/langfuse/extract_model.py +++ b/langfuse/extract_model.py @@ -10,7 +10,7 @@ def _extract_model_name( - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], **kwargs: Any, ): """Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse.""" @@ -106,13 +106,18 @@ def _extract_model_name( def _extract_model_from_repr_by_pattern( - id: str, serialized: dict, pattern: str, default: Optional[str] = None + id: str, serialized: Optional[Dict[str, Any]], pattern: str, default: Optional[str] = None ): + if serialized is None: + return None + if serialized.get("id")[-1] == id: if serialized.get("repr"): extracted = _extract_model_with_regex(pattern, serialized.get("repr")) return extracted if extracted else default if default else None + return None + def _extract_model_with_regex(pattern: str, text: str): match = re.search(rf"{pattern}='(.*?)'", text) @@ -123,21 +128,27 @@ def _extract_model_with_regex(pattern: str, text: str): def _extract_model_by_path_for_id( id: str, - serialized: dict, + serialized: Optional[Dict[str, Any]], kwargs: dict, keys: List[str], - select_from: str = Literal["serialized", "kwargs"], + select_from: Literal["serialized", "kwargs"], ): + if serialized is None and select_from == "serialized": + return None + if serialized.get("id")[-1] == id: return _extract_model_by_path(serialized, kwargs, keys, select_from) def _extract_model_by_path( - serialized: dict, + serialized: Optional[Dict[str, Any]], kwargs: dict, keys: List[str], - select_from: str = Literal["serialized", "kwargs"], + select_from: Literal["serialized", "kwargs"], ): + if serialized is None and select_from == "serialized": + return None + current_obj = kwargs if select_from == "kwargs" else serialized for key in keys: