diff --git a/pytest_mpl/plugin.py b/pytest_mpl/plugin.py index 4613979c..3ddc24b4 100644 --- a/pytest_mpl/plugin.py +++ b/pytest_mpl/plugin.py @@ -81,12 +81,33 @@ def pathify(path): return Path(path + ext) -def _pytest_pyfunc_call(obj, pyfuncitem): - testfunction = pyfuncitem.obj - funcargs = pyfuncitem.funcargs - testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} - obj.result = testfunction(**testargs) - return True +def generate_test_name(item): + """ + Generate a unique name for the hash for this test. + """ + if item.cls is not None: + name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}" + else: + name = f"{item.module.__name__}.{item.name}" + return name + + +def wrap_figure_interceptor(plugin, item): + """ + Intercept and store figures returned by test functions. + """ + # Only intercept figures on marked figure tests + if get_compare(item) is not None: + + # Use the full test name as a key to ensure correct figure is being retrieved + test_name = generate_test_name(item) + + def figure_interceptor(store, obj): + def wrapper(*args, **kwargs): + store.return_value[test_name] = obj(*args, **kwargs) + return wrapper + + item.obj = figure_interceptor(plugin, item.obj) def pytest_report_header(config, startdir): @@ -275,6 +296,7 @@ def __init__(self, self._generated_hash_library = {} self._test_results = {} self._test_stats = None + self.return_value = {} # https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin # turn debug prints on only if "-vv" or more passed @@ -287,7 +309,7 @@ def generate_filename(self, item): Given a pytest item, generate the figure filename. """ if self.config.getini('mpl-use-full-test-name'): - filename = self.generate_test_name(item) + '.png' + filename = generate_test_name(item) + '.png' else: compare = get_compare(item) # Find test name to use as plot name @@ -298,21 +320,11 @@ def generate_filename(self, item): filename = str(pathify(filename)) return filename - def generate_test_name(self, item): - """ - Generate a unique name for the hash for this test. - """ - if item.cls is not None: - name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}" - else: - name = f"{item.module.__name__}.{item.name}" - return name - def make_test_results_dir(self, item): """ Generate the directory to put the results in. """ - test_name = pathify(self.generate_test_name(item)) + test_name = pathify(generate_test_name(item)) results_dir = self.results_dir / test_name results_dir.mkdir(exist_ok=True, parents=True) return results_dir @@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None): pytest.fail(f"Can't find hash library at path {hash_library_filename}") hash_library = self.load_hash_library(hash_library_filename) - hash_name = self.generate_test_name(item) + hash_name = generate_test_name(item) baseline_hash = hash_library.get(hash_name, None) summary['baseline_hash'] = baseline_hash @@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa with plt.style.context(style, after_reset=True), switch_backend(backend): # Run test and get figure object + wrap_figure_interceptor(self, item) yield - fig = self.result + test_name = generate_test_name(item) + if test_name not in self.return_value: + # Test function did not complete successfully + return + fig = self.return_value[test_name] if remove_text: remove_ticks_and_titles(fig) - test_name = self.generate_test_name(item) result_dir = self.make_test_results_dir(item) summary = { @@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa if summary['status'] == 'skipped': pytest.skip(summary['status_msg']) - @pytest.hookimpl(tryfirst=True) - def pytest_pyfunc_call(self, pyfuncitem): - return _pytest_pyfunc_call(self, pyfuncitem) - def generate_summary_json(self): json_file = self.results_dir / 'results.json' with open(json_file, 'w') as f: @@ -732,13 +744,16 @@ class FigureCloser: def __init__(self, config): self.config = config + self.return_value = {} @pytest.hookimpl(hookwrapper=True) def pytest_runtest_call(self, item): + wrap_figure_interceptor(self, item) yield if get_compare(item) is not None: - close_mpl_figure(self.result) - - @pytest.hookimpl(tryfirst=True) - def pytest_pyfunc_call(self, pyfuncitem): - return _pytest_pyfunc_call(self, pyfuncitem) + test_name = generate_test_name(item) + if test_name not in self.return_value: + # Test function did not complete successfully + return + fig = self.return_value[test_name] + close_mpl_figure(fig) diff --git a/setup.cfg b/setup.cfg index b9eb7133..d2fb5e0a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,13 @@ test = [tool:pytest] testpaths = "tests" +markers = + image: run test during image comparison only mode. + hash: run test during hash comparison only mode. +filterwarnings = + error + ignore:distutils Version classes are deprecated + ignore:the imp module is deprecated in favour of importlib [flake8] max-line-length = 100 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..9c3572d9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +from packaging.version import Version + +pytest_plugins = ["pytester"] + +if Version(pytest.__version__) < Version("6.2.0"): + @pytest.fixture + def pytester(testdir): + return testdir diff --git a/tests/test_pytest_mpl.py b/tests/test_pytest_mpl.py index fec94cb0..a9d2cba1 100644 --- a/tests/test_pytest_mpl.py +++ b/tests/test_pytest_mpl.py @@ -3,6 +3,7 @@ import json import subprocess from pathlib import Path +from unittest import TestCase import matplotlib import matplotlib.ft2font @@ -259,6 +260,23 @@ def test_succeeds(self): return fig +class TestClassWithTestCase(TestCase): + + # Regression test for a bug that occurred when using unittest.TestCase + + def setUp(self): + self.x = [1, 2, 3] + + @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir_local, + filename='test_succeeds.png', + tolerance=DEFAULT_TOLERANCE) + def test_succeeds(self): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.plot(self.x) + return fig + + # hashlib @pytest.mark.skipif(not hash_library.exists(), reason="No hash library for this mpl version") @@ -514,8 +532,27 @@ def test_fails(self): return fig """ +TEST_FAILING_UNITTEST_TESTCASE = """ +from unittest import TestCase +import pytest +import matplotlib.pyplot as plt +class TestClassWithTestCase(TestCase): + def setUp(self): + self.x = [1, 2, 3] + @pytest.mark.mpl_image_compare + def test_fails(self): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.plot(self.x) + return fig +""" -@pytest.mark.parametrize("code", [TEST_FAILING_CLASS, TEST_FAILING_CLASS_SETUP_METHOD]) + +@pytest.mark.parametrize("code", [ + TEST_FAILING_CLASS, + TEST_FAILING_CLASS_SETUP_METHOD, + TEST_FAILING_UNITTEST_TESTCASE, +]) def test_class_fail(code, tmpdir): test_file = tmpdir.join('test.py').strpath @@ -529,3 +566,107 @@ def test_class_fail(code, tmpdir): # If we don't use --mpl option, the test should succeed code = call_pytest([test_file]) assert code == 0 + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_fail(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_fail(): + pytest.fail("Manually failed by user.") + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes(failed=1) + result.stdout.fnmatch_lines("FAILED*Manually failed by user.*") + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_skip(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_skip(): + pytest.skip("Manually skipped by user.") + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes(skipped=1) + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_importorskip(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_importorskip(): + pytest.importorskip("nonexistantmodule") + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes(skipped=1) + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_xfail(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_xfail(): + pytest.xfail() + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes(xfailed=1) + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_exit_success(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_exit_success(): + pytest.exit("Manually exited by user.", returncode=0) + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes() + assert result.ret == 0 + result.stdout.fnmatch_lines("*Exit*Manually exited by user.*") + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_exit_failure(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_exit_fail(): + pytest.exit("Manually exited by user.", returncode=1) + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes() + assert result.ret == 1 + result.stdout.fnmatch_lines("*Exit*Manually exited by user.*") + + +@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)]) +def test_user_function_raises(pytester, runpytest_args): + pytester.makepyfile( + """ + import pytest + @pytest.mark.mpl_image_compare + def test_raises(): + raise ValueError("User code raised an exception.") + """ + ) + result = pytester.runpytest(*runpytest_args) + result.assert_outcomes(failed=1) + result.stdout.fnmatch_lines("FAILED*ValueError*User code*") diff --git a/tox.ini b/tox.ini index bec14d22..c4bdf468 100644 --- a/tox.ini +++ b/tox.ini @@ -51,8 +51,3 @@ description = check code style, e.g. with flake8 deps = pre-commit commands = pre-commit run --all-files - -[pytest] -markers = - image: run test during image comparison only mode. - hash: run test during hash comparison only mode.