From 0c798905284cf22b446571499818aa26d270c068 Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 8 Nov 2023 05:25:34 +0000 Subject: [PATCH] Test exception handling preserves the call stack. This way, using a once decorator will not swallow all exception traces. --- once_test.py | 98 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/once_test.py b/once_test.py index 627f11d..a5bac6c 100644 --- a/once_test.py +++ b/once_test.py @@ -3,12 +3,14 @@ import asyncio import collections.abc import concurrent.futures +import contextlib import functools import gc import inspect import math import sys import threading +import traceback import unittest import uuid import weakref @@ -190,6 +192,41 @@ def counting_fn(*args) -> int: return counting_fn, counter +class LineCapture: + def __init__(self): + self.line = None + + def record_next_line(self): + """Record the next line in the parent frame""" + self.line = inspect.currentframe().f_back.f_lineno + 1 + + +class ExceptionContextManager: + exception: Exception + + +@contextlib.contextmanager +def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture): + try: + container = ExceptionContextManager() + yield container + except exception_type as exception: + container.exception = exception + traceback_exception = traceback.TracebackException.from_exception(exception) + if not len(traceback_exception.stack): + test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?") + locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack] + line_number = line.line + error_message = [ + f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n" + ] + error_message.extend(traceback_exception.format()) + test.assertIn((__file__, line_number), locations, msg="".join(error_message)) + + else: + test.fail("expected exception not called") + + class TestFunctionInspection(unittest.TestCase): """Unit tests for function inspection""" @@ -387,33 +424,42 @@ def test_partial(self): def test_failing_function(self): counter = Counter() + failing_line = LineCapture() @once.once def sample_failing_fn(): + nonlocal failing_line if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): + sample_failing_fn() + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm: sample_failing_fn() + self.assertEqual(cm.exception.args[0], "expected failure") self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") def test_failing_function_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) def sample_failing_fn(): + nonlocal failing_line if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() # This ensures that this was a new function call, not a cached result. self.assertEqual(counter.get_incremented(), 4) @@ -433,6 +479,7 @@ def yielding_iterator(): def test_failing_generator(self): counter = Counter() + failing_line = LineCapture() @once.once def sample_failing_fn(): @@ -440,6 +487,7 @@ def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("expected failure after 2.") # Both of these calls should return the same results. @@ -449,9 +497,9 @@ def sample_failing_fn(): self.assertEqual(next(call2), 1) self.assertEqual(next(call1), 2) self.assertEqual(next(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call2) # These next 2 calls should also fail. call3 = sample_failing_fn() @@ -460,13 +508,14 @@ def sample_failing_fn(): self.assertEqual(next(call4), 1) self.assertEqual(next(call3), 2) self.assertEqual(next(call4), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call3) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call4) def test_failing_generator_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) def sample_failing_fn(): @@ -474,6 +523,7 @@ def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("expected failure after 2.") # Both of these calls should return the same results. @@ -483,9 +533,9 @@ def sample_failing_fn(): self.assertEqual(next(call2), 1) self.assertEqual(next(call1), 2) self.assertEqual(next(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call2) # These next 2 calls should succeed. call3 = sample_failing_fn() @@ -983,33 +1033,37 @@ def execute(*args): async def test_failing_function(self): counter = Counter() + failing_line = LineCapture() @once.once async def sample_failing_fn(): if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") async def test_failing_function_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) async def sample_failing_fn(): if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() # This ensures that this was a new function call, not a cached result. self.assertEqual(counter.get_incremented(), 4) @@ -1062,6 +1116,7 @@ async def async_yielding_iterator(): async def test_failing_generator(self): counter = Counter() + failing_line = LineCapture() @once.once async def sample_failing_fn(): @@ -1069,6 +1124,7 @@ async def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("we raise an error when result is exactly 2") # Both of these calls should return the same results. @@ -1078,9 +1134,9 @@ async def sample_failing_fn(): self.assertEqual(await anext(call2), 1) self.assertEqual(await anext(call1), 2) self.assertEqual(await anext(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call2) # These next 2 calls should also fail. call3 = sample_failing_fn() @@ -1089,13 +1145,14 @@ async def sample_failing_fn(): self.assertEqual(await anext(call4), 1) self.assertEqual(await anext(call3), 2) self.assertEqual(await anext(call4), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call3) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call4) async def test_failing_generator_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) async def sample_failing_fn(): @@ -1103,6 +1160,7 @@ async def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("we raise an error when result is exactly 2") # Both of these calls should return the same results. @@ -1112,9 +1170,9 @@ async def sample_failing_fn(): self.assertEqual(await anext(call2), 1) self.assertEqual(await anext(call1), 2) self.assertEqual(await anext(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call2) # These next 2 calls should succeed. call3 = sample_failing_fn()