Skip to content

Commit

Permalink
undo style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nabobalis committed Nov 26, 2024
1 parent 941216f commit 1ddcc3f
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,40 +467,36 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
pixel_shape = self.data.shape[::-1]
if pixel_corners:
pixel_shape = tuple(np.array(pixel_shape) + 1)
pixel_ranges = [np.arange(i) - 0.5 for i in pixel_shape]
ranges = [np.arange(i) - 0.5 for i in pixel_shape]
else:
pixel_ranges = [np.arange(i) for i in pixel_shape]
ranges = [np.arange(i) for i in pixel_shape]

# Limit the pixel dimensions to the ones present in the ExtraCoords
if isinstance(wcs, ExtraCoords):
pixel_ranges = [pixel_ranges[i] for i in wcs.mapping]
ranges = [ranges[i] for i in wcs.mapping]
wcs = wcs.wcs
if wcs is None:
return []

# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
if (
needed_axes is not None
and len(needed_axes)
and all(
world_axis not in needed_axes
for world_axis in world_axes_indices
)
):
if (needed_axes is not None
and len(needed_axes)
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
# needed_axes indicates which values in world_coords will be used by the calling
# function, so skip this iteration if we won't be producing any of those values
continue
# First construct a range of pixel indices for this set of coupled dimensions
pixel_ranges_subset = [pixel_ranges[idx] for idx in pixel_axes_indices]
# Then get a set of non-correlated dimensions
sub_range = [ranges[idx] for idx in pixel_axes_indices]
# Then get a set of non correlated dimensions
non_corr_axes = set(range(wcs.pixel_n_dim)) - set(pixel_axes_indices)
# And inject 0s for those coordinates
for idx in non_corr_axes:
pixel_ranges_subset.insert(idx, 0)
# Generate a grid of broadcast-able pixel indices for all pixel dimensions
grid = np.meshgrid(*pixel_ranges_subset, indexing='ij')
sub_range.insert(idx, 0)
# Generate a grid of broadcastable pixel indices for all pixel dimensions
grid = np.meshgrid(*sub_range, indexing='ij')
# Convert to world coordinates
world = wcs.pixel_to_world_values(*grid)
# TODO: this isinstance check is to mitigate https://github.com/spacetelescope/gwcs/pull/332
Expand All @@ -513,6 +509,7 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units)
array_slice[wcs.axis_correlation_matrix[idx]] = slice(None)
tmp_world = world[idx][tuple(array_slice)].T
world_coords[idx] = tmp_world

if units:
for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
world_coords[i] = coord << u.Unit(unit)
Expand All @@ -524,47 +521,61 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()

object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
unique_obj_names = utils.misc.unique_sorted(object_names)
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]

# Create a mapping from world index in the WCS to object index in axes_coords
world_index_to_object_index = {}
for object_index, world_axes in enumerate(world_axes_for_obj):
for world_index in world_axes:
world_index_to_object_index[world_index] = object_index

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
object_indices = utils.misc.unique_sorted(
[world_index_to_object_index[world_index] for world_index in world_indices]
)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False)

axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

return tuple(axes_coords[i] for i in object_indices)

@utils.cube.sanitize_wcs
def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()

Check warning on line 565 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L565

Added line #L565 was not covered by tests

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True)

world_axis_physical_types = wcs.world_axis_physical_types

# If user has supplied axes, extract only the
# world coords that correspond to those axes.
if axes:
axes_coords = [axes_coords[i] for i in world_indices]
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])

# Return in array order.
# First replace characters in physical types forbidden for namedtuple identifiers.
identifiers = []
Expand Down

0 comments on commit 1ddcc3f

Please sign in to comment.