Skip to content

Commit

Permalink
feat: add tiling function for raster
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Feb 13, 2025
1 parent 6482e77 commit c1c6715
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 2 deletions.
47 changes: 47 additions & 0 deletions geoutils/raster/georeferencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from __future__ import annotations

import math
import warnings
from typing import Iterable, Literal

Expand Down Expand Up @@ -284,3 +285,49 @@ def _cast_nodata(out_dtype: DTypeLike, nodata: int | float | None) -> int | floa
nodata = nodata

return nodata


def _generate_tiling_grid(
row_min: int,
col_min: int,
row_max: int,
col_max: int,
row_split: int,
col_split: int,
overlap: int = 0,
) -> NDArrayNum:
"""
Generate a grid of positions by splitting [row_min, row_max] x
[col_min, col_max] into tiles of size row_split x col_split with optional overlap.
:param row_min: Minimum row index of the bounding box to split.
:param col_min: Minimum column index of the bounding box to split.
:param row_max: Maximum row index of the bounding box to split.
:param col_max: Maximum column index of the bounding box to split.
:param row_split: Height of each tile.
:param col_split: Width of each tile.
:param overlap: size of overlapping between tiles (both vertically and horizontally).
:return: A numpy array grid with splits in two dimensions (0: row, 1: column),
where each cell contains [row_min, row_max, col_min, col_max].
"""
# Calculate the number of splits considering overlap
nb_col_split = math.ceil((col_max - col_min) / (col_split - overlap))
nb_row_split = math.ceil((row_max - row_min) / (row_split - overlap))

# Initialize the output grid
tiling_grid = np.zeros(shape=(nb_row_split, nb_col_split, 4), dtype=int)

for row in range(nb_row_split):
for col in range(nb_col_split):
# Calculate the start of the tile
row_start = row_min + row * (row_split - overlap)
col_start = col_min + col * (col_split - overlap)

# Calculate the end of the tile ensuring it doesn't exceed the bounds
row_end = min(row_max, row_start + row_split)
col_end = min(col_max, col_start + col_split)

# Populate the grid with the tile boundaries
tiling_grid[row, col] = [row_start, row_end, col_start, col_end]

return tiling_grid
75 changes: 73 additions & 2 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import rioxarray
import xarray as xr
from affine import Affine
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
from packaging.version import Version
from rasterio.crs import CRS
Expand Down Expand Up @@ -77,6 +78,7 @@
_cast_pixel_interpretation,
_coords,
_default_nodata,
_generate_tiling_grid,
_ij2xy,
_outside_image,
_res,
Expand Down Expand Up @@ -2655,6 +2657,44 @@ def translate(
raster_copy.transform = translated_transform
return raster_copy

def compute_tiling(
self,
tile_size: int,
raster_ref: RasterType,
overlap: int = 0,
) -> NDArrayNum:
"""
Compute the raster tiling grid to coregister raster by block.
:param tile_size: Size of each tile (square tiles)
:param raster_ref: The other raster to coregister, use to validate the shape
:param overlap: Size of overlap between tiles (optional)
:return: tiling_grid (array of tile boundaries), new_shape (shape of the tiled grid)
"""
if self.shape != raster_ref.shape:
raise Exception("Reference and secondary rasters do not have the same shape")
row_max, col_max = self.shape

# Generate tiling
tiling_grid = _generate_tiling_grid(0, 0, row_max, col_max, tile_size, tile_size, overlap=overlap)
return tiling_grid

def plot_tiling(self, tiling_grid: NDArrayNum) -> None:
"""
Plot raster with its tiling.
:param tiling_grid: tiling given by Raster.compute_tiling.
"""
ax, caxes = self.plot(return_axes=True)
for tile in tiling_grid.reshape(-1, 4):
row_min, row_max, col_min, col_max = tile
x_min, y_min = self.transform * (col_min, row_min) # Bottom-left corner
x_max, y_max = self.transform * (col_max, row_max) # Top-right corne
rect = Rectangle(
(x_min, y_min), x_max - x_min, y_max - y_min, edgecolor="red", facecolor="none", linewidth=1.5
)
ax.add_patch(rect)

def save(
self,
filename: str | pathlib.Path | IO[bytes],
Expand Down Expand Up @@ -2922,6 +2962,38 @@ def intersection(self, raster: str | Raster, match_ref: bool = True) -> tuple[fl
# mypy raises a type issue, not sure how to address the fact that output of merge_bounds can be ()
return intersection # type: ignore

@overload
def plot(
self,
bands: int | tuple[int, ...] | None = None,
cmap: matplotlib.colors.Colormap | str | None = None,
vmin: float | int | None = None,
vmax: float | int | None = None,
alpha: float | int | None = None,
cbar_title: str | None = None,
add_cbar: bool = True,
ax: matplotlib.axes.Axes | Literal["new"] | None = None,
*,
return_axes: Literal[False] = False,
**kwargs: Any,
) -> None: ...

@overload
def plot(
self,
bands: int | tuple[int, ...] | None = None,
cmap: matplotlib.colors.Colormap | str | None = None,
vmin: float | int | None = None,
vmax: float | int | None = None,
alpha: float | int | None = None,
cbar_title: str | None = None,
add_cbar: bool = True,
ax: matplotlib.axes.Axes | Literal["new"] | None = None,
*,
return_axes: Literal[True],
**kwargs: Any,
) -> tuple[matplotlib.axes.Axes, matplotlib.colors.Colormap]: ...

def plot(
self,
bands: int | tuple[int, ...] | None = None,
Expand Down Expand Up @@ -3059,8 +3131,7 @@ def plot(
# If returning axes
if return_axes:
return ax0, cax
else:
return None
return None

def reduce_points(
self,
Expand Down

0 comments on commit c1c6715

Please sign in to comment.