Skip to content

Commit

Permalink
Hack to prevent len(self.data.shape) != wcs.world_n_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
nabobalis committed Nov 11, 2024
1 parent 9b93c83 commit b34f6c5
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,50 +444,56 @@ def quantity(self):
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
# LOLOLOLOLOLOLOLO - sHORT cIRCUIT fOR nON-cORRELATED wORLD cOORDINATES
if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1:
indices = [np.arange(self.data.shape[::-1][needed_axes[0]]).tolist() if wanted else 0
for wanted in wcs.axis_correlation_matrix[needed_axes][0]]
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1 and len(self.data.shape) == wcs.world_n_dim:
indices = [np.arange(self.data.shape[::-1][needed_axes[0]]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes[0]])
return world_coords
# Create meshgrid of all pixel coordinates.
# If user wants pixel_corners, set pixel values to pixel corners.

# Create a meshgrid of all pixel coordinates.
# If the user wants pixel corners, set pixel values to pixel corners.
# Else make pixel centers.
pixel_shape = self.data.shape[::-1]
if pixel_corners:
pixel_shape = tuple(np.array(pixel_shape) + 1)
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
pixel_ranges = [np.arange(i) - 0.5 for i in pixel_shape]
else:
ranges = [np.arange(i) for i in pixel_shape]
pixel_ranges = [np.arange(i) for i in pixel_shape]

# Limit the pixel dimensions to the ones present in the ExtraCoords
if isinstance(wcs, ExtraCoords):
ranges = [ranges[i] for i in wcs.mapping]
pixel_ranges = [pixel_ranges[i] for i in wcs.mapping]
wcs = wcs.wcs
if wcs is None:
return []

# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
if (needed_axes is not None
and len(needed_axes)
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
if (
needed_axes is not None
and len(needed_axes)
and all(
world_axis not in needed_axes
for world_axis in world_axes_indices
)
):
# needed_axes indicates which values in world_coords will be used by the calling
# function, so skip this iteration if we won't be producing any of those values
continue
# First construct a range of pixel indices for this set of coupled dimensions
sub_range = [ranges[idx] for idx in pixel_axes_indices]
# Then get a set of non correlated dimensions
pixel_ranges_subset = [pixel_ranges[idx] for idx in pixel_axes_indices]
# Then get a set of non-correlated dimensions
non_corr_axes = set(list(range(wcs.pixel_n_dim))) - set(pixel_axes_indices)
# And inject 0s for those coordinates
for idx in non_corr_axes:
sub_range.insert(idx, 0)
# Generate a grid of broadcastable pixel indices for all pixel dimensions
grid = np.meshgrid(*sub_range, indexing='ij')
pixel_ranges_subset.insert(idx, 0)
# Generate a grid of broadcast-able pixel indices for all pixel dimensions
grid = np.meshgrid(*pixel_ranges_subset, indexing='ij')
# Convert to world coordinates
world = wcs.pixel_to_world_values(*grid)
# TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332
Expand All @@ -500,7 +506,6 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
tmp_world = world[idx][tuple(array_slice)].T
world_coords[idx] = tmp_world

if units:
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
world_coords[i] = coord << u.Unit(unit)
Expand Down

0 comments on commit b34f6c5

Please sign in to comment.