Skip to content

Commit

Permalink
Always ensure endpoint context variable is cleaned up. (#1589)
Browse files Browse the repository at this point in the history
* Ensure we always clean up the ContextVar as otherwise we may try to wrap things outside of the app invocation if the app invocation fails in the middle.

* Add in test.
  • Loading branch information
sfc-gh-dkurokawa authored Oct 25, 2024
1 parent ca8843e commit bfb8ad0
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 20 deletions.
41 changes: 21 additions & 20 deletions src/core/trulens/core/feedback/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,29 +603,30 @@ def _track_costs(
# following call to retrieve.
endpoints_token = Endpoint._context_endpoints.set(endpoints) # noqa: F841

# context_vars = contextvars.copy_context()
context_vars = {
Endpoint._context_endpoints: Endpoint._context_endpoints.get()
}

# Call the function.
result: T = __func(*args, **kwargs)

def rewrap(result):
if python_utils.is_lazy(result):
return python_utils.wrap_lazy(
result,
wrap=rewrap,
context_vars=context_vars,
)
try:
# context_vars = contextvars.copy_context()
context_vars = {
Endpoint._context_endpoints: Endpoint._context_endpoints.get()
}

return result
# Call the function.
result: T = __func(*args, **kwargs)

def rewrap(result):
if python_utils.is_lazy(result):
return python_utils.wrap_lazy(
result,
wrap=rewrap,
context_vars=context_vars,
)

result = rewrap(result)
return result

# Pop the endpoints from the contextvars.
# Optionally disable to debug context issues. See App._set_context_vars.
Endpoint._context_endpoints.reset(endpoints_token)
result = rewrap(result)
finally:
# Pop the endpoints from the contextvars.
# Optionally disable to debug context issues. See App._set_context_vars.
Endpoint._context_endpoints.reset(endpoints_token)

# Return result and only the callbacks created here. Outer thunks might
# return others.
Expand Down
78 changes: 78 additions & 0 deletions tests/e2e/test_context_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Tests for context variable issues.
"""

import os
import unittest

import openai
from snowflake.snowpark import Session
from trulens.apps.custom import TruCustomApp
from trulens.apps.custom import instrument
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core import TruSession

from tests.test import optional_test


class TestContextVariables(unittest.TestCase):
def setUp(self):
connection_parameters = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USER"],
"password": os.environ["SNOWFLAKE_USER_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"role": os.environ["SNOWFLAKE_ROLE"],
"warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
"schema": "TestContextVariables",
}
self._snowpark_session = Session.builder.configs(
connection_parameters
).create()
connector = SnowflakeConnector(
**connection_parameters, init_server_side=True
)
self._session = TruSession(connector=connector)

@optional_test
def test_endpoint_contextvar_always_cleaned(self):
class FailingRAG:
oai_client = openai.OpenAI()

@instrument
def retrieve(self, query: str) -> list:
return ["A", "B", "C"]

@instrument
def generate_completion(self, query: str, context_str: list) -> str:
raise ValueError()

@instrument
def query(self, query: str) -> str:
context_str = self.retrieve(query=query)
completion = self.generate_completion(
query=query, context_str=context_str
)
return completion

# Set up trulens.
rag = FailingRAG()
tru_rag = TruCustomApp(
rag,
app_name="FailingRAG",
app_version="base",
)

with tru_rag:
self.assertRaises(ValueError, rag.query, "X")
# During app invocation, the endpoint context variable is set to track
# costs, but because in this test it fails prematurely, we must make
# sure the context variable is cleaned up properly. When it's set,
# the snowflake.snowpark.Session.sql function is handled differently
# in such a way that the following call will fail.
# TODO: find a better way to check if the context variable is cleaned.
self._snowpark_session.sql("SELECT 1").collect()


if __name__ == "__main__":
unittest.main()

0 comments on commit bfb8ad0

Please sign in to comment.