From b6fad67b07875a85a20954710dee14c38cffd5d0 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Wed, 22 Nov 2023 13:50:16 +0000 Subject: [PATCH] Ready for batch (#44) Merging for testing on batch. * Integrate tiler and s3 upload to data pipeline * Remove unused file --- scripts/datacube.py | 196 ++++++++++++++++++++++++------------------ scripts/stack_tile.py | 13 --- scripts/tile.py | 129 +++++++++++++++------------ 3 files changed, 187 insertions(+), 151 deletions(-) mode change 100644 => 100755 scripts/datacube.py delete mode 100644 scripts/stack_tile.py diff --git a/scripts/datacube.py b/scripts/datacube.py old mode 100644 new mode 100755 index 6ad230c2..18db5c53 --- a/scripts/datacube.py +++ b/scripts/datacube.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 """ STAC Data Processing Script @@ -19,17 +20,17 @@ and date range. - search_dem(bbox, catalog): Search for DEM items within a given bounding box. -- make_datasets(s2_items, s1_items, dem_items, resolution): +- make_datasets(s2_item, s1_items, dem_items, resolution): Create xarray Datasets for Sentinel-2, Sentinel-1, and DEM data. - process(aoi, start_year, end_year, resolution, cloud_cover_percentage, nodata_pixel_percentage): Process Sentinel-2, Sentinel-1, and DEM data for a specified time range, area of interest, and resolution. """ - import random from datetime import timedelta +import click import geopandas as gpd import numpy as np import planetary_computer as pc @@ -38,12 +39,15 @@ import xarray as xr from pystac import ItemCollection from shapely.geometry import box +from tile import tiler STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1" S2_BANDS = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12", "SCL"] SPATIAL_RESOLUTION = 10 CLOUD_COVER_PERCENTAGE = 50 NODATA_PIXEL_PERCENTAGE = 20 +NODATA = 0 +S1_MATCH_ATTEMPTS = 10 def get_surrounding_days(reference, interval_days): @@ -63,12 +67,15 @@ def get_surrounding_days(reference, interval_days): return f"{start.date()}/{end.date()}" -def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage): +def search_sentinel2( + catalog, date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage, index=0 +): """ Search for Sentinel-2 items within a given date range and area of interest (AOI) with specified conditions. Parameters: + - catalog (pystac.Catalog): STAC catalog containing Sentinel-2 items. - date_range (str): The date range in the format 'start_date/end_date'. - aoi (shapely.geometry.base.BaseGeometry): Geometry object for an Area of Interest (AOI). @@ -76,6 +83,7 @@ def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_perce for Sentinel-2 images. - nodata_pixel_percentage (int): Maximum acceptable percentage of nodata pixels in Sentinel-2 images. + - index: Which of the found scenes to select Returns: - tuple: A tuple containing the STAC catalog, Sentinel-2 items, and the @@ -88,8 +96,6 @@ def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_perce as a tuple containing the STAC catalog, Sentinel-2 items, the bounding box of the first item, and an EPSG code for the coordinate reference system. """ - catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace) - search: pystac_client.item_search.ItemSearch = catalog.search( filter_lang="cql2-json", filter={ @@ -121,27 +127,23 @@ def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_perce s2_items_gdf = gpd.GeoDataFrame.from_features(s2_items.to_dict()) - least_clouds = s2_items_gdf.sort_values( - by=["eo:cloud_cover"], ascending=True - ).index[0] - - s2_items_gdf = s2_items_gdf.iloc[least_clouds] + least_clouds = s2_items_gdf.sort_values(by=["eo:cloud_cover"], ascending=True).iloc[ + index + ] # Get the datetime for the filtered Sentinel 2 dataframe # containing the least nodata and least cloudy scene - s2_items_gdf_datetime = s2_items_gdf["datetime"] for item in s2_items: - if item.properties["datetime"] == s2_items_gdf_datetime: + if item.properties["datetime"] == least_clouds["datetime"]: s2_item = item - else: - continue + break - bbox = s2_items_gdf.iloc[0].bounds + bbox = least_clouds.geometry.bounds epsg = s2_item.properties["proj:epsg"] print("EPSG code based on Sentinel-2 item: ", epsg) - return catalog, s2_item, bbox + return s2_item, bbox def search_sentinel1(bbox, catalog, date_range): @@ -185,9 +187,8 @@ def search_sentinel1(bbox, catalog, date_range): s1_items = search.item_collection() print(f"Found {len(s1_items)} Sentinel-1 items") - if s1_items is None: - return False - + if not len(s1_items): + return else: # Add id as property to persist in gdf for item in s1_items: @@ -271,37 +272,6 @@ def make_datasets(s2_items, s1_items, dem_items, resolution): fill_value=np.nan, ) - # Create xarray.Dataset datacube with all 10m and 20m bands from Sentinel-2 - da_s2_0: xr.DataArray = da_sen2.sel(band="B02", drop=True).rename("B02").squeeze() - da_s2_1: xr.DataArray = da_sen2.sel(band="B03", drop=True).rename("B03").squeeze() - da_s2_2: xr.DataArray = da_sen2.sel(band="B04", drop=True).rename("B04").squeeze() - da_s2_3: xr.DataArray = da_sen2.sel(band="B05", drop=True).rename("B05").squeeze() - da_s2_4: xr.DataArray = da_sen2.sel(band="B06", drop=True).rename("B06").squeeze() - da_s2_5: xr.DataArray = da_sen2.sel(band="B07", drop=True).rename("B07").squeeze() - da_s2_6: xr.DataArray = da_sen2.sel(band="B08", drop=True).rename("B08").squeeze() - da_s2_7: xr.DataArray = da_sen2.sel(band="B8A", drop=True).rename("B8A").squeeze() - da_s2_8: xr.DataArray = da_sen2.sel(band="B11", drop=True).rename("B11").squeeze() - da_s2_9: xr.DataArray = da_sen2.sel(band="B11", drop=True).rename("B11").squeeze() - da_s2_10: xr.DataArray = da_sen2.sel(band="SCL", drop=True).rename("SCL").squeeze() - - ds_sen2: xr.Dataset = xr.merge( - objects=[ - da_s2_0, - da_s2_1, - da_s2_2, - da_s2_3, - da_s2_4, - da_s2_5, - da_s2_6, - da_s2_7, - da_s2_8, - da_s2_9, - da_s2_10, - ], - join="override", - ) - ds_sen2.assign(time=da_sen2.time) - da_sen1: xr.DataArray = stackstac.stack( items=s1_items, assets=["vh", "vv"], @@ -312,12 +282,6 @@ def make_datasets(s2_items, s1_items, dem_items, resolution): fill_value=np.nan, ) - # Create xarray.Dataset datacube with VH and VV channels from SAR - da_sen1 = stackstac.mosaic(da_sen1, dim="time") - da_vh: xr.DataArray = da_sen1.sel(band="vh", drop=True).squeeze().rename("vh") - da_vv: xr.DataArray = da_sen1.sel(band="vv", drop=True).squeeze().rename("vv") - ds_sen1: xr.Dataset = xr.merge(objects=[da_vh, da_vv], join="override") - da_dem: xr.DataArray = stackstac.stack( items=dem_items, epsg=int(da_sen2.epsg), @@ -327,9 +291,27 @@ def make_datasets(s2_items, s1_items, dem_items, resolution): fill_value=np.nan, ) - da_dem: xr.DataArray = stackstac.mosaic(da_dem, dim="time").squeeze().rename("DEM") + da_sen1: xr.DataArray = stackstac.mosaic(da_sen1, dim="time") + + da_sen1 = da_sen1.drop_vars( + [var for var in da_sen1.coords if var not in da_sen1.dims] + ) + + da_sen2 = da_sen2.drop_vars( + [var for var in da_sen2.coords if var not in da_sen2.dims] + ).squeeze() + + del da_sen2.coords["time"] + + da_dem: xr.DataArray = stackstac.mosaic(da_dem, dim="time").assign_coords( + {"band": ["dem"]} + ) + + da_dem = da_dem.drop_vars([var for var in da_dem.coords if var not in da_dem.dims]) - return ds_sen2, ds_sen1, da_dem + result = xr.concat([da_sen2, da_sen1, da_dem], dim="band") + result = result.rename("tile") + return result def process( @@ -361,54 +343,102 @@ def process( """ year = random.randint(start_year, end_year) date_range = f"{year}-01-01/{year}-12-31" - catalog, s2_items, bbox = search_sentinel2( - date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage - ) - - surrounding_days = get_surrounding_days(s2_items.datetime, interval_days=3) - print("Searching S1 in date range", surrounding_days) - - s1_items = search_sentinel1(bbox, catalog, surrounding_days) + catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace) - if not s1_items: - catalog, s2_items, bbox = search_sentinel2( - date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage + for i in range(S1_MATCH_ATTEMPTS): + s2_item, bbox = search_sentinel2( + catalog, + date_range, + aoi, + cloud_cover_percentage, + nodata_pixel_percentage, + index=i, ) - surrounding_days = get_surrounding_days(s2_items.datetime, interval_days=3) + surrounding_days = get_surrounding_days(s2_item.datetime, interval_days=3) print("Searching S1 in date range", surrounding_days) + s1_items = search_sentinel1(bbox, catalog, surrounding_days) + if s1_items: + break + + if i == S1_MATCH_ATTEMPTS - 1: + raise ValueError( + f"No match for S1 scenes found after {S1_MATCH_ATTEMPTS} attempts." + ) + dem_items = search_dem(bbox, catalog) - ds_sen2, ds_sen1, da_dem = make_datasets( - s2_items, + date = s2_item.properties["datetime"][:10] + + result = make_datasets( + s2_item, s1_items, dem_items, resolution, ) - ds_merge = xr.merge([ds_sen2, ds_sen1, da_dem], compat="override") + return date, result - return ds_merge +def convert_attrs_and_coords_objects_to_str(data): + """ + Convert attributes and coordinates that are objects to + strings. -def main(): - tiles = gpd.read_file("scripts/data/mgrs_sample.geojson") - sample = tiles.sample(1, random_state=45) - aoi = sample.iloc[0].geometry + This is required for storing the xarray in netcdf. + """ + for key, coord in data.coords.items(): + if coord.dtype == "object": + data.coords[key] = str(coord.values) + + for key, attr in data.attrs.items(): + data.attrs[key] = str(attr) + + for key, var in data.variables.items(): + var.attrs = {} + + +@click.command() +@click.option( + "--index", + required=True, + default=42, + help="Index of MGRS tile from sample file that should be processed", +) +@click.option( + "--subset", + required=False, + help="For debugging, subset x and y to this pixel window.", +) +def main(index, subset): + print("Starting algorithm", index) + index = int(index) + tiles = gpd.read_file("mgrs_sample.geojson") + tile = tiles.iloc[index] start_year = 2017 end_year = 2023 - - merged = process( - aoi, + date, merged = process( + tile.geometry, start_year, end_year, SPATIAL_RESOLUTION, CLOUD_COVER_PERCENTAGE, NODATA_PIXEL_PERCENTAGE, ) - return merged + mgrs = tile["name"] + if subset: + subset = [int(dat) for dat in subset.split(",")] + print(f"Subsetting to {subset}") + merged = merged.sel( + x=slice(merged.x.values[subset[0]], merged.x.values[subset[2]]), + y=slice(merged.y.values[subset[1]], merged.y.values[subset[3]]), + ) + merged = merged.compute() + + tiler(merged, date, mgrs) -# main() +if __name__ == "__main__": + main() diff --git a/scripts/stack_tile.py b/scripts/stack_tile.py deleted file mode 100644 index 85edc8dc..00000000 --- a/scripts/stack_tile.py +++ /dev/null @@ -1,13 +0,0 @@ -from datacube import main -from tile import tiler - - -def run_stack_tile(): - stack = main() - tiles = tiler(stack) - print("Stack: ", stack) - return tiles - - -tiles = run_stack_tile() -print("Number of tiles generated: ", len(tiles)) diff --git a/scripts/tile.py b/scripts/tile.py index b246a08c..d36a870e 100644 --- a/scripts/tile.py +++ b/scripts/tile.py @@ -5,17 +5,25 @@ stacks into smaller tiles, while filtering out tiles with high cloud coverage or no-data pixels. -It includes functions to filter tiles based on cloud coverage and no-data -pixels, and a tiling function that generates smaller tiles from the input -stack. +It includes functions to filter tiles based on cloud coverage and no-data pixels, +and a tiling function that generates smaller tiles from the input stack. """ +import os +import subprocess +import tempfile +import numpy as np +import rasterio +import rioxarray # noqa: F401 +from rasterio.enums import ColorInterp NODATA = 0 TILE_SIZE = 256 PIXELS_PER_TILE = TILE_SIZE * TILE_SIZE BAD_PIXEL_MAX_PERCENTAGE = 0.9 SCL_FILTER = [0, 1, 3, 8, 9, 10] +EPSILON = 0.1 +VERSION = "01" def filter_clouds_nodata(tile): @@ -29,13 +37,13 @@ def filter_clouds_nodata(tile): - bool: True if the tile is approved, False if rejected. """ # Check for nodata pixels - nodata_pixel_count = int(tile.B02.isin([NODATA]).sum()) + nodata_pixel_count = int(tile.sel(band="B02").isin([NODATA]).sum()) if nodata_pixel_count: print("Too much no-data") return False # Check for cloud coverage - cloudy_pixel_count = int(tile.SCL.isin(SCL_FILTER).sum()) + cloudy_pixel_count = int(tile.sel(band="SCL").isin(SCL_FILTER).sum()) if cloudy_pixel_count / PIXELS_PER_TILE >= BAD_PIXEL_MAX_PERCENTAGE: print("Too much cloud coverage") return False @@ -43,65 +51,76 @@ def filter_clouds_nodata(tile): return True # If both conditions pass -def tiler(stack): +def tiler(stack, date, mgrs): """ Function to tile a multi-dimensional imagery stack while filtering out tiles with high cloud coverage or no-data pixels. Args: - stack (xarray.Dataset): The input multi-dimensional imagery stack. - - Returns: - - list: A list containing approved tiles with specified dimensions. + - date (str): Date string yyyy-mm-dd + - mgrs (Str): MGRS Tile id """ # Calculate the number of full tiles in x and y directions num_x_tiles = stack.x.size // TILE_SIZE num_y_tiles = stack.y.size // TILE_SIZE - # Calculate the remaining sizes in x and y directions - remainder_x = stack.x.size % TILE_SIZE - remainder_y = stack.y.size % TILE_SIZE - - # Create a list to hold the tiles - tiles = [] - - # Counter for tiles - tile_count = 0 - - # Counter for bad tiles - # bad_tile_count = 0 - - # Iterate through each chunk of x and y dimensions and create tiles - for y_idx in range(num_y_tiles + 1 if remainder_y > 0 else num_y_tiles): - for x_idx in range(num_x_tiles + 1 if remainder_x > 0 else num_x_tiles): - # Calculate the start and end indices - # for x and y dimensions of the current tile - x_start = x_idx * TILE_SIZE - y_start = y_idx * TILE_SIZE - x_end = min((x_idx + 1) * TILE_SIZE, stack.x.size) - y_end = min((y_idx + 1) * TILE_SIZE, stack.y.size) - # print("x_start, y_start, x_end, y_end: ", x_start, y_start, x_end, y_end) - - # Select the subset of data for the current tile - tile = stack.sel( - x=slice(stack.x.values[x_start], stack.x.values[x_end - 1]), - y=slice(stack.y.values[y_start], stack.y.values[y_end - 1]), - ) - tile_spatial_dims = tuple(tile.dims[d] for d in ["x", "y"]) - if tile_spatial_dims[0] == TILE_SIZE and tile_spatial_dims[1] == TILE_SIZE: - tile_count = tile_count + 1 - """ - print( - "Tile size: ", - tuple(tile.dims[d] for d in ["x", "y"]), - "; tile count: ", - tile_count, + bucket = os.environ.get("TARGET_BUCKET", "whis-imagery") + + counter = 0 + with tempfile.TemporaryDirectory() as dir: + print("Writing tempfiles to ", dir) + # Iterate through each chunk of x and y dimensions and create tiles + for y_idx in range(num_y_tiles): + for x_idx in range(num_x_tiles): + counter += 1 + print(f"Counted {counter} tiles") + + # Calculate the start and end indices for x and y dimensions + # for the current tile + x_start = x_idx * TILE_SIZE + y_start = y_idx * TILE_SIZE + x_end = x_start + TILE_SIZE + y_end = y_start + TILE_SIZE + + # Select the subset of data for the current tile + tile = stack.sel( + x=slice( + stack.x.values[x_start], + stack.x.values[x_end] + + np.sign(stack.rio.transform()[4]) * EPSILON, + ), + y=slice( + stack.y.values[y_start], + stack.y.values[y_end] + + np.sign(stack.rio.transform()[0]) * EPSILON, + ), ) - """ - # Check for clouds and nodata - if filter_clouds_nodata(tile): - # Append the tile to the list - tiles.append(tile) - # print(f"{bad_tile_count} tiles removed due to clouds or nodata") - # 'tiles' now contains tiles with 256x256 pixels for x and y - return tiles + + if not filter_clouds_nodata(tile): + continue + + tile = tile.drop_sel(band="SCL") + + # Track band names and color interpretation + tile.attrs["long_name"] = [str(x.values) for x in tile.band] + color = [ColorInterp.blue, ColorInterp.green, ColorInterp.red] + [ + ColorInterp.gray + ] * (len(tile.band) - 3) + + # Write tile to tempdir + name = f"{dir}/claytile-{mgrs}-{date}-{VERSION}-{counter}.tif" + tile.rio.to_raster(name, compress="deflate") + + with rasterio.open(name, "r+") as rst: + rst.colorinterp = color + + import shutil + + shutil.copytree(dir, "/home/tam/Desktop/claytiles", dirs_exist_ok=True) + + print(f"Syncing {dir} with s3://{bucket}/clay/{VERSION}/{mgrs}/{date}") + subprocess.run( + ["aws", "s3", "sync", dir, f"s3://{bucket}/clay/{VERSION}/{mgrs}/{date}"], + check=True, + )