diff --git a/pytest_mpl/plugin.py b/pytest_mpl/plugin.py index 65545be8..a6251489 100644 --- a/pytest_mpl/plugin.py +++ b/pytest_mpl/plugin.py @@ -177,6 +177,34 @@ def get_marker(item, marker_name): return item.keywords.get(marker_name) +def _raise_on_image_difference(expected, actual, tol): + """ + Based on matplotlib.testing.decorators._raise_on_image_difference + + Compare image size ourselves since the Matplotlib + exception is a bit cryptic in this case and doesn't show + the filenames + """ + from matplotlib.image import imread + from matplotlib.testing.compare import compare_images + + expected_shape = imread(expected).shape[:2] + actual_shape = imread(actual).shape[:2] + if expected_shape != actual_shape: + error = SHAPE_MISMATCH_ERROR.format(expected_path=expected, + expected_shape=expected_shape, + actual_path=actual, + actual_shape=actual_shape) + pytest.fail(error, pytrace=False) + + msg = compare_images(expected, actual, tol=tol) + + if msg is None: + shutil.rmtree(os.path.dirname(expected)) + else: + pytest.fail(msg, pytrace=False) + + class ImageComparison(object): def __init__(self, config, baseline_dir=None, generate_dir=None, results_dir=None): @@ -195,9 +223,7 @@ def pytest_runtest_setup(self, item): return import matplotlib - from matplotlib.image import imread import matplotlib.pyplot as plt - from matplotlib.testing.compare import compare_images try: from matplotlib.testing.decorators import remove_ticks_and_titles except ImportError: @@ -246,7 +272,10 @@ def item_function_wrapper(*args, **kwargs): fig = original(*args, **kwargs) if remove_text: - remove_ticks_and_titles(fig) + if not isinstance(fig, tuple): + remove_ticks_and_titles(fig) + else: + [remove_ticks_and_titles(f) for f in fig] # Find test name to use as plot name filename = compare.kwargs.get('filename', None) @@ -260,52 +289,51 @@ def item_function_wrapper(*args, **kwargs): # reference images or simply running the test. if self.generate_dir is None: - # Save the figure + # Save the figure(s) result_dir = tempfile.mkdtemp(dir=self.results_dir) test_image = os.path.abspath(os.path.join(result_dir, filename)) - - fig.savefig(test_image, **savefig_kwargs) - close_mpl_figure(fig) - - # Find path to baseline image - if baseline_remote: - baseline_image_ref = _download_file(baseline_dir, filename) - else: - baseline_image_ref = os.path.abspath(os.path.join( - os.path.dirname(item.fspath.strpath), baseline_dir, filename)) - - if not os.path.exists(baseline_image_ref): - pytest.fail("Image file not found for comparison test in: " - "\n\t{baseline_dir}" - "\n(This is expected for new tests.)\nGenerated Image: " - "\n\t{test}".format(baseline_dir=baseline_dir, - test=test_image), - pytrace=False) - - # distutils may put the baseline images in non-accessible places, - # copy to our tmpdir to be sure to keep them in case of failure baseline_image = os.path.abspath(os.path.join(result_dir, 'baseline-' + filename)) - shutil.copyfile(baseline_image_ref, baseline_image) - - # Compare image size ourselves since the Matplotlib - # exception is a bit cryptic in this case and doesn't show - # the filenames - expected_shape = imread(baseline_image).shape[:2] - actual_shape = imread(test_image).shape[:2] - if expected_shape != actual_shape: - error = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image, - expected_shape=expected_shape, - actual_path=test_image, - actual_shape=actual_shape) - pytest.fail(error, pytrace=False) - - msg = compare_images(baseline_image, test_image, tol=tolerance) - - if msg is None: - shutil.rmtree(result_dir) + + if not isinstance(fig, tuple): + fig.savefig(test_image, **savefig_kwargs) + close_mpl_figure(fig) + + # Find path to baseline image + if baseline_remote: + baseline_image_ref = _download_file(baseline_dir, filename) + else: + baseline_image_ref = os.path.abspath(os.path.join( + os.path.dirname(item.fspath.strpath), baseline_dir, filename)) + + if not os.path.exists(baseline_image_ref): + pytest.fail("Image file not found for comparison test in: " + "\n\t{baseline_dir}" + "\n(This is expected for new tests.)\nGenerated Image: " + "\n\t{test}".format(baseline_dir=baseline_dir, + test=test_image), + pytrace=False) + + # distutils may put the baseline images in non-accessible places, + # copy to our tmpdir to be sure to keep them in case of failure + shutil.copyfile(baseline_image_ref, baseline_image) + else: - pytest.fail(msg, pytrace=False) + fig[0].savefig(test_image, **savefig_kwargs) + close_mpl_figure(fig[0]) + fig[1].savefig(baseline_image, **savefig_kwargs) + close_mpl_figure(fig[1]) + + _raise_on_image_difference( + expected=baseline_image, + actual=test_image, + tol=tolerance + ) + + elif self.generate_dir and isinstance(fig, tuple): + close_mpl_figure(fig[0]) + close_mpl_figure(fig[1]) + pytest.skip("Skipping image comparison test") else: diff --git a/tests/test_pytest_mpl.py b/tests/test_pytest_mpl.py index ebf4bf88..a120ca5e 100644 --- a/tests/test_pytest_mpl.py +++ b/tests/test_pytest_mpl.py @@ -227,3 +227,58 @@ def test_succeeds(self): ax = fig.add_subplot(1, 1, 1) ax.plot(self.x) return fig + + +@pytest.mark.mpl_image_compare +def test_check_equal(): + fig_test, ax_test = plt.subplots() + ax_test.plot([1, 3, 5]) + + fig_ref, ax_ref = plt.subplots() + ax_ref.plot([0, 1, 2], [1, 3, 5]) + + return fig_test, fig_ref + + +TEST_GENERATE_2 = """ +import pytest +import matplotlib.pyplot as plt +@pytest.mark.mpl_image_compare +def test_gen_two_figs(): + fig_test, ax_test = plt.subplots() + ax_test.plot([1, 3, 5]) + fig_ref, ax_ref = plt.subplots() + ax_ref.plot([0, 1, 2], [1, 3, 7]) + return fig_test, fig_ref +""" + + +def test_check_unequal_fails(tmpdir): + + test_file = tmpdir.join("test2.py").strpath + with open(test_file, "w") as f: + f.write(TEST_GENERATE_2) + + # If we use --mpl, it should detect that the two figures are not the same + code = subprocess.call([sys.executable, "-m", "pytest", "--mpl", test_file]) + assert code != 0 + + # If we don't use --mpl option, the test should succeed + code = subprocess.call([sys.executable, "-m", "pytest", test_file]) + assert code == 0 + + +def test_skip_generate_two_figures(tmpdir): + + test_file = tmpdir.join("test2.py").strpath + with open(test_file, "w") as f: + f.write(TEST_GENERATE_2) + + gen_dir = tmpdir.mkdir("spam").mkdir("egg").strpath + + # If we try to generate, the test should be skipped and a new file won't appear + code = subprocess.call([sys.executable, "-m", "pytest", + "--mpl-generate-path={0}".format(gen_dir), + test_file]) + assert code == 0 + assert not os.path.exists(os.path.join(gen_dir, "test_gen_two_figs.png"))