Skip to content

Commit

Permalink
Merge pull request #780 from sunpy/YOLO
Browse files Browse the repository at this point in the history
Faster shortcut for working out coordinates values for non-correlated WCS
  • Loading branch information
DanRyanIrish authored Dec 13, 2024
2 parents ee91f49 + f821a69 commit cd93cf8
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 31 deletions.
1 change: 1 addition & 0 deletions changelog/780.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.
42 changes: 42 additions & 0 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l():
return WCS(header=header)


@pytest.fixture
def wcs_3d_wave_lt_ln():
header = {
'CTYPE1': 'WAVE ',
'CUNIT1': 'Angstrom',
'CDELT1': 0.2,
'CRPIX1': 0,
'CRVAL1': 10,

'CTYPE2': 'HPLT-TAN',
'CUNIT2': 'deg',
'CDELT2': 0.5,
'CRPIX2': 2,
'CRVAL2': 0.5,

'CTYPE3': 'HPLN-TAN ',
'CUNIT3': 'deg',
'CDELT3': 0.4,
'CRPIX3': 2,
'CRVAL3': 1,
}
return WCS(header=header)


@pytest.fixture
def wcs_2d_lt_ln():
spatial = {
Expand Down Expand Up @@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
return cube


@pytest.fixture
def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln):
shape = (3, 4, 5)
wcs_3d_wave_lt_ln.array_shape = shape
data = data_nd(shape)
mask = data > 0
cube = NDCube(
data,
wcs_3d_wave_lt_ln,
mask=mask,
uncertainty=data,
)
base_time = Time('2000-01-01', format='fits', scale='utc')
timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])])
cube.extra_coords.add('time', 0, timestamps)
return cube


@pytest.fixture
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],
Expand Down
125 changes: 100 additions & 25 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import astropy.nddata
import astropy.units as u
from astropy.units import UnitsError
from astropy.wcs.utils import _split_matrix

from ndcube.utils.wcs import world_axis_to_pixel_axes

try:
# Import sunpy coordinates if available to register the frames and WCS functions with astropy
Expand All @@ -20,7 +23,6 @@
pass

from astropy.wcs import WCS
from astropy.wcs.utils import _split_matrix
from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper
from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects

Expand Down Expand Up @@ -479,24 +481,76 @@ def quantity(self):
"""Unitful representation of the NDCube data."""
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
# Create meshgrid of all pixel coordinates.
# If user wants pixel_corners, set pixel values to pixel corners.
# Else make pixel centers.
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for independent axes.
The idea is to workout only the specific grid that is needed for independent axes.
This speeds up the calculation of world coordinates and reduces memory usage.
Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.
Returns
-------
array-like
The world coordinates.
"""
needed_axes = np.array(needed_axes).squeeze()
if self.data.ndim in needed_axes:
required_axes = needed_axes - 1
else:
required_axes = needed_axes
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
return world_coords

def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for dependent axes.
This will work out the exact grid that is needed for dependent axes
and can be time and memory consuming.
Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.
Returns
-------
array-like
The world coordinates.
"""
pixel_shape = self.data.shape[::-1]
if pixel_corners:
pixel_shape = tuple(np.array(pixel_shape) + 1)
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
else:
ranges = [np.arange(i) for i in pixel_shape]

# Limit the pixel dimensions to the ones present in the ExtraCoords
if isinstance(wcs, ExtraCoords):
ranges = [ranges[i] for i in wcs.mapping]
wcs = wcs.wcs
if wcs is None:
return []

return ()
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
Expand Down Expand Up @@ -528,71 +582,92 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
tmp_world = world[idx][tuple(array_slice)].T
world_coords[idx] = tmp_world

if units:
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
world_coords[i] = coord << u.Unit(unit)
return world_coords

def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
"""
Private method to generate world coordinates.
Handles both dependent and independent axes.
Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The axes that are needed.
units : bool
If units are needed.
Returns
-------
array-like
The world coordinates.
"""
axes_are_independent = []
pixel_axes = set()
for world_axis in needed_axes:
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
axes_are_independent.append(len(pix_ax) == 1)
pixel_axes = pixel_axes.union(set(pix_ax))
pixel_axes = list(pixel_axes)
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
else:
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
return world_coords

@utils.cube.sanitize_wcs
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()

object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
unique_obj_names = utils.misc.unique_sorted(object_names)
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]

# Create a mapping from world index in the WCS to object index in axes_coords
world_index_to_object_index = {}
for object_index, world_axes in enumerate(world_axes_for_obj):
for world_index in world_axes:
world_index_to_object_index[world_index] = object_index

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
object_indices = utils.misc.unique_sorted(
[world_index_to_object_index[world_index] for world_index in world_indices]
)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False)
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

return tuple(axes_coords[i] for i in object_indices)

@utils.cube.sanitize_wcs
def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs

if not wcs:
return ()
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True)
world_axis_physical_types = wcs.world_axis_physical_types

# If user has supplied axes, extract only the
# world coords that correspond to those axes.
if axes:
axes_coords = [axes_coords[i] for i in world_indices]
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])

# Return in array order.
# First replace characters in physical types forbidden for namedtuple identifiers.
identifiers = []
Expand Down
39 changes: 33 additions & 6 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astropy.coordinates import SkyCoord, SpectralCoord
from astropy.io import fits
from astropy.nddata import UnknownUncertainty
from astropy.tests.helper import assert_quantity_allclose
from astropy.time import Time
from astropy.units import UnitsError
from astropy.wcs import WCS
Expand Down Expand Up @@ -177,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):

coords = cube.axis_world_coords()
assert len(coords) == 2
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)

coords = cube.axis_world_coords(wcs=cube.combined_wcs)
assert len(coords) == 3
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)
assert isinstance(coords[2], Time)
assert coords[2].shape == (5,)

coords = cube.axis_world_coords(wcs=cube.extra_coords)
assert len(coords) == 1
Expand All @@ -199,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
# slice the cube so extra_coords is empty, and then try and run axis_world_coords
awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords)
assert awc == ()
sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True)
assert awc == ()


@pytest.mark.xfail(reason=">1D Tables not supported")
Expand Down Expand Up @@ -235,13 +244,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)


def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time):
# This replicates a specific NDCube object in visualization.rst
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords[0], Time)
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))

coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords.time, u.Quantity)
assert_quantity_allclose(coords.time, [0, 60, 120] * u.second)


@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):

# We go from 4 pixels to 6 pixels when we add pixel corners
coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)


@pytest.mark.parametrize(("ndc", "item"),
Expand All @@ -252,10 +279,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
indirect=("ndc",))
def test_axis_world_coords_sliced_all_3d(ndc, item):
coords = ndc[item].axis_world_coords_values()
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndc[item].axis_world_coords()
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)


@pytest.mark.parametrize(("ndc", "item"),
Expand Down

0 comments on commit cd93cf8

Please sign in to comment.