Skip to content

Commit

Permalink
move dask tests to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
wtbarnes committed Oct 4, 2023
1 parent d7665ac commit 1aba1f1
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 52 deletions.
52 changes: 0 additions & 52 deletions sunpy/map/tests/test_mapbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Test Generic Map
"""
import re
import copy
import tempfile
import contextlib

Expand Down Expand Up @@ -1304,17 +1303,6 @@ def test_quicklook(mocker, aia171_test_map):
assert aia171_test_map._repr_html_() in html_string


def test_dask_array(generic_map):
dask_array = pytest.importorskip('dask.array')
da = dask_array.from_array(generic_map.data, chunks=(1, 1))
pair_map = sunpy.map.Map(da, generic_map.meta)

# Check that _repr_html_ functions for a dask array
html_dask_repr = pair_map._repr_html_(compute_dask=False)
html_computed_repr = pair_map._repr_html_(compute_dask=True)
assert html_dask_repr != html_computed_repr


@pytest.fixture
def generic_map2(generic_map):
generic_map.meta["CTYPE1"] = "HPLN-TAN"
Expand Down Expand Up @@ -1700,43 +1688,3 @@ def test_plot_annotate_nonboolean(aia171_test_map):
ax = plt.subplot(projection=aia171_test_map)
with pytest.raises(TypeError, match="non-boolean value"):
aia171_test_map.plot(ax)


@pytest.fixture
def aia171_test_dask_map(aia171_test_map):
dask_array = pytest.importorskip('dask.array')
return aia171_test_map._new_instance(
dask_array.from_array(aia171_test_map.data),
copy.deepcopy(aia171_test_map.meta)
)


# This is needed for the reproject_to function
with pytest.warns(VerifyWarning, match="Invalid 'BLANK' keyword in header."):
with fits.open(get_test_filepath('aia_171_level1.fits')) as hdu:
aia_wcs = astropy.wcs.WCS(header=hdu[0].header)


@pytest.mark.parametrize(("func", "args"), [
("max", {}),
("mean", {}),
("min", {}),
pytest.param("reproject_to", {"wcs": aia_wcs}, marks=pytest.mark.xfail(reason="reproject is not dask aware")),
pytest.param("resample", {"dimensions": (100, 100)*u.pix}, marks=pytest.mark.xfail()),
pytest.param("rotate", {}, marks=pytest.mark.xfail(reason="nanmedian is not implemented in Dask")),
("std", {}),
("superpixel", {"dimensions": (10, 10)*u.pix}),
("submap", {"bottom_left": (100, 100)*u.pixel, "width": 10*u.pixel, "height":10*u.pixel}),
])
def test_method_preserves_dask_array(aia171_test_map, aia171_test_dask_map, func, args):
dask_array = pytest.importorskip('dask.array')
# Check that result array is still a Dask array
res_dask = aia171_test_dask_map.__getattribute__(func)(**args)
res = aia171_test_map.__getattribute__(func)(**args)
result_is_map = isinstance(res, GenericMap)
if result_is_map:
assert isinstance(res_dask.data, dask_array.Array)
assert np.allclose(res_dask.data.compute(), res.data)
else:
assert isinstance(res_dask, dask_array.Array)
assert np.allclose(res_dask.compute(), res)
67 changes: 67 additions & 0 deletions sunpy/map/tests/test_mapbase_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy

import numpy as np
import pytest

import astropy.units as u
import astropy.wcs
from astropy.io import fits
from astropy.io.fits.verify import VerifyWarning

import sunpy.map
from sunpy.data.test import get_test_filepath


def test_dask_array(generic_map):
dask_array = pytest.importorskip('dask.array')
da = dask_array.from_array(generic_map.data, chunks=(1, 1))
pair_map = sunpy.map.Map(da, generic_map.meta)

# Check that _repr_html_ functions for a dask array
html_dask_repr = pair_map._repr_html_(compute_dask=False)
html_computed_repr = pair_map._repr_html_(compute_dask=True)
assert html_dask_repr != html_computed_repr


@pytest.fixture
def aia171_test_dask_map(aia171_test_map):
dask_array = pytest.importorskip('dask.array')
return aia171_test_map._new_instance(
dask_array.from_array(aia171_test_map.data),
copy.deepcopy(aia171_test_map.meta)
)


# This is needed for the reproject_to function
with pytest.warns(VerifyWarning, match="Invalid 'BLANK' keyword in header."):
with fits.open(get_test_filepath('aia_171_level1.fits')) as hdu:
aia_wcs = astropy.wcs.WCS(header=hdu[0].header)


@pytest.mark.parametrize(("func", "args"), [
("max", {}),
("mean", {}),
("min", {}),
pytest.param("reproject_to", {"wcs": aia_wcs}, marks=pytest.mark.xfail(reason="reproject is not dask aware")),
pytest.param("resample", {"dimensions": (100, 100)*u.pix}, marks=pytest.mark.xfail()),
pytest.param("rotate", {}, marks=pytest.mark.xfail(reason="nanmedian is not implemented in Dask")),
("std", {}),
("superpixel", {"dimensions": (10, 10)*u.pix}),
("submap", {"bottom_left": (100, 100)*u.pixel, "width": 10*u.pixel, "height":10*u.pixel}),
])
def test_method_preserves_dask_array(aia171_test_map, aia171_test_dask_map, func, args):
"""
Check that map methods preserve dask arrays if they are given as input, instead of eagerly
operating on them and bringing them into memory.
"""
dask_array = pytest.importorskip('dask.array')
# Check that result array is still a Dask array
res_dask = aia171_test_dask_map.__getattribute__(func)(**args)
res = aia171_test_map.__getattribute__(func)(**args)
result_is_map = isinstance(res, sunpy.map.GenericMap)
if result_is_map:
assert isinstance(res_dask.data, dask_array.Array)
assert np.all(res_dask.data.compute(), res.data)
else:
assert isinstance(res_dask, dask_array.Array)
assert np.all(res_dask.compute(), res)

0 comments on commit 1aba1f1

Please sign in to comment.