Skip to content

Commit

Permalink
Override base patch test due to langgraph import structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Jan 9, 2025
1 parent 4c63c72 commit 74a2e87
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/contrib/langgraph/test_langgraph_patch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import sys
from tempfile import NamedTemporaryFile

from ddtrace.contrib.langgraph import get_version
from ddtrace.contrib.langgraph import patch
from ddtrace.contrib.langgraph import unpatch
from tests.contrib.patch import PatchTestCase
from tests.utils import call_program


class TestLangGraphPatch(PatchTestCase.Base):
Expand Down Expand Up @@ -43,3 +48,52 @@ def assert_not_module_double_patched(self, langgraph):
self.assert_not_double_wrapped(Pregel.stream)
self.assert_not_double_wrapped(Pregel.astream)
self.assert_not_double_wrapped(PregelLoop.tick)

def test_ddtrace_run_patch_on_import(self):
# We check that the integration's patch function is called only
# after import of the relevant module when using ddtrace-run.
with NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""
import sys
from ddtrace.internal.module import ModuleWatchdog
from wrapt import wrap_function_wrapper as wrap
patched = False
def patch_hook(module):
def patch_wrapper(wrapped, _, args, kwrags):
global patched
result = wrapped(*args, **kwrags)
sys.stdout.write("K")
patched = True
return result
wrap(module.__name__, module.patch.__name__, patch_wrapper)
ModuleWatchdog.register_module_hook("ddtrace.contrib..patch", patch_hook)
sys.stdout.write("O")
import langgraph as mod
from langgraph import graph
# If the module was already loaded during the sitecustomize
# we check that the module was marked as patched.
if not patched and (
getattr(mod, "__datadog_patch", False) or getattr(mod, "_datadog_patch", False)
):
sys.stdout.write("K")
"""
)
f.flush()

env = os.environ.copy()
env["DD_TRACE_%s_ENABLED" % self.__integration_name__.upper()] = "1"

out, err, _, _ = call_program("ddtrace-run", sys.executable, f.name, env=env)

self.assertEqual(out, b"OK", "stderr:\n%s" % err.decode())

0 comments on commit 74a2e87

Please sign in to comment.