From b34f6c5eb55140c26cf0cead63030a311ec4ac83 Mon Sep 17 00:00:00 2001 From: Nabil Freij Date: Mon, 11 Nov 2024 11:54:39 -0700 Subject: [PATCH] Hack to prevent len(self.data.shape) != wcs.world_n_dim --- ndcube/ndcube.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 9f3426f8d..805a9f0eb 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -444,50 +444,56 @@ def quantity(self): return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units): - # LOLOLOLOLOLOLOLO - sHORT cIRCUIT fOR nON-cORRELATED wORLD cOORDINATES - if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1: - indices = [np.arange(self.data.shape[::-1][needed_axes[0]]).tolist() if wanted else 0 - for wanted in wcs.axis_correlation_matrix[needed_axes][0]] + # This will generate only the coordinates that are needed if there is no correlation within the WCS + # This bypasses the entire rest of the function below which works out the full set of coordinates + # This only works for WCS that have the same number of world and pixel dimensions + if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1 and len(self.data.shape) == wcs.world_n_dim: + indices = [np.arange(self.data.shape[::-1][needed_axes[0]]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]] world_coords = wcs.pixel_to_world_values(*indices) if units: world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes[0]]) return world_coords - # Create meshgrid of all pixel coordinates. - # If user wants pixel_corners, set pixel values to pixel corners. + + # Create a meshgrid of all pixel coordinates. + # If the user wants pixel corners, set pixel values to pixel corners. # Else make pixel centers. pixel_shape = self.data.shape[::-1] if pixel_corners: pixel_shape = tuple(np.array(pixel_shape) + 1) - ranges = [np.arange(i) - 0.5 for i in pixel_shape] + pixel_ranges = [np.arange(i) - 0.5 for i in pixel_shape] else: - ranges = [np.arange(i) for i in pixel_shape] + pixel_ranges = [np.arange(i) for i in pixel_shape] # Limit the pixel dimensions to the ones present in the ExtraCoords if isinstance(wcs, ExtraCoords): - ranges = [ranges[i] for i in wcs.mapping] + pixel_ranges = [pixel_ranges[i] for i in wcs.mapping] wcs = wcs.wcs if wcs is None: return [] - # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is # required so values_to_high_level_objects in the calling function doesn't crash or warn world_coords = [0] * wcs.world_n_dim for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix): - if (needed_axes is not None - and len(needed_axes) - and not any(world_axis in needed_axes for world_axis in world_axes_indices)): + if ( + needed_axes is not None + and len(needed_axes) + and all( + world_axis not in needed_axes + for world_axis in world_axes_indices + ) + ): # needed_axes indicates which values in world_coords will be used by the calling # function, so skip this iteration if we won't be producing any of those values continue # First construct a range of pixel indices for this set of coupled dimensions - sub_range = [ranges[idx] for idx in pixel_axes_indices] - # Then get a set of non correlated dimensions + pixel_ranges_subset = [pixel_ranges[idx] for idx in pixel_axes_indices] + # Then get a set of non-correlated dimensions non_corr_axes = set(list(range(wcs.pixel_n_dim))) - set(pixel_axes_indices) # And inject 0s for those coordinates for idx in non_corr_axes: - sub_range.insert(idx, 0) - # Generate a grid of broadcastable pixel indices for all pixel dimensions - grid = np.meshgrid(*sub_range, indexing='ij') + pixel_ranges_subset.insert(idx, 0) + # Generate a grid of broadcast-able pixel indices for all pixel dimensions + grid = np.meshgrid(*pixel_ranges_subset, indexing='ij') # Convert to world coordinates world = wcs.pixel_to_world_values(*grid) # TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332 @@ -500,7 +506,6 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) array_slice[wcs.axis_correlation_matrix[idx]] = slice(None) tmp_world = world[idx][tuple(array_slice)].T world_coords[idx] = tmp_world - if units: for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): world_coords[i] = coord << u.Unit(unit)