Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster shortcut for working out coordinates values for non-correlated WCS #780

Merged
merged 12 commits into from
Dec 13, 2024
1 change: 1 addition & 0 deletions changelog/780.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 42 additions & 0 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l():
return WCS(header=header)


@pytest.fixture
def wcs_3d_wave_lt_ln():
header = {
'CTYPE1': 'WAVE ',
'CUNIT1': 'Angstrom',
'CDELT1': 0.2,
'CRPIX1': 0,
'CRVAL1': 10,

'CTYPE2': 'HPLT-TAN',
'CUNIT2': 'deg',
'CDELT2': 0.5,
'CRPIX2': 2,
'CRVAL2': 0.5,

'CTYPE3': 'HPLN-TAN ',
'CUNIT3': 'deg',
'CDELT3': 0.4,
'CRPIX3': 2,
'CRVAL3': 1,
}
return WCS(header=header)


@pytest.fixture
def wcs_2d_lt_ln():
spatial = {
Expand Down Expand Up @@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
return cube


@pytest.fixture
def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln):
shape = (3, 4, 5)
wcs_3d_wave_lt_ln.array_shape = shape
data = data_nd(shape)
mask = data > 0
cube = NDCube(
data,
wcs_3d_wave_lt_ln,
mask=mask,
uncertainty=data,
)
base_time = Time('2000-01-01', format='fits', scale='utc')
timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])])
cube.extra_coords.add('time', 0, timestamps)
return cube


@pytest.fixture
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],
Expand Down
125 changes: 100 additions & 25 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import astropy.nddata
import astropy.units as u
from astropy.units import UnitsError
from astropy.wcs.utils import _split_matrix

from ndcube.utils.wcs import world_axis_to_pixel_axes

try:
# Import sunpy coordinates if available to register the frames and WCS functions with astropy
Expand All @@ -20,7 +23,6 @@
pass

from astropy.wcs import WCS
from astropy.wcs.utils import _split_matrix
from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper
from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects

Expand Down Expand Up @@ -479,24 +481,76 @@
"""Unitful representation of the NDCube data."""
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
# Create meshgrid of all pixel coordinates.
# If user wants pixel_corners, set pixel values to pixel corners.
# Else make pixel centers.
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for independent axes.

The idea is to workout only the specific grid that is needed for independent axes.
This speeds up the calculation of world coordinates and reduces memory usage.

Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.

Returns
-------
array-like
The world coordinates.
"""
needed_axes = np.array(needed_axes).squeeze()
if self.data.ndim in needed_axes:
required_axes = needed_axes - 1
else:
required_axes = needed_axes
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
return world_coords

def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for dependent axes.

This will work out the exact grid that is needed for dependent axes
and can be time and memory consuming.

Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.

Returns
-------
array-like
The world coordinates.
"""
pixel_shape = self.data.shape[::-1]
if pixel_corners:
pixel_shape = tuple(np.array(pixel_shape) + 1)
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
else:
ranges = [np.arange(i) for i in pixel_shape]

# Limit the pixel dimensions to the ones present in the ExtraCoords
if isinstance(wcs, ExtraCoords):
ranges = [ranges[i] for i in wcs.mapping]
wcs = wcs.wcs
if wcs is None:
return []

return ()

Check warning on line 553 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L553

Added line #L553 was not covered by tests
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
Expand Down Expand Up @@ -528,71 +582,92 @@
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
tmp_world = world[idx][tuple(array_slice)].T
world_coords[idx] = tmp_world

if units:
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
world_coords[i] = coord << u.Unit(unit)
return world_coords

def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
"""
Private method to generate world coordinates.

Handles both dependent and independent axes.

Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The axes that are needed.
units : bool
If units are needed.

Returns
-------
array-like
The world coordinates.
"""
axes_are_independent = []
pixel_axes = set()
for world_axis in needed_axes:
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
axes_are_independent.append(len(pix_ax) == 1)
pixel_axes = pixel_axes.union(set(pix_ax))
pixel_axes = list(pixel_axes)
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
else:
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
return world_coords

@utils.cube.sanitize_wcs
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()

object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
unique_obj_names = utils.misc.unique_sorted(object_names)
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]

# Create a mapping from world index in the WCS to object index in axes_coords
world_index_to_object_index = {}
for object_index, world_axes in enumerate(world_axes_for_obj):
for world_index in world_axes:
world_index_to_object_index[world_index] = object_index

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
object_indices = utils.misc.unique_sorted(
[world_index_to_object_index[world_index] for world_index in world_indices]
)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False)
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

return tuple(axes_coords[i] for i in object_indices)

@utils.cube.sanitize_wcs
def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs

if not wcs:
return ()

Check warning on line 662 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L662

Added line #L662 was not covered by tests
Comment on lines +661 to +662
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was missing and is included in the other version of this method, so I added it.

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True)
world_axis_physical_types = wcs.world_axis_physical_types

# If user has supplied axes, extract only the
# world coords that correspond to those axes.
if axes:
axes_coords = [axes_coords[i] for i in world_indices]
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])

# Return in array order.
# First replace characters in physical types forbidden for namedtuple identifiers.
identifiers = []
Expand Down
39 changes: 33 additions & 6 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astropy.coordinates import SkyCoord, SpectralCoord
from astropy.io import fits
from astropy.nddata import UnknownUncertainty
from astropy.tests.helper import assert_quantity_allclose
from astropy.time import Time
from astropy.units import UnitsError
from astropy.wcs import WCS
Expand Down Expand Up @@ -177,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):

coords = cube.axis_world_coords()
assert len(coords) == 2
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)

coords = cube.axis_world_coords(wcs=cube.combined_wcs)
assert len(coords) == 3
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)
assert isinstance(coords[2], Time)
assert coords[2].shape == (5,)

coords = cube.axis_world_coords(wcs=cube.extra_coords)
assert len(coords) == 1
Expand All @@ -199,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
# slice the cube so extra_coords is empty, and then try and run axis_world_coords
awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords)
assert awc == ()
sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True)
assert awc == ()
Comment on lines -202 to -203
Copy link
Contributor Author

@nabobalis nabobalis Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the private method does not handle ECs, this happens now at the higher level.



@pytest.mark.xfail(reason=">1D Tables not supported")
Expand Down Expand Up @@ -235,13 +244,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)


def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time):
# This replicates a specific NDCube object in visualization.rst
coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords[0], Time)
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))

coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords.time, u.Quantity)
assert_quantity_allclose(coords.time, [0, 60, 120] * u.second)


@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):

# We go from 4 pixels to 6 pixels when we add pixel corners
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
nabobalis marked this conversation as resolved.
Show resolved Hide resolved

coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)


@pytest.mark.parametrize(("ndc", "item"),
Expand All @@ -252,10 +279,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
indirect=("ndc",))
def test_axis_world_coords_sliced_all_3d(ndc, item):
coords = ndc[item].axis_world_coords_values()
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndc[item].axis_world_coords()
assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have these asserts had to change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I knew, I assume since I broke the code.

Copy link
Contributor Author

@nabobalis nabobalis Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is common in most of the other tests, you return a tuple of length N and you have to escape it always to get the coord. There are lots of tests where you do len(coords) and it is 1 but then you need to index the return to get the coord info.



@pytest.mark.parametrize(("ndc", "item"),
Expand Down
Loading