Skip to content

Commit

Permalink
ENH: Add the possibility to load an image/geo window in `rasters(_rio…
Browse files Browse the repository at this point in the history
…).read`** #1
  • Loading branch information
remi-braun committed Dec 13, 2022
1 parent 0ce6ff2 commit 89f5d26
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 162 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 1.XX.Y (20YY-MM-DD)

## 1.21.0 (2022-12-13)

- **ENH: Add the possibility to load an image/geo window in `rasters(_rio).read`** ([#1](https://github.com/sertit/eoreader/issues/1))

## 1.20.3 (2022-11-30)

- FIX: Ensure that attributes and encoding are propagated through `rasters` functions
Expand Down
Binary file added CI/DATA/rasters/window.tif
Binary file not shown.
Binary file added CI/DATA/rasters/window_20.tif
Binary file not shown.
17 changes: 17 additions & 0 deletions CI/SCRIPTS/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def test_rasters():
raster_sieved_path = rasters_path().joinpath("raster_sieved.tif")
raster_to_merge_path = rasters_path().joinpath("raster_to_merge.tif")
raster_merged_gtiff_path = rasters_path().joinpath("raster_merged.tif")
raster_window_path = rasters_path().joinpath("window.tif")
raster_window_20_path = rasters_path().joinpath("window_20.tif")

# Vectors
mask_path = rasters_path().joinpath("raster_mask.geojson")
Expand Down Expand Up @@ -141,6 +143,21 @@ def test_rasters():
assert_xr_encoding_attrs(xda, xda_4)
assert_xr_encoding_attrs(xda, xda_dask)

# ----------------------------------------------------------------------------------------------
# -- Read with window
xda_window_out = os.path.join(tmp_dir, "test_xda_window.tif")
xda_window = rasters.read(
raster_path,
window=mask_path,
)
rasters.write(xda_window, xda_window_out, dtype=np.uint8)
ci.assert_raster_equal(xda_window_out, raster_window_path)

xda_window_20_out = os.path.join(tmp_dir, "test_xda_20_window.tif")
xda_window_20 = rasters.read(raster_path, window=mask_path, resolution=20)
rasters.write(xda_window_20, xda_window_20_out, dtype=np.uint8)
ci.assert_raster_equal(xda_window_20_out, raster_window_20_path)

# ----------------------------------------------------------------------------------------------
# -- Write
# DataArray
Expand Down
18 changes: 18 additions & 0 deletions CI/SCRIPTS/test_rasters_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_rasters_rio():
raster_sieved_path = rasters_path().joinpath("raster_sieved.tif")
raster_to_merge_path = rasters_path().joinpath("raster_to_merge.tif")
raster_merged_gtiff_path = rasters_path().joinpath("raster_merged.tif")
raster_window_path = rasters_path().joinpath("window.tif")
raster_window_20_path = rasters_path().joinpath("window_20.tif")

# Vectors
mask_path = rasters_path().joinpath("raster_mask.geojson")
Expand Down Expand Up @@ -88,6 +90,22 @@ def test_rasters_rio():
np.testing.assert_array_equal(raster_1, raster_3)
np.testing.assert_array_equal(raster, raster_4) # 2D array

# -- Read with window
window_out = os.path.join(tmp_dir, "test_xda_window.tif")
window, w_mt = rasters_rio.read(
raster_path,
window=mask_path,
)
rasters_rio.write(window, w_mt, window_out)
ci.assert_raster_equal(window_out, raster_window_path)

window_20_out = os.path.join(tmp_dir, "test_xda_20_window.tif")
window_20, w_mt_20 = rasters_rio.read(
raster_path, window=mask_path, resolution=20
)
rasters_rio.write(window_20, w_mt_20, window_20_out)
ci.assert_raster_equal(window_20_out, raster_window_20_path)

# Write
raster_out = os.path.join(tmp_dir, "test.tif")
rasters_rio.write(raster, meta, raster_out)
Expand Down
100 changes: 63 additions & 37 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from rioxarray.exceptions import MissingCRS

from sertit.logs import SU_NAME
from sertit.rasters_rio import MAX_CORES, PATH_ARR_DS, bigtiff_value, path_arr_dst
from sertit.rasters_rio import (
MAX_CORES,
PATH_ARR_DS,
bigtiff_value,
get_window,
path_arr_dst,
)

try:
import geopandas as gpd
Expand Down Expand Up @@ -69,13 +75,13 @@ def path_xarr_dst(function: Callable) -> Callable:
>>> # Create mock function
>>> @path_or_dst
>>> def fct(dst):
>>> read(dst)
>>> def fct(ds):
>>> read(ds)
>>>
>>> # Test the two ways
>>> read1 = fct("path/to/raster.tif")
>>> with rasterio.open("path/to/raster.tif") as dst:
>>> read2 = fct(dst)
>>> with rasterio.open("path/to/raster.tif") as ds:
>>> read2 = fct(ds)
>>>
>>> # Test
>>> read1 == read2
Expand Down Expand Up @@ -339,8 +345,8 @@ def vectorize(
>>> raster_path = "path/to/raster.tif"
>>> vec1 = vectorize(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> vec2 = vectorize(dst)
>>> with rasterio.open(raster_path) as ds:
>>> vec2 = vectorize(ds)
>>> vec1 == vec2
True
Expand Down Expand Up @@ -376,8 +382,8 @@ def get_valid_vector(xds: PATH_XARR_DS, default_nodata: int = 0) -> gpd.GeoDataF
>>> raster_path = "path/to/raster.tif"
>>> nodata1 = get_nodata_vec(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> nodata2 = get_nodata_vec(dst)
>>> with rasterio.open(raster_path) as ds:
>>> nodata2 = get_nodata_vec(ds)
>>> nodata1 == nodata2
True
Expand All @@ -397,7 +403,7 @@ def get_valid_vector(xds: PATH_XARR_DS, default_nodata: int = 0) -> gpd.GeoDataF


@path_xarr_dst
def get_nodata_vector(dst: PATH_ARR_DS, default_nodata: int = 0) -> gpd.GeoDataFrame:
def get_nodata_vector(ds: PATH_ARR_DS, default_nodata: int = 0) -> gpd.GeoDataFrame:
"""
Get the nodata vector of a raster as a vector.
Expand All @@ -409,21 +415,19 @@ def get_nodata_vector(dst: PATH_ARR_DS, default_nodata: int = 0) -> gpd.GeoDataF
>>> raster_path = "path/to/raster.tif" # Classified raster, with no data set to 255
>>> nodata1 = get_nodata_vec(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> nodata2 = get_nodata_vec(dst)
>>> with rasterio.open(raster_path) as ds:
>>> nodata2 = get_nodata_vec(ds)
>>> nodata1 == nodata2
True
Args:
dst (PATH_ARR_DS): Path to the raster, its dataset, its :code:`xarray` or a tuple containing its array and metadata
ds (PATH_ARR_DS): Path to the raster, its dataset, its :code:`xarray` or a tuple containing its array and metadata
default_nodata (int): Default values for nodata in case of non existing in file
Returns:
gpd.GeoDataFrame: Nodata Vector
"""
nodata = _vectorize(
dst, values=None, get_nodata=True, default_nodata=default_nodata
)
nodata = _vectorize(ds, values=None, get_nodata=True, default_nodata=default_nodata)
return nodata[nodata.raster_val == 0]


Expand Down Expand Up @@ -452,8 +456,8 @@ def mask(
>>> shapes = gpd.read_file(shape_path)
>>> mask1 = mask(raster_path, shapes)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> mask2 = mask(dst, shapes)
>>> with rasterio.open(raster_path) as ds:
>>> mask2 = mask(ds, shapes)
>>> mask1 == mask2
True
Expand Down Expand Up @@ -505,8 +509,8 @@ def paint(
>>> shapes = gpd.read_file(shape_path)
>>> paint1 = paint(raster_path, shapes, value=100)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> paint2 = paint(dst, shapes, value=100)
>>> with rasterio.open(raster_path) as ds:
>>> paint2 = paint(ds, shapes, value=100)
>>> paint1 == paint2
True
Expand Down Expand Up @@ -572,8 +576,8 @@ def crop(
>>> shapes = gpd.read_file(shape_path)
>>> xds2 = crop(raster_path, shapes)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> xds2 = crop(dst, shapes)
>>> with rasterio.open(raster_path) as ds:
>>> xds2 = crop(ds, shapes)
>>> xds1 == xds2
True
Expand Down Expand Up @@ -604,9 +608,10 @@ def crop(

@path_arr_dst
def read(
dst: PATH_ARR_DS,
ds: PATH_ARR_DS,
resolution: Union[tuple, list, float] = None,
size: Union[tuple, list] = None,
window: Any = None,
resampling: Resampling = Resampling.nearest,
masked: bool = True,
indexes: Union[int, list] = None,
Expand Down Expand Up @@ -637,15 +642,16 @@ def read(
>>> raster_path = "path/to/raster.tif"
>>> xds1 = read(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> xds2 = read(dst)
>>> with rasterio.open(raster_path) as ds:
>>> xds2 = read(ds)
>>> xds1 == xds2
True
Args:
dst (PATH_ARR_DS): Path to the raster or a rasterio dataset or a xarray
ds (PATH_ARR_DS): Path to the raster or a rasterio dataset or a xarray
resolution (Union[tuple, list, float]): Resolution of the wanted band, in dataset resolution unit (X, Y)
size (Union[tuple, list]): Size of the array (width, height). Not used if resolution is provided.
window (Any): Anything that can be returned as a window. In case of iterable, assumption is made it's geographic bounds. For pixel, please provide a Window directly.
resampling (Resampling): Resampling method
masked (bool): Get a masked array
indexes (Union[int, list]): Indexes to load. Load the whole array if None. Starts at 1.
Expand All @@ -659,16 +665,27 @@ def read(
Union[XDS_TYPE]: Masked xarray corresponding to the raster data and its meta data
"""
if window is not None:
window = get_window(ds, window)

# Get new height and width
new_height, new_width = rasters_rio.get_new_shape(dst, resolution, size)
new_height, new_width, do_resampling = rasters_rio.get_new_shape(
ds, resolution, size, window
)

# Read data (and load it to discard lock)
with xarray.set_options(keep_attrs=True):
with rioxarray.set_options(export_grid_mapping=False):
with rioxarray.open_rasterio(
dst, default_name=files.get_filename(dst.name), chunks=chunks, **kwargs
ds, default_name=files.get_filename(ds.name), chunks=chunks, **kwargs
) as xda:
orig_dtype = xda.dtype

# Windows
if window is not None:
xda = xda.rio.isel_window(window).load()

# Indexes
if indexes is not None:
if not isinstance(indexes, list):
indexes = [indexes]
Expand All @@ -680,7 +697,7 @@ def read(
ok_indexes = np.isin(indexes, xda.band)
if any(~ok_indexes):
LOGGER.warning(
f"Non available index: {[idx for i, idx in enumerate(indexes) if not ok_indexes[i]]} for {dst.name}"
f"Non available index: {[idx for i, idx in enumerate(indexes) if not ok_indexes[i]]} for {ds.name}"
)

xda = xda[np.isin(xda.band, indexes)]
Expand All @@ -696,9 +713,16 @@ def read(
pass

# Manage resampling
if new_height != dst.height or new_width != dst.width:
factor_h = dst.height / new_height
factor_w = dst.width / new_width
if do_resampling:
if window is not None:
factor_h = window.height / new_height
factor_w = window.width / new_width
else:
factor_h = ds.height / new_height
factor_w = ds.width / new_width

# Manage 2 ways of resampling, coarsen being faster than reprojection
# TODO: find a way to match rasterio's speed
if factor_h.is_integer() and factor_w.is_integer():
xda = xda.coarsen(x=int(factor_w), y=int(factor_h)).mean()
else:
Expand All @@ -708,14 +732,16 @@ def read(
resampling=resampling,
)

# Convert to wanted type
if as_type:
# Modify the type as wanted by the user
# TODO: manage nodata and uint/int numbers
xda = xda.astype(as_type)

# Mask if necessary
if masked:
# Set nodata not in opening due to some performance issues
xda = set_nodata(xda, dst.meta["nodata"])
xda = set_nodata(xda, ds.meta["nodata"])

# Set original dtype
xda.encoding["dtype"] = orig_dtype
Expand Down Expand Up @@ -958,8 +984,8 @@ def get_extent(xds: PATH_XARR_DS) -> gpd.GeoDataFrame:
>>> extent1 = get_extent(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> extent2 = get_extent(dst)
>>> with rasterio.open(raster_path) as ds:
>>> extent2 = get_extent(ds)
>>> extent1 == extent2
True
Expand All @@ -984,8 +1010,8 @@ def get_footprint(xds: PATH_XARR_DS) -> gpd.GeoDataFrame:
>>> footprint1 = get_footprint(raster_path)
>>> # or
>>> with rasterio.open(raster_path) as dst:
>>> footprint2 = get_footprint(dst)
>>> with rasterio.open(raster_path) as ds:
>>> footprint2 = get_footprint(ds)
>>> footprint1 == footprint2
Args:
Expand Down
Loading

0 comments on commit 89f5d26

Please sign in to comment.