From bfb8ad0a9d53c1f15cdb791d12c4805273690b88 Mon Sep 17 00:00:00 2001 From: David Kurokawa Date: Fri, 25 Oct 2024 13:52:34 -0700 Subject: [PATCH] Always ensure endpoint context variable is cleaned up. (#1589) * Ensure we always clean up the ContextVar as otherwise we may try to wrap things outside of the app invocation if the app invocation fails in the middle. * Add in test. --- src/core/trulens/core/feedback/endpoint.py | 41 ++++++------ tests/e2e/test_context_variables.py | 78 ++++++++++++++++++++++ 2 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 tests/e2e/test_context_variables.py diff --git a/src/core/trulens/core/feedback/endpoint.py b/src/core/trulens/core/feedback/endpoint.py index d44bace70..c1b0c58c7 100644 --- a/src/core/trulens/core/feedback/endpoint.py +++ b/src/core/trulens/core/feedback/endpoint.py @@ -603,29 +603,30 @@ def _track_costs( # following call to retrieve. endpoints_token = Endpoint._context_endpoints.set(endpoints) # noqa: F841 - # context_vars = contextvars.copy_context() - context_vars = { - Endpoint._context_endpoints: Endpoint._context_endpoints.get() - } - - # Call the function. - result: T = __func(*args, **kwargs) - - def rewrap(result): - if python_utils.is_lazy(result): - return python_utils.wrap_lazy( - result, - wrap=rewrap, - context_vars=context_vars, - ) + try: + # context_vars = contextvars.copy_context() + context_vars = { + Endpoint._context_endpoints: Endpoint._context_endpoints.get() + } - return result + # Call the function. + result: T = __func(*args, **kwargs) + + def rewrap(result): + if python_utils.is_lazy(result): + return python_utils.wrap_lazy( + result, + wrap=rewrap, + context_vars=context_vars, + ) - result = rewrap(result) + return result - # Pop the endpoints from the contextvars. - # Optionally disable to debug context issues. See App._set_context_vars. - Endpoint._context_endpoints.reset(endpoints_token) + result = rewrap(result) + finally: + # Pop the endpoints from the contextvars. + # Optionally disable to debug context issues. See App._set_context_vars. + Endpoint._context_endpoints.reset(endpoints_token) # Return result and only the callbacks created here. Outer thunks might # return others. diff --git a/tests/e2e/test_context_variables.py b/tests/e2e/test_context_variables.py new file mode 100644 index 000000000..8203b606f --- /dev/null +++ b/tests/e2e/test_context_variables.py @@ -0,0 +1,78 @@ +""" +Tests for context variable issues. +""" + +import os +import unittest + +import openai +from snowflake.snowpark import Session +from trulens.apps.custom import TruCustomApp +from trulens.apps.custom import instrument +from trulens.connectors.snowflake import SnowflakeConnector +from trulens.core import TruSession + +from tests.test import optional_test + + +class TestContextVariables(unittest.TestCase): + def setUp(self): + connection_parameters = { + "account": os.environ["SNOWFLAKE_ACCOUNT"], + "user": os.environ["SNOWFLAKE_USER"], + "password": os.environ["SNOWFLAKE_USER_PASSWORD"], + "database": os.environ["SNOWFLAKE_DATABASE"], + "role": os.environ["SNOWFLAKE_ROLE"], + "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"], + "schema": "TestContextVariables", + } + self._snowpark_session = Session.builder.configs( + connection_parameters + ).create() + connector = SnowflakeConnector( + **connection_parameters, init_server_side=True + ) + self._session = TruSession(connector=connector) + + @optional_test + def test_endpoint_contextvar_always_cleaned(self): + class FailingRAG: + oai_client = openai.OpenAI() + + @instrument + def retrieve(self, query: str) -> list: + return ["A", "B", "C"] + + @instrument + def generate_completion(self, query: str, context_str: list) -> str: + raise ValueError() + + @instrument + def query(self, query: str) -> str: + context_str = self.retrieve(query=query) + completion = self.generate_completion( + query=query, context_str=context_str + ) + return completion + + # Set up trulens. + rag = FailingRAG() + tru_rag = TruCustomApp( + rag, + app_name="FailingRAG", + app_version="base", + ) + + with tru_rag: + self.assertRaises(ValueError, rag.query, "X") + # During app invocation, the endpoint context variable is set to track + # costs, but because in this test it fails prematurely, we must make + # sure the context variable is cleaned up properly. When it's set, + # the snowflake.snowpark.Session.sql function is handled differently + # in such a way that the following call will fail. + # TODO: find a better way to check if the context variable is cleaned. + self._snowpark_session.sql("SELECT 1").collect() + + +if __name__ == "__main__": + unittest.main()