Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test exception handling preserves the call stack. #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 78 additions & 20 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import asyncio
import collections.abc
import concurrent.futures
import contextlib
import functools
import gc
import inspect
import math
import sys
import threading
import traceback
import unittest
import uuid
import weakref
Expand Down Expand Up @@ -190,6 +192,41 @@ def counting_fn(*args) -> int:
return counting_fn, counter


class LineCapture:
def __init__(self):
self.line = None

def record_next_line(self):
"""Record the next line in the parent frame"""
self.line = inspect.currentframe().f_back.f_lineno + 1


class ExceptionContextManager:
exception: Exception


@contextlib.contextmanager
def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture):
try:
container = ExceptionContextManager()
yield container
except exception_type as exception:
container.exception = exception
traceback_exception = traceback.TracebackException.from_exception(exception)
if not len(traceback_exception.stack):
test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was a raw assertRaises, wouldn't we never get to this failure in the first place?

locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack]
line_number = line.line
error_message = [
f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a bit strange, can you explain what it's about?

]
error_message.extend(traceback_exception.format())
test.assertIn((__file__, line_number), locations, msg="".join(error_message))

else:
test.fail("expected exception not called")


class TestFunctionInspection(unittest.TestCase):
"""Unit tests for function inspection"""

Expand Down Expand Up @@ -387,33 +424,42 @@ def test_partial(self):

def test_failing_function(self):
counter = Counter()
failing_line = LineCapture()

@once.once
def sample_failing_fn():
nonlocal failing_line
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm:
sample_failing_fn()
self.assertEqual(cm.exception.args[0], "expected failure")
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")

def test_failing_function_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
def sample_failing_fn():
nonlocal failing_line
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
# This ensures that this was a new function call, not a cached result.
self.assertEqual(counter.get_incremented(), 4)
Expand All @@ -433,13 +479,15 @@ def yielding_iterator():

def test_failing_generator(self):
counter = Counter()
failing_line = LineCapture()

@once.once
def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("expected failure after 2.")

# Both of these calls should return the same results.
Expand All @@ -449,9 +497,9 @@ def sample_failing_fn():
self.assertEqual(next(call2), 1)
self.assertEqual(next(call1), 2)
self.assertEqual(next(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call2)
# These next 2 calls should also fail.
call3 = sample_failing_fn()
Expand All @@ -460,20 +508,22 @@ def sample_failing_fn():
self.assertEqual(next(call4), 1)
self.assertEqual(next(call3), 2)
self.assertEqual(next(call4), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call3)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call4)

def test_failing_generator_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("expected failure after 2.")

# Both of these calls should return the same results.
Expand All @@ -483,9 +533,9 @@ def sample_failing_fn():
self.assertEqual(next(call2), 1)
self.assertEqual(next(call1), 2)
self.assertEqual(next(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call2)
# These next 2 calls should succeed.
call3 = sample_failing_fn()
Expand Down Expand Up @@ -983,33 +1033,37 @@ def execute(*args):

async def test_failing_function(self):
counter = Counter()
failing_line = LineCapture()

@once.once
async def sample_failing_fn():
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")

async def test_failing_function_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
async def sample_failing_fn():
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
# This ensures that this was a new function call, not a cached result.
self.assertEqual(counter.get_incremented(), 4)
Expand Down Expand Up @@ -1062,13 +1116,15 @@ async def async_yielding_iterator():

async def test_failing_generator(self):
counter = Counter()
failing_line = LineCapture()

@once.once
async def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("we raise an error when result is exactly 2")

# Both of these calls should return the same results.
Expand All @@ -1078,9 +1134,9 @@ async def sample_failing_fn():
self.assertEqual(await anext(call2), 1)
self.assertEqual(await anext(call1), 2)
self.assertEqual(await anext(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call2)
# These next 2 calls should also fail.
call3 = sample_failing_fn()
Expand All @@ -1089,20 +1145,22 @@ async def sample_failing_fn():
self.assertEqual(await anext(call4), 1)
self.assertEqual(await anext(call3), 2)
self.assertEqual(await anext(call4), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call3)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call4)

async def test_failing_generator_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
async def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("we raise an error when result is exactly 2")

# Both of these calls should return the same results.
Expand All @@ -1112,9 +1170,9 @@ async def sample_failing_fn():
self.assertEqual(await anext(call2), 1)
self.assertEqual(await anext(call1), 2)
self.assertEqual(await anext(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call2)
# These next 2 calls should succeed.
call3 = sample_failing_fn()
Expand Down
Loading