diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index c79b0fc4c84fb..df04e0f88f610 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -657,6 +657,7 @@ async def _arun_llm( tags: Optional[List[str]] = None, callbacks: Callbacks = None, input_mapper: Optional[Callable[[Dict], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> Union[str, BaseMessage]: """Asynchronously run the language model. @@ -682,7 +683,9 @@ async def _arun_llm( ): return await llm.ainvoke( prompt_or_messages, - config=RunnableConfig(callbacks=callbacks, tags=tags or []), + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) else: raise InputFormatError( @@ -695,12 +698,18 @@ async def _arun_llm( try: prompt = _get_prompt(inputs) llm_output: Union[str, BaseMessage] = await llm.ainvoke( - prompt, config=RunnableConfig(callbacks=callbacks, tags=tags or []) + prompt, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) except InputFormatError: messages = _get_messages(inputs) llm_output = await llm.ainvoke( - messages, config=RunnableConfig(callbacks=callbacks, tags=tags or []) + messages, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) return llm_output @@ -712,6 +721,7 @@ async def _arun_chain( *, tags: Optional[List[str]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> Union[dict, str]: """Run a chain asynchronously on inputs.""" inputs_ = inputs if input_mapper is None else input_mapper(inputs) @@ -723,10 +733,15 @@ async def _arun_chain( ): val = next(iter(inputs_.values())) output = await chain.ainvoke( - val, config=RunnableConfig(callbacks=callbacks, tags=tags or []) + val, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, metadata=metadata or {} + ) output = await chain.ainvoke(inputs_, config=runnable_config) return output @@ -762,6 +777,7 @@ async def _arun_llm_or_chain( tags=config["tags"], callbacks=config["callbacks"], input_mapper=input_mapper, + metadata=config.get("metadata"), ) else: chain = llm_or_chain_factory() @@ -771,6 +787,7 @@ async def _arun_llm_or_chain( tags=config["tags"], callbacks=config["callbacks"], input_mapper=input_mapper, + metadata=config.get("metadata"), ) result = output except Exception as e: @@ -793,6 +810,7 @@ def _run_llm( *, tags: Optional[List[str]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> Union[str, BaseMessage]: """ Run the language model on the example. @@ -819,7 +837,9 @@ def _run_llm( ): llm_output: Union[str, BaseMessage] = llm.invoke( prompt_or_messages, - config=RunnableConfig(callbacks=callbacks, tags=tags or []), + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) else: raise InputFormatError( @@ -831,12 +851,16 @@ def _run_llm( try: llm_prompts = _get_prompt(inputs) llm_output = llm.invoke( - llm_prompts, config=RunnableConfig(callbacks=callbacks, tags=tags or []) + llm_prompts, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) except InputFormatError: llm_messages = _get_messages(inputs) llm_output = llm.invoke( - llm_messages, config=RunnableConfig(callbacks=callbacks) + llm_messages, + config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}), ) return llm_output @@ -848,6 +872,7 @@ def _run_chain( *, tags: Optional[List[str]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> Union[Dict, str]: """Run a chain on inputs.""" inputs_ = inputs if input_mapper is None else input_mapper(inputs) @@ -859,10 +884,15 @@ def _run_chain( ): val = next(iter(inputs_.values())) output = chain.invoke( - val, config=RunnableConfig(callbacks=callbacks, tags=tags or []) + val, + config=RunnableConfig( + callbacks=callbacks, tags=tags or [], metadata=metadata or {} + ), ) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, metadata=metadata or {} + ) output = chain.invoke(inputs_, config=runnable_config) return output @@ -899,6 +929,7 @@ def _run_llm_or_chain( config["callbacks"], tags=config["tags"], input_mapper=input_mapper, + metadata=config.get("metadata"), ) else: chain = llm_or_chain_factory() @@ -908,6 +939,7 @@ def _run_llm_or_chain( config["callbacks"], tags=config["tags"], input_mapper=input_mapper, + metadata=config.get("metadata"), ) result = output except Exception as e: @@ -1083,8 +1115,13 @@ def prepare( input_mapper: Optional[Callable[[Dict], Any]] = None, concurrency_level: int = 5, project_metadata: Optional[Dict[str, Any]] = None, + revision_id: Optional[str] = None, ) -> _DatasetRunContainer: project_name = project_name or name_generation.random_name() + if revision_id: + if not project_metadata: + project_metadata = {} + project_metadata.update({"revision_id": revision_id}) wrapped_model, project, dataset, examples = _prepare_eval_run( client, dataset_name, @@ -1121,6 +1158,7 @@ def prepare( ], tags=tags, max_concurrency=concurrency_level, + metadata={"revision_id": revision_id} if revision_id else {}, ) for example in examples ] @@ -1183,6 +1221,7 @@ async def arun_on_dataset( project_metadata: Optional[Dict[str, Any]] = None, verbose: bool = False, tags: Optional[List[str]] = None, + revision_id: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: input_mapper = kwargs.pop("input_mapper", None) @@ -1208,6 +1247,7 @@ async def arun_on_dataset( input_mapper, concurrency_level, project_metadata=project_metadata, + revision_id=revision_id, ) batch_results = await runnable_utils.gather_with_concurrency( container.configs[0].get("max_concurrency"), @@ -1235,6 +1275,7 @@ def run_on_dataset( project_metadata: Optional[Dict[str, Any]] = None, verbose: bool = False, tags: Optional[List[str]] = None, + revision_id: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: input_mapper = kwargs.pop("input_mapper", None) @@ -1260,6 +1301,7 @@ def run_on_dataset( input_mapper, concurrency_level, project_metadata=project_metadata, + revision_id=revision_id, ) if concurrency_level == 0: batch_results = [ @@ -1309,6 +1351,8 @@ def run_on_dataset( log feedback and run traces. verbose: Whether to print progress. tags: Tags to add to each run in the project. + revision_id: Optional revision identifier to assign this test run to + track the performance of different versions of your system. Returns: A dictionary containing the run's project name and the resulting model outputs.