Skip to content

Commit

Permalink
Another go that failed
Browse files Browse the repository at this point in the history
  • Loading branch information
nabobalis committed Dec 5, 2024
1 parent 8d79d7d commit eb88be8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 51 deletions.
78 changes: 29 additions & 49 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ def quantity(self):
"""Unitful representation of the NDCube data."""
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)


def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, units):
def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for independent axes.
Expand All @@ -459,8 +458,8 @@ def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, uni
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
pixel_axes : array-like
The pixel axes.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.
Expand All @@ -469,29 +468,19 @@ def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, uni
array-like
The world coordinates.
"""
naxes = len(self.data.shape)
pixel_indices = [np.array([0], dtype=int).reshape([1] * naxes).squeeze()] * naxes
for pixel_axis in pixel_axes:
len_axis = self.data.shape[::-1][pixel_axis]
# Define limits of desired pixel range based on whether corners or centers are desired
lims = (-0.5, len_axis + 1) if pixel_corners else (0, len_axis)
pix_ind = np.arange(lims[0], lims[1])
shape = [1] * naxes
shape[pixel_axis] = len(pix_ind)
pixel_indices[pixel_axis] = pix_ind.reshape(shape)
world_coords = wcs.pixel_to_world_values(*pixel_indices)
# TODO: Remove NaNs??? These should not be here
if np.isnan(world_coords).any():
if isinstance(world_coords, tuple| list):
world_coords = [world_coord[~np.isnan(world_coord)] for world_coord in world_coords]
else:
world_coords = world_coords[~np.isnan(world_coords)]
needed_axes = np.array(needed_axes).squeeze()
if self.data.ndim in needed_axes:
required_axes = needed_axes - 1
else:
required_axes = needed_axes
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
mod = abs(wcs.world_n_dim - naxes) if wcs.world_n_dim > naxes else 0
world_coords = world_coords << u.Unit(wcs.world_axis_units[np.squeeze(pixel_axes)+mod])
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
return world_coords

def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units):
def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for dependent axes.
Expand All @@ -504,8 +493,8 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
pixel_axes : array-like
The pixel axes.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.
Expand All @@ -520,13 +509,19 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
else:
ranges = [np.arange(i) for i in pixel_shape]
# Limit the pixel dimensions to the ones present in the ExtraCoords
if isinstance(wcs, ExtraCoords):
ranges = [ranges[i] for i in wcs.mapping]
wcs = wcs.wcs
if wcs is None:
return []
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
if (pixel_axes is not None
and len(pixel_axes)
and not any(world_axis in pixel_axes for world_axis in world_axes_indices)):
if (needed_axes is not None
and len(needed_axes)
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
# needed_axes indicates which values in world_coords will be used by the calling
# function, so skip this iteration if we won't be producing any of those values
continue
Expand Down Expand Up @@ -556,7 +551,6 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units
world_coords[i] = coord << u.Unit(unit)
return world_coords


def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes=None, units=None):
"""
Private method to generate world coordinates.
Expand All @@ -579,32 +573,21 @@ def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes=None, units=
array-like
The world coordinates.
"""
# TODO: Workout why I need this twice now.
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()
if needed_axes is None or len(needed_axes) == 0:
needed_axes = np.array(list(range(wcs.world_n_dim)),dtype=int)
axes_are_independent = []
pixel_axes = set()
for world_axis in needed_axes:
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
axes_are_independent.append(len(pix_ax) == 1)
pixel_axes = pixel_axes.union(set(pix_ax))
if len(pixel_axes) == 1:
pixel_axes = list(pixel_axes)
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes):
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, pixel_axes, units)
else:
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, pixel_axes, units)
if len(world_coords) > 1 and isinstance(world_coords, tuple | list):
world_coords = [np.squeeze(world_coord) for world_coord in world_coords]
pixel_axes = list(pixel_axes)
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
else:
world_coords = np.squeeze(world_coords)
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
return world_coords


@utils.cube.sanitize_wcs
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
Expand All @@ -628,8 +611,6 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
[world_index_to_object_index[world_index] for world_index in world_indices]
)
axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False)
if not isinstance(axes_coords, list):
axes_coords = [axes_coords]
axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)
if not axes:
return tuple(axes_coords)
Expand Down Expand Up @@ -662,8 +643,7 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
identifier = identifier.replace("-", "__")
identifiers.append(identifier)
CoordValues = namedtuple("CoordValues", identifiers)
flag = len(axes_coords) == 1 or isinstance(axes_coords, tuple | list)
return CoordValues(*axes_coords[::-1]) if flag else CoordValues(axes_coords)
return CoordValues(*axes_coords[::-1])

def crop(self, *points, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
Expand Down
12 changes: 10 additions & 2 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,19 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):

coords = cube.axis_world_coords()
assert len(coords) == 2
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)

coords = cube.axis_world_coords(wcs=cube.combined_wcs)
assert len(coords) == 3
assert isinstance(coords[0], SkyCoord)
assert coords[0].shape == (5, 8)
assert isinstance(coords[1], SpectralCoord)
assert coords[1].shape == (10,)
assert isinstance(coords[2], Time)
assert coords[2].shape == (5,)

coords = cube.axis_world_coords(wcs=cube.extra_coords)
assert len(coords) == 1
Expand All @@ -200,8 +210,6 @@ def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
# slice the cube so extra_coords is empty, and then try and run axis_world_coords
awc = sub_cube.axis_world_coords(wcs=sub_cube.extra_coords)
assert awc == ()
sub_cube._generate_world_coords(pixel_corners=False, wcs=sub_cube.extra_coords, units=True)
assert awc == ()


@pytest.mark.xfail(reason=">1D Tables not supported")
Expand Down

0 comments on commit eb88be8

Please sign in to comment.