Skip to content

Commit 900bfd8

Browse files
add type annotations
1 parent 2db4102 commit 900bfd8

File tree

1 file changed

+60
-49
lines changed

1 file changed

+60
-49
lines changed

tests/profiling_v2/collector/test_threading.py

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import sys
44
import threading
5-
from typing import Any
5+
from typing import Any, Optional, Type, Union
66
import uuid
77

88
import mock
@@ -17,35 +17,38 @@
1717
from tests.profiling.collector.lock_utils import init_linenos
1818

1919

20-
# Module-level globals for testing global lock profiling
21-
_test_global_lock = None
22-
_test_global_bar_instance = None
20+
# Type aliases for supported classes
21+
LockClass = Union[Type[threading.Lock], Type[threading.RLock]]
22+
CollectorClass = Union[Type[collector_threading.ThreadingLockCollector], Type[collector_threading.ThreadingRLockCollector]]
2323

24+
# Module-level globals for testing global lock profiling
25+
_test_global_lock: Optional[Any] = None
26+
_test_global_bar_instance: Optional[Any] = None
2427

25-
TESTING_GEVENT = os.getenv("DD_PROFILE_TEST_GEVENT", False)
28+
TESTING_GEVENT: Union[str, bool] = os.getenv("DD_PROFILE_TEST_GEVENT", False)
2629

2730
init_linenos(__file__)
2831

2932

3033
# Helper classes for testing lock collector
3134
class Foo:
32-
def __init__(self, lock_class: Any):
35+
def __init__(self, lock_class: LockClass) -> None:
3336
self.foo_lock = lock_class() # !CREATE! foolock
3437

35-
def foo(self):
38+
def foo(self) -> None:
3639
with self.foo_lock: # !RELEASE! !ACQUIRE! foolock
3740
pass
3841

3942

4043
class Bar:
41-
def __init__(self, lock_class: Any):
44+
def __init__(self, lock_class: LockClass) -> None:
4245
self.foo = Foo(lock_class)
4346

44-
def bar(self):
47+
def bar(self) -> None:
4548
self.foo.foo()
4649

4750

48-
def test_repr():
51+
def test_repr() -> None:
4952
test_collector._test_repr(
5053
collector_threading.ThreadingLockCollector,
5154
"ThreadingLockCollector(status=<ServiceStatus.STOPPED: 'stopped'>, "
@@ -54,7 +57,7 @@ def test_repr():
5457
)
5558

5659

57-
def test_patch():
60+
def test_patch() -> None:
5861
lock = threading.Lock
5962
collector = collector_threading.ThreadingLockCollector()
6063
collector.start()
@@ -69,7 +72,7 @@ def test_patch():
6972
@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="only works on linux")
7073
@pytest.mark.subprocess(err=None)
7174
# For macOS: Could print 'Error uploading' but okay to ignore since we are checking if native_id is set
72-
def test_user_threads_have_native_id():
75+
def test_user_threads_have_native_id() -> None:
7376
from os import getpid
7477
from threading import Thread
7578
from threading import _MainThread
@@ -96,13 +99,16 @@ def test_user_threads_have_native_id():
9699
for _ in range(10):
97100
try:
98101
# The TID should be higher than the PID, but not too high
99-
assert 0 < t.native_id - getpid() < 100, (t.native_id, getpid())
102+
native_id = getattr(t, 'native_id', None)
103+
if native_id is not None:
104+
assert 0 < native_id - getpid() < 100, (native_id, getpid())
105+
break
106+
else:
107+
raise AttributeError("native_id not set yet")
100108
except AttributeError:
101109
# The native_id attribute is set by the thread so we might have to
102110
# wait a bit for it to be set.
103111
sleep(0.1)
104-
else:
105-
break
106112
else:
107113
raise AssertionError("Thread.native_id not set")
108114

@@ -114,7 +120,7 @@ def test_user_threads_have_native_id():
114120
@pytest.mark.subprocess(
115121
env=dict(WRAPT_DISABLE_EXTENSIONS="True", DD_PROFILING_FILE_PATH=__file__),
116122
)
117-
def test_wrapt_disable_extensions():
123+
def test_wrapt_disable_extensions() -> None:
118124
import os
119125
import threading
120126

@@ -181,7 +187,7 @@ def test_wrapt_disable_extensions():
181187
@pytest.mark.subprocess(
182188
env=dict(DD_PROFILING_FILE_PATH=__file__),
183189
)
184-
def test_lock_gevent_tasks():
190+
def test_lock_gevent_tasks() -> None:
185191
from gevent import monkey
186192

187193
monkey.patch_all()
@@ -207,7 +213,7 @@ def test_lock_gevent_tasks():
207213

208214
init_linenos(os.environ["DD_PROFILING_FILE_PATH"])
209215

210-
def play_with_lock():
216+
def play_with_lock() -> None:
211217
lock = threading.Lock() # !CREATE! test_lock_gevent_tasks
212218
lock.acquire() # !ACQUIRE! test_lock_gevent_tasks
213219
lock.release() # !RELEASE! test_lock_gevent_tasks
@@ -261,18 +267,22 @@ def play_with_lock():
261267

262268

263269
class BaseThreadingLockCollectorTest:
270+
test_name: str
271+
pprof_prefix: str
272+
output_filename: str
273+
264274
# These should be implemented by child classes
265275
@property
266-
def collector_class(self):
276+
def collector_class(self) -> CollectorClass:
267277
raise NotImplementedError("Child classes must implement collector_class")
268278

269279
@property
270-
def lock_class(self):
280+
def lock_class(self) -> LockClass:
271281
raise NotImplementedError("Child classes must implement lock_class")
272282

273283
# setup_method and teardown_method which will be called before and after
274284
# each test method, respectively, part of pytest api.
275-
def setup_method(self, method):
285+
def setup_method(self, method: Any) -> None:
276286
self.test_name = method.__name__
277287
self.pprof_prefix = "/tmp" + os.sep + self.test_name
278288
# The output filename will be /tmp/method_name.<pid>.<counter>.
@@ -285,7 +295,7 @@ def setup_method(self, method):
285295
ddup.config(env="test", service=self.test_name, version="my_version", output_filename=self.pprof_prefix)
286296
ddup.start()
287297

288-
def teardown_method(self, method):
298+
def teardown_method(self, method: Any) -> None:
289299
# might be unnecessary but this will ensure that the file is removed
290300
# after each successful test, and when a test fails it's easier to
291301
# pinpoint and debug.
@@ -295,15 +305,15 @@ def teardown_method(self, method):
295305
except Exception as e:
296306
print("Error removing file: {}".format(e))
297307

298-
def test_wrapper(self):
308+
def test_wrapper(self) -> None:
299309
# TODO: change to collector_class
300310
collector = collector_threading.ThreadingLockCollector()
301311
with collector:
302312

303313
class Foobar(object):
304314
lock_class = threading.Lock
305315

306-
def __init__(self):
316+
def __init__(self) -> None:
307317
lock = self.lock_class()
308318
assert lock.acquire()
309319
lock.release()
@@ -316,7 +326,7 @@ def __init__(self):
316326
Foobar()
317327

318328
# Tests
319-
def test_lock_events(self):
329+
def test_lock_events(self) -> None:
320330
# The first argument is the recorder.Recorder which is used for the
321331
# v1 exporter. We don't need it for the v2 exporter.
322332
with self.collector_class(capture_pct=100):
@@ -348,12 +358,12 @@ def test_lock_events(self):
348358
],
349359
)
350360

351-
def test_lock_acquire_events_class(self):
361+
def test_lock_acquire_events_class(self) -> None:
352362
with self.collector_class(capture_pct=100):
353363
lock_class = self.lock_class # Capture for inner class
354364

355365
class Foobar(object):
356-
def lockfunc(self):
366+
def lockfunc(self) -> None:
357367
lock = lock_class() # !CREATE! test_lock_acquire_events_class
358368
lock.acquire() # !ACQUIRE! test_lock_acquire_events_class
359369

@@ -376,7 +386,7 @@ def lockfunc(self):
376386
],
377387
)
378388

379-
def test_lock_events_tracer(self, tracer):
389+
def test_lock_events_tracer(self, tracer: Any) -> None:
380390
tracer._endpoint_call_counter_span_processor.enable()
381391
resource = str(uuid.uuid4())
382392
span_type = ext.SpanTypes.WEB
@@ -437,7 +447,7 @@ def test_lock_events_tracer(self, tracer):
437447
],
438448
)
439449

440-
def test_lock_events_tracer_non_web(self, tracer):
450+
def test_lock_events_tracer_non_web(self, tracer: Any) -> None:
441451
tracer._endpoint_call_counter_span_processor.enable()
442452
resource = str(uuid.uuid4())
443453
span_type = ext.SpanTypes.SQL
@@ -479,7 +489,7 @@ def test_lock_events_tracer_non_web(self, tracer):
479489
],
480490
)
481491

482-
def test_lock_events_tracer_late_finish(self, tracer):
492+
def test_lock_events_tracer_late_finish(self, tracer: Any) -> None:
483493
tracer._endpoint_call_counter_span_processor.enable()
484494
resource = str(uuid.uuid4())
485495
span_type = ext.SpanTypes.WEB
@@ -534,7 +544,7 @@ def test_lock_events_tracer_late_finish(self, tracer):
534544
],
535545
)
536546

537-
def test_resource_not_collected(self, tracer):
547+
def test_resource_not_collected(self, tracer: Any) -> None:
538548
tracer._endpoint_call_counter_span_processor.enable()
539549
resource = str(uuid.uuid4())
540550
span_type = ext.SpanTypes.WEB
@@ -595,7 +605,7 @@ def test_resource_not_collected(self, tracer):
595605
],
596606
)
597607

598-
def test_lock_enter_exit_events(self):
608+
def test_lock_enter_exit_events(self) -> None:
599609
with self.collector_class(capture_pct=100):
600610
th_lock = self.lock_class() # !CREATE! test_lock_enter_exit_events
601611
with th_lock: # !ACQUIRE! !RELEASE! test_lock_enter_exit_events
@@ -631,7 +641,7 @@ def test_lock_enter_exit_events(self):
631641
"inspect_dir_enabled",
632642
[True, False],
633643
)
634-
def test_class_member_lock(self, inspect_dir_enabled):
644+
def test_class_member_lock(self, inspect_dir_enabled: bool) -> None:
635645
with mock.patch("ddtrace.settings.profiling.config.lock.name_inspect_dir", inspect_dir_enabled):
636646
expected_lock_name = "foo_lock" if inspect_dir_enabled else None
637647

@@ -670,12 +680,12 @@ def test_class_member_lock(self, inspect_dir_enabled):
670680
],
671681
)
672682

673-
def test_private_lock(self):
683+
def test_private_lock(self) -> None:
674684
class Foo:
675-
def __init__(self, lock_class: Any):
685+
def __init__(self, lock_class: LockClass) -> None:
676686
self.__lock = lock_class() # !CREATE! test_private_lock
677687

678-
def foo(self):
688+
def foo(self) -> None:
679689
with self.__lock: # !RELEASE! !ACQUIRE! test_private_lock
680690
pass
681691

@@ -709,12 +719,12 @@ def foo(self):
709719
],
710720
)
711721

712-
def test_inner_lock(self):
722+
def test_inner_lock(self) -> None:
713723
class Bar:
714-
def __init__(self, lock_class: Any):
724+
def __init__(self, lock_class: LockClass) -> None:
715725
self.foo = Foo(lock_class)
716726

717-
def bar(self):
727+
def bar(self) -> None:
718728
with self.foo.foo_lock: # !RELEASE! !ACQUIRE! test_inner_lock
719729
pass
720730

@@ -749,7 +759,7 @@ def bar(self):
749759
],
750760
)
751761

752-
def test_anonymous_lock(self):
762+
def test_anonymous_lock(self) -> None:
753763
with self.collector_class(capture_pct=100):
754764
with self.lock_class(): # !CREATE! !ACQUIRE! !RELEASE! test_anonymous_lock
755765
pass
@@ -776,23 +786,24 @@ def test_anonymous_lock(self):
776786
],
777787
)
778788

779-
def test_global_locks(self):
789+
def test_global_locks(self) -> None:
780790
global _test_global_lock, _test_global_bar_instance
781791

782792
with self.collector_class(capture_pct=100):
783793
# Create true module-level globals
784794
_test_global_lock = self.lock_class() # !CREATE! _test_global_lock
785795

786796
class TestBar:
787-
def __init__(self, lock_class: Any):
797+
def __init__(self, lock_class: LockClass) -> None:
788798
self.bar_lock = lock_class() # !CREATE! bar_lock
789799

790-
def bar(self):
800+
def bar(self) -> None:
791801
with self.bar_lock: # !ACQUIRE! !RELEASE! bar_lock
792802
pass
793803

794-
def foo():
804+
def foo() -> None:
795805
global _test_global_lock
806+
assert _test_global_lock is not None
796807
with _test_global_lock: # !ACQUIRE! !RELEASE! _test_global_lock
797808
pass
798809

@@ -843,7 +854,7 @@ def foo():
843854
],
844855
)
845856

846-
def test_upload_resets_profile(self):
857+
def test_upload_resets_profile(self) -> None:
847858
# This test checks that the profile is cleared after each upload() call
848859
# It is added in test_threading.py as LockCollector can easily be
849860
# configured to be deterministic with capture_pct=100.
@@ -885,21 +896,21 @@ class TestThreadingLockCollector(BaseThreadingLockCollectorTest):
885896
"""Test threading.Lock profiling"""
886897

887898
@property
888-
def collector_class(self):
899+
def collector_class(self) -> Type[collector_threading.ThreadingLockCollector]:
889900
return collector_threading.ThreadingLockCollector
890901

891902
@property
892-
def lock_class(self):
903+
def lock_class(self) -> Type[threading.Lock]:
893904
return threading.Lock
894905

895906

896907
class TestThreadingRLockCollector(BaseThreadingLockCollectorTest):
897908
"""Test threading.RLock profiling"""
898909

899910
@property
900-
def collector_class(self):
911+
def collector_class(self) -> Type[collector_threading.ThreadingRLockCollector]:
901912
return collector_threading.ThreadingRLockCollector
902913

903914
@property
904-
def lock_class(self):
915+
def lock_class(self) -> Type[threading.RLock]:
905916
return threading.RLock

0 commit comments

Comments
 (0)