Skip to content
This repository has been archived by the owner on Oct 1, 2024. It is now read-only.

Commit

Permalink
Support decorations on tests
Browse files Browse the repository at this point in the history
Decorations on functions are not re-applied after rewriting their
assertions. This is because a function like

@foo
def a():
    pass

is (effectively) sugar for

def a():
    pass
a = foo(a)

However, rewrite_assertion only extracts the function from the
recompiled code. To include the decorations, we have to execute the
code and extract the function from the globals. However, this presents a
problem, since the decorator must be in scope when executing the code.
We could use the module's __dict__, but it's possible that the
decorator is redefined between when it is applied to the function and
when the module finishes execution:

def foo(f):
    return f

@foo
def a():
    pass

foo = 5

This will load properly, but would fail if we tried to execute code for
the function with the module's final __dict__. We have similar problems
for constructs like

for i in range(4):
    def _(i=i):
        return i

To really be correct here, we have to rewrite assertions when importing
the module. This is actually much simpler than the existing strategy (as
can be seen by the negative diffstat). It does result in a behavioral
change where all assertions in a test module are rewritten instead of
just those in tests.

This patch does not handle cases like

import b_test

assert 1 == 2

because if a_test is imported first, then it will import b_test with the
regular importer and the assertions will not be rewritten. To fix this
correctly, we need to replace the loader to add an exec hook which
applies only to test modules. This patch does not implement this and so
I have marked this patch RFC. However, if such a scenario does not
occur, this is more than enough to get hypothesis working.
  • Loading branch information
Forty-Bot committed Jun 23, 2024
1 parent f41c616 commit 2819512
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 107 deletions.
78 changes: 29 additions & 49 deletions tests/test_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import ast

from tests.utilities import testable_test
from tests.utilities import testable_test, failing_assertion
from ward import fixture, test
from ward._rewrite import (
RewriteAssert,
get_assertion_msg,
is_binary_comparison,
is_comparison_type,
make_call_node,
rewrite_assertions_in_tests,
)
from ward.expect import TestAssertionFailure, raises
from ward.testing import Test, each


Expand All @@ -34,37 +34,6 @@ def as_dict(node):
return node


@testable_test
def passing_fn():
assert 1 == 1


@testable_test
def failing_fn():
assert 1 == 2


@fixture
def passing():
yield Test(fn=passing_fn, module_name="m", id="id-pass")


@fixture
def failing():
yield Test(fn=failing_fn, module_name="m", id="id-fail")


@test("rewrite_assertions_in_tests returns all tests, keeping metadata")
def _(p=passing, f=failing):
in_tests = [p, f]
out_tests = rewrite_assertions_in_tests(in_tests)

def meta(test):
return test.description, test.id, test.module_name, test.fn.ward_meta

assert [meta(test) for test in in_tests] == [meta(test) for test in out_tests]


@test("RewriteAssert.visit_Assert doesn't transform `{src}`")
def _(
src=each(
Expand Down Expand Up @@ -121,6 +90,33 @@ def _(
assert out_tree.value.args[1].id == "y"
assert out_tree.value.args[2].s == ""

@test("This test suite's assertions are themselves rewritten")
def _():
with raises(TestAssertionFailure):
assert 1 == 2
with raises(TestAssertionFailure):
assert 1 != 1
with raises(TestAssertionFailure):
assert 1 in ()
with raises(TestAssertionFailure):
assert 1 not in (1,)
with raises(TestAssertionFailure):
assert None is Ellipsis
with raises(TestAssertionFailure):
assert None is not None
with raises(TestAssertionFailure):
assert 2 < 1
with raises(TestAssertionFailure):
assert 2 <= 1
with raises(TestAssertionFailure):
assert 1 > 2
with raises(TestAssertionFailure):
assert 1 >= 2

@test("Non-test modules' assertions aren't rewritten")
def _():
with raises(AssertionError):
failing_assertion()

@test("RewriteAssert.visit_Assert transforms `{src}`")
def _(src="assert 1 == 2, 'msg'"):
Expand Down Expand Up @@ -210,19 +206,3 @@ def _():
@test("test with indentation level of 2")
def _():
assert 2 + 3 == 5


@test("rewriter finds correct function when there is a lambda in an each")
def _():
@testable_test
def _(x=each(lambda: 5)):
assert x == 5

t = Test(fn=_, module_name="m")

rewritten = rewrite_assertions_in_tests([t])[0]

# https://github.com/darrenburns/ward/issues/169
# The assertion rewriter thought the lambda function stored in co_consts was the test function,
# so it was rebuilding the test function using the lambda as the test instead of the original function.
assert rewritten.fn.__code__.co_name != "<lambda>"
4 changes: 4 additions & 0 deletions tests/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def testable_test(func):
testable_test.path = FORCE_TEST_PATH # type: ignore[attr-defined]


def failing_assertion():
assert 1 == 2


@fixture
def dummy_fixture():
"""
Expand Down
3 changes: 2 additions & 1 deletion ward/_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cucumber_tag_expressions.model import Expression

from ward._errors import CollectionError
from ward._rewrite import exec_module
from ward._testing import COLLECTED_TESTS, is_test_module_name
from ward._utilities import get_absolute_path
from ward.fixtures import Fixture
Expand Down Expand Up @@ -149,7 +150,7 @@ def load_modules(modules: Iterable[pkgutil.ModuleInfo]) -> List[ModuleType]:
if pkg_data.pkg_root not in sys.path:
sys.path.append(str(pkg_data.pkg_root))
m.__package__ = pkg_data.pkg_name
m.__loader__.exec_module(m)
exec_module(m)
loaded_modules.append(m)

return loaded_modules
Expand Down
61 changes: 8 additions & 53 deletions ward/_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import textwrap
import types
from pathlib import Path
from typing import Iterable, List

from ward.expect import (
Expand Down Expand Up @@ -87,57 +88,11 @@ def visit_Assert(self, node): # noqa: C901 - no chance to reduce complexity
return node


def rewrite_assertions_in_tests(tests: Iterable[Test]) -> List[Test]:
return [rewrite_assertion(test) for test in tests]


def rewrite_assertion(test: Test) -> Test:
# Get the old code and code object
code_lines, line_no = inspect.getsourcelines(test.fn)

code = "".join(code_lines)
indents = textwrap._leading_whitespace_re.findall(code)
col_offset = len(indents[0]) if len(indents) > 0 else 0
code = textwrap.dedent(code)
code_obj = test.fn.__code__

# Rewrite the AST of the code
tree = ast.parse(code)
ast.increment_lineno(tree, line_no - 1)

def exec_module(module: types.ModuleType):
filename = module.__spec__.origin
code = module.__loader__.get_source(module.__name__)
tree = ast.parse(code, filename=filename)
new_tree = RewriteAssert().visit(tree)

if sys.version_info[:2] < (3, 11):
# We dedented the code so that it was a valid tree, now re-apply the indent
for child in ast.walk(new_tree):
if hasattr(child, "col_offset"):
child.col_offset = getattr(child, "col_offset", 0) + col_offset

# Reconstruct the test function
new_mod_code_obj = compile(new_tree, code_obj.co_filename, "exec")

# TODO: This probably isn't correct for nested closures
clo_glob = {}
if test.fn.__closure__:
clo_glob = test.fn.__closure__[0].cell_contents.__globals__

# Look through the new module,
# find the code object with the same name as the original code object,
# and build a new function with the injected assert functions added to the global namespace.
# Filtering on the code object name prevents finding other kinds of code objects,
# like lambdas stored directly in test function arguments.
for const in new_mod_code_obj.co_consts:
if isinstance(const, types.CodeType) and const.co_name == code_obj.co_name:
new_test_func = types.FunctionType(
const,
{**assert_func_namespace, **test.fn.__globals__, **clo_glob},
test.fn.__name__,
test.fn.__defaults__,
)
new_test_func.ward_meta = test.fn.ward_meta
return Test(
**{k: vars(test)[k] for k in vars(test) if k != "fn"},
fn=new_test_func,
)

return test
code = compile(new_tree, filename, "exec", dont_inherit=True)
module.__dict__.update(assert_func_namespace)
exec(code, module.__dict__)
5 changes: 1 addition & 4 deletions ward/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from ward._config import set_defaults_from_config
from ward._debug import init_breakpointhooks
from ward._rewrite import rewrite_assertions_in_tests
from ward._suite import Suite
from ward._terminal import (
SessionPrelude,
Expand Down Expand Up @@ -204,11 +203,9 @@ def test(
if config.order == "random":
shuffle(filtered_tests)

tests = rewrite_assertions_in_tests(filtered_tests)

time_to_collect_secs = default_timer() - start_run

suite = Suite(tests=tests)
suite = Suite(tests=filtered_tests)
test_results = suite.generate_test_runs(
dry_run=dry_run, capture_output=capture_output
)
Expand Down

0 comments on commit 2819512

Please sign in to comment.