From ab57de9c8f6cd770eaf888cc31df817ec138bdca Mon Sep 17 00:00:00 2001 From: Josh Reini <60949774+joshreini1@users.noreply.github.com> Date: Fri, 22 Dec 2023 18:12:13 -0500 Subject: [PATCH] Add shortcut to select_context() (#706) * adjust docstring for select_context * langchain select_context, update quickstarts * undo app name change * remove dev cell * generalized langchain select_context (#711) * generalized langchain select_context * typo * typo in string * update langchain example to pass app in select_context --------- Co-authored-by: Josh Reini * comments, clarity updates to quickstarts * add lib-independent select_context * update lc li quickstarts --------- Co-authored-by: Piotr Mardziel --- .../quickstart/langchain_quickstart.ipynb | 64 ++++----- .../quickstart/llama_index_quickstart.ipynb | 125 ++++++++++++------ trulens_eval/trulens_eval/app.py | 19 +++ .../feedback/provider/endpoint/base.py | 11 ++ .../feedback/provider/endpoint/openai.py | 2 +- trulens_eval/trulens_eval/instruments.py | 34 ++--- trulens_eval/trulens_eval/schema.py | 5 + trulens_eval/trulens_eval/tru_chain.py | 49 ++++++- trulens_eval/trulens_eval/tru_llama.py | 10 ++ trulens_eval/trulens_eval/utils/python.py | 35 +++-- trulens_eval/trulens_eval/utils/threading.py | 20 ++- 11 files changed, 272 insertions(+), 102 deletions(-) diff --git a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb index 63415ac97..cc46c7a0b 100644 --- a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb +++ b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb @@ -38,7 +38,7 @@ "outputs": [], "source": [ "import os\n", - "os.environ[\"OPENAI_API_KEY\"] = \"...\"" + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" ] }, { @@ -116,7 +116,8 @@ "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", "splits = text_splitter.split_documents(docs)\n", "\n", - "vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())" + "vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(\n", + "))" ] }, { @@ -180,16 +181,21 @@ "outputs": [], "source": [ "from trulens_eval.feedback.provider import OpenAI\n", - "from trulens_eval import Select\n", "import numpy as np\n", + "\n", "# Initialize provider class\n", "openai = OpenAI()\n", + "\n", + "# select context to be used in feedback. the location of context is app specific.\n", + "from trulens_eval.app import App\n", + "context = App.select_context(rag_chain)\n", + "\n", "from trulens_eval.feedback import Groundedness\n", "grounded = Groundedness(groundedness_provider=OpenAI())\n", "# Define a groundedness feedback function\n", "f_groundedness = (\n", " Feedback(grounded.groundedness_measure_with_cot_reasons)\n", - " .on(Select.RecordCalls.first.invoke.rets.context)\n", + " .on(context.collect()) # collect context chunks into a list\n", " .on_output()\n", " .aggregate(grounded.grounded_statements_aggregator)\n", ")\n", @@ -199,10 +205,10 @@ "# Question/statement relevance between question and each context chunk.\n", "f_context_relevance = (\n", " Feedback(openai.qs_relevance)\n", - " .on(Select.RecordCalls.first.invoke.args.input)\n", - " .on(Select.RecordCalls.first.invoke.rets.context)\n", + " .on_input()\n", + " .on(context)\n", " .aggregate(np.mean)\n", - ")" + " )" ] }, { @@ -236,15 +242,6 @@ "display(llm_response)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tru.run_dashboard()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -288,11 +285,14 @@ ] }, { - "attachments": {}, - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Explore in a Dashboard" + "records, feedback = tru.get_records_and_feedback(app_ids=[\"Chain1_ChatApplication\"])\n", + "\n", + "records.head()" ] }, { @@ -301,9 +301,7 @@ "metadata": {}, "outputs": [], "source": [ - "tru.run_dashboard() # open a local streamlit app to explore\n", - "\n", - "# tru.stop_dashboard() # stop if needed" + "tru.get_leaderboard(app_ids=[\"Chain1_ChatApplication\"])" ] }, { @@ -311,15 +309,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternatively, you can run `trulens-eval` from a command line in the same folder to start the dashboard." + "## Explore in a Dashboard" ] }, { - "attachments": {}, - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Note: Feedback functions evaluated in the deferred manner can be seen in the \"Progress\" page of the TruLens dashboard." + "tru.run_dashboard() # open a local streamlit app to explore\n", + "\n", + "# tru.stop_dashboard() # stop if needed" ] }, { @@ -327,16 +328,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Or view results directly in your notebook" + "Alternatively, you can run `trulens-eval` from a command line in the same folder to start the dashboard." ] }, { - "cell_type": "code", - "execution_count": null, + "attachments": {}, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "tru.get_records_and_feedback(app_ids=[])[0] # pass an empty list of app_ids to get all" + "Note: Feedback functions evaluated in the deferred manner can be seen in the \"Progress\" page of the TruLens dashboard." ] } ], @@ -356,7 +356,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.0" }, "vscode": { "interpreter": { diff --git a/trulens_eval/examples/quickstart/llama_index_quickstart.ipynb b/trulens_eval/examples/quickstart/llama_index_quickstart.ipynb index cd7cd3e0e..0191d6e68 100644 --- a/trulens_eval/examples/quickstart/llama_index_quickstart.ipynb +++ b/trulens_eval/examples/quickstart/llama_index_quickstart.ipynb @@ -50,7 +50,7 @@ "outputs": [], "source": [ "import os\n", - "os.environ[\"OPENAI_API_KEY\"] = \"...\"" + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" ] }, { @@ -58,28 +58,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Import from LlamaIndex and TruLens" + "### Import from TruLens" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "🦑 Tru initialized with db url sqlite:///default.sqlite .\n", - "🛑 Secret keys may be written to the database. See the `database_redact_keys` option of `Tru` to prevent this.\n" - ] - } - ], + "outputs": [], "source": [ - "from trulens_eval import Feedback, Tru, TruLlama\n", - "from trulens_eval.feedback import Groundedness\n", - "from trulens_eval.feedback.provider.openai import OpenAI\n", - "\n", + "from trulens_eval import Tru\n", "tru = Tru()" ] }, @@ -145,23 +133,36 @@ "import numpy as np\n", "\n", "# Initialize provider class\n", + "from trulens_eval.feedback.provider.openai import OpenAI\n", "openai = OpenAI()\n", "\n", - "grounded = Groundedness(groundedness_provider=OpenAI())\n", + "# select context to be used in feedback. the location of context is app specific.\n", + "from trulens_eval.app import App\n", + "context = App.select_context(query_engine)\n", + "\n", + "# imports for feedback\n", + "from trulens_eval import Feedback\n", "\n", "# Define a groundedness feedback function\n", - "f_groundedness = Feedback(grounded.groundedness_measure_with_cot_reasons).on(\n", - " TruLlama.select_source_nodes().node.text.collect()\n", - " ).on_output(\n", - " ).aggregate(grounded.grounded_statements_aggregator)\n", + "from trulens_eval.feedback import Groundedness\n", + "grounded = Groundedness(groundedness_provider=OpenAI())\n", + "f_groundedness = (\n", + " Feedback(grounded.groundedness_measure_with_cot_reasons)\n", + " .on(context.collect()) # collect context chunks into a list\n", + " .on_output()\n", + " .aggregate(grounded.grounded_statements_aggregator)\n", + ")\n", "\n", "# Question/answer relevance between overall question and answer.\n", "f_qa_relevance = Feedback(openai.relevance).on_input_output()\n", "\n", "# Question/statement relevance between question and each context chunk.\n", - "f_qs_relevance = Feedback(openai.qs_relevance).on_input().on(\n", - " TruLlama.select_source_nodes().node.text\n", - " ).aggregate(np.mean)" + "f_qs_relevance = (\n", + " Feedback(openai.qs_relevance)\n", + " .on_input()\n", + " .on(context)\n", + " .aggregate(np.mean)\n", + ")" ] }, { @@ -178,6 +179,7 @@ "metadata": {}, "outputs": [], "source": [ + "from trulens_eval import TruLlama\n", "tru_query_engine_recorder = TruLlama(query_engine,\n", " app_id='LlamaIndex_App1',\n", " feedbacks=[f_groundedness, f_qa_relevance, f_qs_relevance])" @@ -195,11 +197,10 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Explore in a Dashboard" + "## Retrieve records and feedback" ] }, { @@ -208,25 +209,55 @@ "metadata": {}, "outputs": [], "source": [ - "tru.run_dashboard() # open a local streamlit app to explore\n", + "# The record of the ap invocation can be retrieved from the `recording`:\n", "\n", - "# tru.stop_dashboard() # stop if needed" + "rec = recording.get() # use .get if only one record\n", + "# recs = recording.records # use .records if multiple\n", + "\n", + "display(rec)" ] }, { - "attachments": {}, - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Alternatively, you can run `trulens-eval` from a command line in the same folder to start the dashboard." + "# The results of the feedback functions can be rertireved from the record. These\n", + "# are `Future` instances (see `concurrent.futures`). You can use `as_completed`\n", + "# to wait until they have finished evaluating.\n", + "\n", + "from trulens_eval.schema import FeedbackResult\n", + "\n", + "from concurrent.futures import as_completed\n", + "\n", + "for feedback_future in as_completed(rec.feedback_results):\n", + " feedback, feedback_result = feedback_future.result()\n", + " \n", + " feedback: Feedback\n", + " feedbac_result: FeedbackResult\n", + "\n", + " display(feedback.name, feedback_result.result)" ] }, { - "attachments": {}, - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "records, feedback = tru.get_records_and_feedback(app_ids=[\"LlamaIndex_App1\"])\n", + "\n", + "records.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Note: Feedback functions evaluated in the deferred manner can be seen in the \"Progress\" page of the TruLens dashboard." + "tru.get_leaderboard(app_ids=[\"LlamaIndex_App1\"])" ] }, { @@ -234,7 +265,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Or view results directly in your notebook" + "## Explore in a Dashboard" ] }, { @@ -243,7 +274,25 @@ "metadata": {}, "outputs": [], "source": [ - "tru.get_records_and_feedback(app_ids=[])[0] # pass an empty list of app_ids to get all" + "tru.run_dashboard() # open a local streamlit app to explore\n", + "\n", + "# tru.stop_dashboard() # stop if needed" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, you can run `trulens-eval` from a command line in the same folder to start the dashboard." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: Feedback functions evaluated in the deferred manner can be seen in the \"Progress\" page of the TruLens dashboard." ] } ], @@ -263,7 +312,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.0" }, "vscode": { "interpreter": { diff --git a/trulens_eval/trulens_eval/app.py b/trulens_eval/trulens_eval/app.py index 614a8960d..f8476f62b 100644 --- a/trulens_eval/trulens_eval/app.py +++ b/trulens_eval/trulens_eval/app.py @@ -458,6 +458,25 @@ def __init__( self.tru_post_init() + @classmethod + def select_context( + cls, + app: Optional[Any] = None + ) -> Lens: + if app is None: + raise ValueError("Could not determine context selection without `app` argument.") + + # Checking by module name so we don't have to try to import either + # langchain or llama_index beforehand. + if type(app).__module__.startswith("langchain"): + from trulens_eval.tru_chain import TruChain + return TruChain.select_context(app) + elif type(app).__module__.startswith("llama_index"): + from trulens_eval.tru_llama import TruLlama + return TruLlama.select_context(app) + else: + raise ValueError(f"Could not determine context from unrecognized `app` type {type(app)}.") + def __hash__(self): return hash(id(self)) diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py index 640e4dcd7..760fcbe1f 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py @@ -349,6 +349,15 @@ def metawrap(*args, **kwargs): setattr(cls, wrapper_method_name, metawrap) def _instrument_module_members(self, mod: ModuleType, method_name: str): + if not safe_hasattr(mod, INSTRUMENT): + setattr(mod, INSTRUMENT, set()) + + already_instrumented = safe_getattr(mod, INSTRUMENT) + + if method_name in already_instrumented: + logger.debug(f"module {mod} already instrumented for {method_name}") + return + for m in dir(mod): logger.debug( f"instrumenting module {mod} member {m} for method {method_name}" @@ -357,6 +366,8 @@ def _instrument_module_members(self, mod: ModuleType, method_name: str): obj = safe_getattr(mod, m) self._instrument_class(obj, method_name=method_name) + already_instrumented.add(method_name) + # TODO: CODEDUP @staticmethod def track_all_costs( diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py index edec2cb81..bf7a4946f 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py @@ -23,7 +23,7 @@ import inspect import logging import pprint -from typing import Any, Callable, ClassVar, List, Optional, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Union from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.schema import Generation diff --git a/trulens_eval/trulens_eval/instruments.py b/trulens_eval/trulens_eval/instruments.py index 87d0d67f1..3dcd3b3cf 100644 --- a/trulens_eval/trulens_eval/instruments.py +++ b/trulens_eval/trulens_eval/instruments.py @@ -422,7 +422,10 @@ def tracked_method_wrapper( sig = safe_signature(func) - async def awrapper(*args, **kwargs): + def find_instrumented(f): + return id(f) in [id(tru_awrapper.__code__), id(tru_wrapper.__code__)] + + async def tru_awrapper(*args, **kwargs): # TODO: figure out how to have less repetition between the async and # sync versions of this method. @@ -430,14 +433,11 @@ async def awrapper(*args, **kwargs): f"{query}: calling instrumented async method {func}" ) # DIFF - apps = getattr(awrapper, Instrument.APPS) # DIFF + apps = getattr(tru_awrapper, Instrument.APPS) # DIFF # If not within a root method, call the wrapped function without # any recording. - def find_instrumented(f): - return id(f) in [id(awrapper.__code__)] # DIFF - # Get any contexts already known from higher in the call stack. contexts = get_first_local_in_call_stack( key="contexts", @@ -613,20 +613,17 @@ def find_instrumented(f): return rets - def wrapper(*args, **kwargs): + def tru_wrapper(*args, **kwargs): # TODO: figure out how to have less repetition between the async and # sync versions of this method. logger.debug(f"{query}: calling instrumented method {func}") - apps = getattr(wrapper, Instrument.APPS) + apps = getattr(tru_wrapper, Instrument.APPS) # If not within a root method, call the wrapped function without # any recording. - def find_instrumented(f): - return id(f) in [id(wrapper.__code__), id(awrapper.__code__)] - # Get any contexts already known from higher in the call stack. contexts = get_first_local_in_call_stack( key="contexts", func=find_instrumented, offset=1 @@ -794,9 +791,9 @@ def find_instrumented(f): return rets - w = wrapper + w = tru_wrapper if inspect.iscoroutinefunction(func): - w = awrapper + w = tru_awrapper # Indicate that the wrapper is an instrumented method so that we dont # further instrument it in another layer accidentally. @@ -974,7 +971,12 @@ def instrument_object( ) ) - if self.to_instrument_object(obj): + if self.to_instrument_object(obj) or isinstance(obj, (dict, list, tuple)): + vals = None + if isinstance(obj, dict): + attrs = obj.keys() + vals = obj.values() + if isinstance(obj, pydantic.BaseModel): # NOTE(piotrm): This will not include private fields like # llama_index's LLMPredictor._llm which might be useful to @@ -994,9 +996,11 @@ def instrument_object( # not so this section applies. attrs = clean_attributes(obj, include_props=True).keys() - for k in attrs: - v = safe_getattr(obj, k, get_prop=True) + if vals is None: + vals = [safe_getattr(obj, k, get_prop=True) for k in attrs] + for k, v in zip(attrs, vals): + if isinstance(v, (str, bool, int, float)): pass diff --git a/trulens_eval/trulens_eval/schema.py b/trulens_eval/trulens_eval/schema.py index f68895dcd..4245dc0ef 100644 --- a/trulens_eval/trulens_eval/schema.py +++ b/trulens_eval/trulens_eval/schema.py @@ -280,6 +280,11 @@ class Select: # The whole output of the first called / last returned method call. RecordRets: Query = RecordCall.rets + @staticmethod + def context(app: Optional[Any] = None) -> Lens: + from trulens_eval.app import App + return App.select_context(app) + @staticmethod def for_record(query: Select.Query) -> Query: return Select.Query(path=Select.Record.path + query.path) diff --git a/trulens_eval/trulens_eval/tru_chain.py b/trulens_eval/trulens_eval/tru_chain.py index 0e70bf1c2..99c3ee6f8 100644 --- a/trulens_eval/trulens_eval/tru_chain.py +++ b/trulens_eval/trulens_eval/tru_chain.py @@ -6,7 +6,7 @@ from inspect import Signature import logging from pprint import PrettyPrinter -from typing import Any, Callable, ClassVar, Dict, List, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple # import nest_asyncio # NOTE(piotrm): disabling for now, need more investigation from pydantic import Field @@ -14,12 +14,16 @@ from trulens_eval.app import App from trulens_eval.instruments import Instrument from trulens_eval.schema import Record +from trulens_eval.schema import Select from trulens_eval.utils.imports import OptionalImports from trulens_eval.utils.imports import REQUIREMENT_LANGCHAIN +from trulens_eval.utils.json import jsonify from trulens_eval.utils.langchain import WithFeedbackFilterDocuments from trulens_eval.utils.pyschema import Class from trulens_eval.utils.pyschema import FunctionOrMethod from trulens_eval.utils.python import safe_hasattr +from trulens_eval.utils.serial import all_queries +from trulens_eval.utils.serial import Lens logger = logging.getLogger(__name__) @@ -28,7 +32,7 @@ with OptionalImports(messages=REQUIREMENT_LANGCHAIN): # langchain.agents.agent.AgentExecutor, # is langchain.chains.base.Chain # import langchain - + from langchain.agents.agent import BaseMultiActionAgent from langchain.agents.agent import BaseSingleActionAgent from langchain.chains.base import Chain @@ -96,6 +100,10 @@ class Default: lambda o: isinstance(o, (RunnableSerializable)), "_aget_relevant_documents": lambda o: isinstance(o, (RunnableSerializable)), + "get_relevant_documents": + lambda o: isinstance(o, (RunnableSerializable)), + "aget_relevant_documents": + lambda o: isinstance(o, (RunnableSerializable)), # "format_prompt": lambda o: isinstance(o, langchain.prompts.base.BasePromptTemplate), # "format": lambda o: isinstance(o, langchain.prompts.base.BasePromptTemplate), # the prompt calls might be too small to be interesting @@ -210,6 +218,43 @@ def __init__(self, app: Chain, **kwargs): super().__init__(**kwargs) + @classmethod + def select_context(cls, app: Optional[Chain] = None) -> Lens: + """ + Get the path to the context in the query output. + """ + + if app is None: + raise ValueError( + "langchain app/chain is required to determine context for langchain apps. " + "Pass it in as the `app` argument" + ) + + retrievers = [] + + app_json = jsonify(app) + for lens in all_queries(app_json): + try: + comp = lens.get_sole_item(app) + if isinstance(comp, BaseRetriever): + retrievers.append((lens, comp)) + + except Exception: + pass + + if len(retrievers) == 0: + raise ValueError("Cannot find any `BaseRetriever` in app.") + + if len(retrievers) > 1: + raise ValueError( + "Found more than one `BaseRetriever` in app:\n\t" + \ + ("\n\t".join(map( + lambda lr: f"{type(lr[1])} at {lr[0]}", + retrievers))) + ) + + return (Select.RecordCalls + retrievers[0][0]).get_relevant_documents.rets + # TODEP # Chain requirement @property diff --git a/trulens_eval/trulens_eval/tru_llama.py b/trulens_eval/trulens_eval/tru_llama.py index 1c434795f..6b5f794db 100644 --- a/trulens_eval/trulens_eval/tru_llama.py +++ b/trulens_eval/trulens_eval/tru_llama.py @@ -251,6 +251,16 @@ def select_source_nodes(cls) -> Lens: """ return cls.select_outputs().source_nodes[:] + @classmethod + def select_context( + cls, + app: Optional[Union[BaseQueryEngine, BaseChatEngine]] = None + ) -> Lens: + """ + Get the path to the context in the query output. + """ + return cls.select_outputs().source_nodes[:].node.text + def main_input( self, func: Callable, sig: Signature, bindings: BoundArguments ) -> str: diff --git a/trulens_eval/trulens_eval/utils/python.py b/trulens_eval/trulens_eval/utils/python.py index bb13e7e9a..527283682 100644 --- a/trulens_eval/trulens_eval/utils/python.py +++ b/trulens_eval/trulens_eval/utils/python.py @@ -185,21 +185,22 @@ def stack_with_tasks() -> Sequence['frame']: ret = [fi.frame for fi in inspect.stack()[1:]] # skip stack_with_task_stack - logger.debug("Getting cross-Task stacks. Current stack:") - for f in ret: - logger.debug(f"\t{f}") + # Need a more verbose debug mode for these: + #logger.debug("Getting cross-Task stacks. Current stack:") + #for f in ret: + # logger.debug(f"\t{f}") try: task_stack = get_task_stack(asyncio.current_task()) - logger.debug(f"Merging in stack from {asyncio.current_task()}:") - for s in task_stack: - logger.debug(f"\t{s}") + #logger.debug(f"Merging in stack from {asyncio.current_task()}:") + #for s in task_stack: + # logger.debug(f"\t{s}") temp = merge_stacks(ret, task_stack) - logger.debug(f"Complete stack:") - for f in temp: - logger.debug(f"\t{f}") + #logger.debug(f"Complete stack:") + #for f in temp: + # logger.debug(f"\t{f}") return temp @@ -207,7 +208,7 @@ def stack_with_tasks() -> Sequence['frame']: return ret -def _future_target_wrapper(stack, func, *args, **kwargs): +def _future_target_wrapper(stack, context, func, *args, **kwargs): """ Wrapper for a function that is started by threads. This is needed to record the call stack prior to thread creation as in python threads do @@ -220,6 +221,10 @@ def _future_target_wrapper(stack, func, *args, **kwargs): # Keep this for looking up via get_first_local_in_call_stack . pre_start_stack = stack + for var, value in context.items(): + logger.debug(f"Copying context var {var} to thread.") + var.set(value) + return func(*args, **kwargs) @@ -244,10 +249,13 @@ class above. with async tasks. In those cases, the `skip` argument is more reliable. """ - logger.debug(f"Looking for local '{key}' in the stack.") + # TODO: Need a more verbose mode for these: + # logger.debug(f"Looking for local '{key}' in the stack.") if skip is not None: - logger.debug(f"Will be skipping {skip}.") + pass + # TODO: verbose debug + # logger.debug(f"Will be skipping {skip}.") frames = stack_with_tasks()[1:] # + 1 to skip this method itself # NOTE: skipping offset frames is done below since the full stack may need @@ -261,7 +269,8 @@ class above. while not q.empty(): f = q.get() - logger.debug(f"{f.f_code}") + # TODO: verbose debug + # logger.debug(f"{f.f_code}") if id(f.f_code) == id(_future_target_wrapper.__code__): logger.debug( diff --git a/trulens_eval/trulens_eval/utils/threading.py b/trulens_eval/trulens_eval/utils/threading.py index d9b3e85ab..e7e6b776a 100644 --- a/trulens_eval/trulens_eval/utils/threading.py +++ b/trulens_eval/trulens_eval/utils/threading.py @@ -5,6 +5,7 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor as fThreadPoolExecutor from concurrent.futures import TimeoutError +import contextvars from inspect import stack import logging import threading @@ -28,12 +29,29 @@ class ThreadPoolExecutor(fThreadPoolExecutor): invocation. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def submit(self, fn, /, *args, **kwargs): present_stack = stack() + present_context = contextvars.copy_context() return super().submit( - _future_target_wrapper, present_stack, fn, *args, **kwargs + _future_target_wrapper, present_stack, present_context, fn, *args, **kwargs ) +# Attempt other users of ThreadPoolExecutor to use our version. +import concurrent + +concurrent.futures.ThreadPoolExecutor = ThreadPoolExecutor +concurrent.futures.thread.ThreadPoolExecutor = ThreadPoolExecutor + +# Hack to try to make langchain use our ThreadPoolExecutor as the above doesn't +# seem to do the trick. +try: + import langchain_core + langchain_core.runnables.config.ThreadPoolExecutor = ThreadPoolExecutor +except Exception: + pass class TP(SingletonPerName['TP']): # "thread processing"