diff --git a/newrelic/hooks/mlmodel_langchain.py b/newrelic/hooks/mlmodel_langchain.py index ff8e0381d..11de2817f 100644 --- a/newrelic/hooks/mlmodel_langchain.py +++ b/newrelic/hooks/mlmodel_langchain.py @@ -20,11 +20,11 @@ from newrelic.api.function_trace import FunctionTrace from newrelic.api.time_trace import current_trace, get_trace_linking_metadata from newrelic.api.transaction import current_transaction -from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper +from newrelic.common.object_wrapper import wrap_function_wrapper from newrelic.common.package_version_utils import get_package_version from newrelic.common.signature import bind_args from newrelic.core.config import global_settings -from newrelic.core.context import ContextOf +from newrelic.core.context import context_wrapper _logger = logging.getLogger(__name__) LANGCHAIN_VERSION = get_package_version("langchain") @@ -125,16 +125,6 @@ } -def context_wrapper(func, trace=None, request=None, trace_cache_id=None, strict=True): - @function_wrapper - def _context_wrapper(wrapped, instance, args, kwargs): - bound_args = bind_args(wrapped, kwargs["args"], kwargs["kwargs"]) - with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id, strict=strict): - return wrapped(**bound_args) - - return _context_wrapper(func) - - def wrap_ContextThreadPoolExecutor_submit(wrapped, instance, args, kwargs): trace = current_trace() if not trace: @@ -142,8 +132,7 @@ def wrap_ContextThreadPoolExecutor_submit(wrapped, instance, args, kwargs): bound_args = bind_args(wrapped, args, kwargs) bound_args["func"] = context_wrapper(bound_args["func"], trace=trace, strict=True) - - return wrapped(**bound_args) + return wrapped(bound_args["func"], *bound_args["args"], **bound_args["kwargs"]) def _create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped):