Skip to content

Commit

Permalink
Merge pull request sunpy#7100 from wtbarnes/add-dask-array-tests
Browse files Browse the repository at this point in the history
Add tests for checking whether Map methods are lazy
  • Loading branch information
dstansby authored Oct 5, 2023
2 parents be1e158 + e1f17b6 commit ba60d3f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog/7100.trivial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tests to check whether various `~sunpy.map.Map` methods preserve laziness when operating on Maps backed by a `dask.array.Array`.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
"asdf": ("https://asdf.readthedocs.io/en/stable/", None),
"astropy": ("https://docs.astropy.org/en/stable/", None),
"astroquery": ("https://astroquery.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
"drms": ("https://docs.sunpy.org/projects/drms/en/stable/", None),
"hvpy": ("https://hvpy.readthedocs.io/en/latest/", None),
"matplotlib": ("https://matplotlib.org/stable", None),
Expand Down
11 changes: 0 additions & 11 deletions sunpy/map/tests/test_mapbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,17 +1303,6 @@ def test_quicklook(mocker, aia171_test_map):
assert aia171_test_map._repr_html_() in html_string


def test_dask_array(generic_map):
dask_array = pytest.importorskip('dask.array')
da = dask_array.from_array(generic_map.data, chunks=(1, 1))
pair_map = sunpy.map.Map(da, generic_map.meta)

# Check that _repr_html_ functions for a dask array
html_dask_repr = pair_map._repr_html_(compute_dask=False)
html_computed_repr = pair_map._repr_html_(compute_dask=True)
assert html_dask_repr != html_computed_repr


@pytest.fixture
def generic_map2(generic_map):
generic_map.meta["CTYPE1"] = "HPLN-TAN"
Expand Down
62 changes: 62 additions & 0 deletions sunpy/map/tests/test_mapbase_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import copy

import numpy as np
import pytest

import astropy.units as u
import astropy.wcs
from astropy.io import fits
from astropy.io.fits.verify import VerifyWarning

import sunpy.map
from sunpy.data.test import get_test_filepath


@pytest.fixture
def aia171_test_dask_map(aia171_test_map):
dask_array = pytest.importorskip('dask.array')
return aia171_test_map._new_instance(
dask_array.from_array(aia171_test_map.data),
copy.deepcopy(aia171_test_map.meta)
)


def test_dask_array(aia171_test_dask_map):
# Check that _repr_html_ functions for a dask array
html_dask_repr = aia171_test_dask_map._repr_html_(compute_dask=False)
html_computed_repr = aia171_test_dask_map._repr_html_(compute_dask=True)
assert html_dask_repr != html_computed_repr


# This is needed for the reproject_to function
with pytest.warns(VerifyWarning, match="Invalid 'BLANK' keyword in header."):
with fits.open(get_test_filepath('aia_171_level1.fits')) as hdu:
aia_wcs = astropy.wcs.WCS(header=hdu[0].header)


@pytest.mark.parametrize(("func", "args"), [
("max", {}),
("mean", {}),
("min", {}),
pytest.param("reproject_to", {"wcs": aia_wcs}, marks=pytest.mark.xfail(reason="reproject is not dask aware")),
pytest.param("resample", {"dimensions": (100, 100)*u.pix}, marks=pytest.mark.xfail()),
pytest.param("rotate", {}, marks=pytest.mark.xfail(reason="nanmedian is not implemented in Dask")),
("std", {}),
("superpixel", {"dimensions": (10, 10)*u.pix}),
("submap", {"bottom_left": (100, 100)*u.pixel, "width": 10*u.pixel, "height":10*u.pixel}),
])
def test_method_preserves_dask_array(aia171_test_map, aia171_test_dask_map, func, args):
"""
Check that map methods preserve dask arrays if they are given as input, instead of eagerly
operating on them and bringing them into memory.
"""
# Check that result array is still a Dask array
res_dask = aia171_test_dask_map.__getattribute__(func)(**args)
res = aia171_test_map.__getattribute__(func)(**args)
result_is_map = isinstance(res, sunpy.map.GenericMap)
if result_is_map:
assert isinstance(res_dask.data, type(aia171_test_dask_map.data))
assert np.allclose(res_dask.data.compute(), res.data, atol=0.0, rtol=0.0)
else:
assert isinstance(res_dask, type(aia171_test_dask_map.data))
assert np.allclose(res_dask.compute(), res, atol=0.0, rtol=0.0)

0 comments on commit ba60d3f

Please sign in to comment.