Skip to content

Commit

Permalink
Ready for batch (#44)
Browse files Browse the repository at this point in the history
Merging for testing on batch.

* Integrate tiler and s3 upload to data pipeline

* Remove unused file
  • Loading branch information
yellowcap authored Nov 22, 2023
1 parent 17f4698 commit b6fad67
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 151 deletions.
196 changes: 113 additions & 83 deletions scripts/datacube.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
"""
STAC Data Processing Script
Expand All @@ -19,17 +20,17 @@
and date range.
- search_dem(bbox, catalog):
Search for DEM items within a given bounding box.
- make_datasets(s2_items, s1_items, dem_items, resolution):
- make_datasets(s2_item, s1_items, dem_items, resolution):
Create xarray Datasets for Sentinel-2, Sentinel-1, and DEM data.
- process(aoi, start_year, end_year, resolution, cloud_cover_percentage,
nodata_pixel_percentage):
Process Sentinel-2, Sentinel-1, and DEM data for a specified time range,
area of interest, and resolution.
"""

import random
from datetime import timedelta

import click
import geopandas as gpd
import numpy as np
import planetary_computer as pc
Expand All @@ -38,12 +39,15 @@
import xarray as xr
from pystac import ItemCollection
from shapely.geometry import box
from tile import tiler

STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
S2_BANDS = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12", "SCL"]
SPATIAL_RESOLUTION = 10
CLOUD_COVER_PERCENTAGE = 50
NODATA_PIXEL_PERCENTAGE = 20
NODATA = 0
S1_MATCH_ATTEMPTS = 10


def get_surrounding_days(reference, interval_days):
Expand All @@ -63,19 +67,23 @@ def get_surrounding_days(reference, interval_days):
return f"{start.date()}/{end.date()}"


def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage):
def search_sentinel2(
catalog, date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage, index=0
):
"""
Search for Sentinel-2 items within a given date range and area of interest (AOI)
with specified conditions.
Parameters:
- catalog (pystac.Catalog): STAC catalog containing Sentinel-2 items.
- date_range (str): The date range in the format 'start_date/end_date'.
- aoi (shapely.geometry.base.BaseGeometry): Geometry object for an Area of
Interest (AOI).
- cloud_cover_percentage (int): Maximum acceptable cloud cover percentage
for Sentinel-2 images.
- nodata_pixel_percentage (int): Maximum acceptable percentage of nodata
pixels in Sentinel-2 images.
- index: Which of the found scenes to select
Returns:
- tuple: A tuple containing the STAC catalog, Sentinel-2 items, and the
Expand All @@ -88,8 +96,6 @@ def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_perce
as a tuple containing the STAC catalog, Sentinel-2 items, the bounding box
of the first item, and an EPSG code for the coordinate reference system.
"""
catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace)

search: pystac_client.item_search.ItemSearch = catalog.search(
filter_lang="cql2-json",
filter={
Expand Down Expand Up @@ -121,27 +127,23 @@ def search_sentinel2(date_range, aoi, cloud_cover_percentage, nodata_pixel_perce

s2_items_gdf = gpd.GeoDataFrame.from_features(s2_items.to_dict())

least_clouds = s2_items_gdf.sort_values(
by=["eo:cloud_cover"], ascending=True
).index[0]

s2_items_gdf = s2_items_gdf.iloc[least_clouds]
least_clouds = s2_items_gdf.sort_values(by=["eo:cloud_cover"], ascending=True).iloc[
index
]

# Get the datetime for the filtered Sentinel 2 dataframe
# containing the least nodata and least cloudy scene
s2_items_gdf_datetime = s2_items_gdf["datetime"]
for item in s2_items:
if item.properties["datetime"] == s2_items_gdf_datetime:
if item.properties["datetime"] == least_clouds["datetime"]:
s2_item = item
else:
continue
break

bbox = s2_items_gdf.iloc[0].bounds
bbox = least_clouds.geometry.bounds

epsg = s2_item.properties["proj:epsg"]
print("EPSG code based on Sentinel-2 item: ", epsg)

return catalog, s2_item, bbox
return s2_item, bbox


def search_sentinel1(bbox, catalog, date_range):
Expand Down Expand Up @@ -185,9 +187,8 @@ def search_sentinel1(bbox, catalog, date_range):
s1_items = search.item_collection()
print(f"Found {len(s1_items)} Sentinel-1 items")

if s1_items is None:
return False

if not len(s1_items):
return
else:
# Add id as property to persist in gdf
for item in s1_items:
Expand Down Expand Up @@ -271,37 +272,6 @@ def make_datasets(s2_items, s1_items, dem_items, resolution):
fill_value=np.nan,
)

# Create xarray.Dataset datacube with all 10m and 20m bands from Sentinel-2
da_s2_0: xr.DataArray = da_sen2.sel(band="B02", drop=True).rename("B02").squeeze()
da_s2_1: xr.DataArray = da_sen2.sel(band="B03", drop=True).rename("B03").squeeze()
da_s2_2: xr.DataArray = da_sen2.sel(band="B04", drop=True).rename("B04").squeeze()
da_s2_3: xr.DataArray = da_sen2.sel(band="B05", drop=True).rename("B05").squeeze()
da_s2_4: xr.DataArray = da_sen2.sel(band="B06", drop=True).rename("B06").squeeze()
da_s2_5: xr.DataArray = da_sen2.sel(band="B07", drop=True).rename("B07").squeeze()
da_s2_6: xr.DataArray = da_sen2.sel(band="B08", drop=True).rename("B08").squeeze()
da_s2_7: xr.DataArray = da_sen2.sel(band="B8A", drop=True).rename("B8A").squeeze()
da_s2_8: xr.DataArray = da_sen2.sel(band="B11", drop=True).rename("B11").squeeze()
da_s2_9: xr.DataArray = da_sen2.sel(band="B11", drop=True).rename("B11").squeeze()
da_s2_10: xr.DataArray = da_sen2.sel(band="SCL", drop=True).rename("SCL").squeeze()

ds_sen2: xr.Dataset = xr.merge(
objects=[
da_s2_0,
da_s2_1,
da_s2_2,
da_s2_3,
da_s2_4,
da_s2_5,
da_s2_6,
da_s2_7,
da_s2_8,
da_s2_9,
da_s2_10,
],
join="override",
)
ds_sen2.assign(time=da_sen2.time)

da_sen1: xr.DataArray = stackstac.stack(
items=s1_items,
assets=["vh", "vv"],
Expand All @@ -312,12 +282,6 @@ def make_datasets(s2_items, s1_items, dem_items, resolution):
fill_value=np.nan,
)

# Create xarray.Dataset datacube with VH and VV channels from SAR
da_sen1 = stackstac.mosaic(da_sen1, dim="time")
da_vh: xr.DataArray = da_sen1.sel(band="vh", drop=True).squeeze().rename("vh")
da_vv: xr.DataArray = da_sen1.sel(band="vv", drop=True).squeeze().rename("vv")
ds_sen1: xr.Dataset = xr.merge(objects=[da_vh, da_vv], join="override")

da_dem: xr.DataArray = stackstac.stack(
items=dem_items,
epsg=int(da_sen2.epsg),
Expand All @@ -327,9 +291,27 @@ def make_datasets(s2_items, s1_items, dem_items, resolution):
fill_value=np.nan,
)

da_dem: xr.DataArray = stackstac.mosaic(da_dem, dim="time").squeeze().rename("DEM")
da_sen1: xr.DataArray = stackstac.mosaic(da_sen1, dim="time")

da_sen1 = da_sen1.drop_vars(
[var for var in da_sen1.coords if var not in da_sen1.dims]
)

da_sen2 = da_sen2.drop_vars(
[var for var in da_sen2.coords if var not in da_sen2.dims]
).squeeze()

del da_sen2.coords["time"]

da_dem: xr.DataArray = stackstac.mosaic(da_dem, dim="time").assign_coords(
{"band": ["dem"]}
)

da_dem = da_dem.drop_vars([var for var in da_dem.coords if var not in da_dem.dims])

return ds_sen2, ds_sen1, da_dem
result = xr.concat([da_sen2, da_sen1, da_dem], dim="band")
result = result.rename("tile")
return result


def process(
Expand Down Expand Up @@ -361,54 +343,102 @@ def process(
"""
year = random.randint(start_year, end_year)
date_range = f"{year}-01-01/{year}-12-31"
catalog, s2_items, bbox = search_sentinel2(
date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage
)

surrounding_days = get_surrounding_days(s2_items.datetime, interval_days=3)
print("Searching S1 in date range", surrounding_days)

s1_items = search_sentinel1(bbox, catalog, surrounding_days)
catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace)

if not s1_items:
catalog, s2_items, bbox = search_sentinel2(
date_range, aoi, cloud_cover_percentage, nodata_pixel_percentage
for i in range(S1_MATCH_ATTEMPTS):
s2_item, bbox = search_sentinel2(
catalog,
date_range,
aoi,
cloud_cover_percentage,
nodata_pixel_percentage,
index=i,
)

surrounding_days = get_surrounding_days(s2_items.datetime, interval_days=3)
surrounding_days = get_surrounding_days(s2_item.datetime, interval_days=3)
print("Searching S1 in date range", surrounding_days)

s1_items = search_sentinel1(bbox, catalog, surrounding_days)

if s1_items:
break

if i == S1_MATCH_ATTEMPTS - 1:
raise ValueError(
f"No match for S1 scenes found after {S1_MATCH_ATTEMPTS} attempts."
)

dem_items = search_dem(bbox, catalog)

ds_sen2, ds_sen1, da_dem = make_datasets(
s2_items,
date = s2_item.properties["datetime"][:10]

result = make_datasets(
s2_item,
s1_items,
dem_items,
resolution,
)

ds_merge = xr.merge([ds_sen2, ds_sen1, da_dem], compat="override")
return date, result

return ds_merge

def convert_attrs_and_coords_objects_to_str(data):
"""
Convert attributes and coordinates that are objects to
strings.
def main():
tiles = gpd.read_file("scripts/data/mgrs_sample.geojson")
sample = tiles.sample(1, random_state=45)
aoi = sample.iloc[0].geometry
This is required for storing the xarray in netcdf.
"""
for key, coord in data.coords.items():
if coord.dtype == "object":
data.coords[key] = str(coord.values)

for key, attr in data.attrs.items():
data.attrs[key] = str(attr)

for key, var in data.variables.items():
var.attrs = {}


@click.command()
@click.option(
"--index",
required=True,
default=42,
help="Index of MGRS tile from sample file that should be processed",
)
@click.option(
"--subset",
required=False,
help="For debugging, subset x and y to this pixel window.",
)
def main(index, subset):
print("Starting algorithm", index)
index = int(index)
tiles = gpd.read_file("mgrs_sample.geojson")
tile = tiles.iloc[index]
start_year = 2017
end_year = 2023

merged = process(
aoi,
date, merged = process(
tile.geometry,
start_year,
end_year,
SPATIAL_RESOLUTION,
CLOUD_COVER_PERCENTAGE,
NODATA_PIXEL_PERCENTAGE,
)
return merged
mgrs = tile["name"]
if subset:
subset = [int(dat) for dat in subset.split(",")]
print(f"Subsetting to {subset}")
merged = merged.sel(
x=slice(merged.x.values[subset[0]], merged.x.values[subset[2]]),
y=slice(merged.y.values[subset[1]], merged.y.values[subset[3]]),
)
merged = merged.compute()

tiler(merged, date, mgrs)


# main()
if __name__ == "__main__":
main()
13 changes: 0 additions & 13 deletions scripts/stack_tile.py

This file was deleted.

Loading

0 comments on commit b6fad67

Please sign in to comment.