diff --git a/.gitignore b/.gitignore index 87283e4..a61a696 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ data/ tfrecords/ logs/ /**/*.png +*.tif +*.tiff # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/fao_models/export_s2_patches.py b/fao_models/export_s2_patches.py index 4c08605..e85740b 100644 --- a/fao_models/export_s2_patches.py +++ b/fao_models/export_s2_patches.py @@ -4,25 +4,36 @@ from serving import * import os import argparse + # os.environ['TF_ENABLE_ONEDNN_OPTS=0'] -PROJECT = 'sig-ee-cloud' # change to your cloud project name -ee.Initialize(project=PROJECT) +PROJECT = "pc530-fao-fra-rss" # change to your cloud project name +# ee.Initialize(project=PROJECT) ## INIT WITH HIGH VOLUME ENDPOINT credentials, _ = google.auth.default() -ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com') +ee.Initialize( + credentials, + project=PROJECT, + opt_url="https://earthengine-highvolume.googleapis.com", +) # USE HEXAGONS TO MAKE PATCH BOUNDS ###################################################### -all_hex = ee.FeatureCollection("projects/pc530-fao-fra-rss/assets/reference/hexWCenPropertiesTropics") +all_hex = ee.FeatureCollection( + "projects/pc530-fao-fra-rss/assets/reference/hexWCenPropertiesTropics" +) # print(all_hex.size().getInfo()) # print(all_hex.limit(1).getInfo()['features'][0]['properties']) -hexForest = all_hex.filter(ee.Filter.And(ee.Filter.eq('FOREST',1),ee.Filter.eq('LU18CEN','Forest'))) +hexForest = all_hex.filter( + ee.Filter.And(ee.Filter.eq("FOREST", 1), ee.Filter.eq("LU18CEN", "Forest")) +) # print('pureForest size',hexForest.size().getInfo()) # print('LU18CEN values pure forest',hexForest.aggregate_histogram('LU18CEN').getInfo()) -hexNonForest = all_hex.filter(ee.Filter.And(ee.Filter.eq('FOREST',0),ee.Filter.neq('LU18CEN','Forest'))) +hexNonForest = all_hex.filter( + ee.Filter.And(ee.Filter.eq("FOREST", 0), ee.Filter.neq("LU18CEN", "Forest")) +) # print('pureNonForest size',hexNonForest.size().getInfo()) # print('LU18CEN values pure nonForest',hexNonForest.aggregate_histogram('LU18CEN').getInfo()) @@ -32,106 +43,133 @@ sample_size_total = FNFhex.size().getInfo() # print(sample_size_total) + # create 320 x 320 m box for image patches (32x32 px patches for training) -def hex_patch_box(fc,size): - def per_hex(f): - centroid = f.geometry().centroid() - patch_box = centroid.buffer(size/2).bounds() - return ee.Feature(patch_box) - return fc.map(per_hex) +def hex_patch_box(fc, size): + def per_hex(f): + centroid = f.geometry().centroid() + patch_box = centroid.buffer(size / 2).bounds() + return ee.Feature(patch_box) + + return fc.map(per_hex) -patch_boxes = hex_patch_box(FNFhex,320) + +patch_boxes = hex_patch_box(FNFhex, 320) # Finally, for actual workflow at botom we only need the centroids of each hexagon to generate the image patchess FNFhex_centroids = FNFhex.map(lambda h: ee.Feature(h.geometry().centroid())) # print(FNFhex_centroids.first().getInfo()) # image patch generation from hexagon centroid -hexLabel = ee.Image(0).paint(hexForest,1).paint(hexNonForest,2).selfMask().rename('class') +hexLabel = ( + ee.Image(0).paint(hexForest, 1).paint(hexNonForest, 2).selfMask().rename("class") +) ## MAKE S2 COMPOSITE IN HEXAGONS ########################################## # Using Cloud Score + for cloud/cloud-shadow masking # Harmonized Sentinel-2 Level 2A collection. -s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') +s2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") # Cloud Score+ image collection. Note Cloud Score+ is produced from Sentinel-2 # Level 1C data and can be applied to either L1C or L2A collections. -csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED'); +csPlus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED") # Use 'cs' or 'cs_cdf', depending on your use case; see docs for guidance. -QA_BAND = 'cs_cdf' +QA_BAND = "cs_cdf" # The threshold for masking; values between 0.50 and 0.65 generally work well. # Higher values will remove thin clouds, haze & cirrus shadows. -CLEAR_THRESHOLD = 0.50; +CLEAR_THRESHOLD = 0.50 # Make a clear median composite. -sampleImage = (s2 - .filterDate('2017-01-01', '2019-12-31') +sampleImage = ( + s2.filterDate("2017-01-01", "2019-12-31") .linkCollection(csPlus, [QA_BAND]) .map(lambda img: img.updateMask(hexLabel)) .map(lambda img: img.updateMask(img.select(QA_BAND).gte(CLEAR_THRESHOLD))) .median() .addBands(hexLabel) - .select(['B4','B3','B2','B8','class'],['R','G','B','N','class'])) # B G R classlabel + .select(["B4", "B3", "B2", "B8", "class"], ["R", "G", "B", "N", "class"]) +) # B G R classlabel + ## TESTING ################################################################ def main(): - # initalize new cli parser - parser = argparse.ArgumentParser( - description="Export S2 image patches." - ) - - parser.add_argument( - "-o", - "--output_dir", - type=str, - help="path to config file", - ) - - parser.add_argument( - "-f", - "--forest", - dest="forest", - action="store_true", - help="export forest labeled patches", - required=False - ) - - parser.add_argument( - "-nf", - "--nonforest", - dest="nonforest", - action="store_true", - help="export nonforest labeled patches", - required=False - ) - args = parser.parse_args() - - parser.set_defaults(forest=False) - parser.set_defaults(nonforest=False) - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - # we have about a 1/3 to 2/3 split of forest / nonforest makeup of total hexagons - ee_points_forest = hexForest.map(lambda h: ee.Feature(h.geometry().centroid())).randomColumn().sort('random') - ee_points_nonforest = hexNonForest.map(lambda h: ee.Feature(h.geometry().centroid())).randomColumn().sort('random') - - if not args.forest and not args.nonforest: - print('Please specify --forest and/or --nonforest') - exit() - - if args.forest: - # for i in [0.,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]: # first chunk finishes then hangs.. - # print(f'exporting patches from points chunk {i}') - # points_chunk = ee_points_forest.filter(ee.Filter.And(ee.Filter.gt('random',i),ee.Filter.lte('random',i+0.1))) - # # print(points_chunk.size().getInfo()) - # write_geotiff_patch_from_points_v2(sampleImage,points_chunk,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='forest') - write_geotiff_patch_from_points_v2(sampleImage,ee_points_forest,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='forest') - - if args.nonforest: - write_geotiff_patch_from_points_v2(sampleImage,ee_points_nonforest,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='nonforest') + # initalize new cli parser + parser = argparse.ArgumentParser(description="Export S2 image patches.") + + parser.add_argument( + "-o", + "--output_dir", + type=str, + help="path to config file", + ) + + parser.add_argument( + "-f", + "--forest", + dest="forest", + action="store_true", + help="export forest labeled patches", + required=False, + ) + + parser.add_argument( + "-nf", + "--nonforest", + dest="nonforest", + action="store_true", + help="export nonforest labeled patches", + required=False, + ) + args = parser.parse_args() + + parser.set_defaults(forest=False) + parser.set_defaults(nonforest=False) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # we have about a 1/3 to 2/3 split of forest / nonforest makeup of total hexagons + # set seed so we can reproduce list and skip exports that have finished. + seed = 42 + ee_points_forest = ( + hexForest.map(lambda h: ee.Feature(h.geometry().centroid())) + .randomColumn(seed=seed) + .sort("random") + ) + ee_points_nonforest = ( + hexNonForest.map(lambda h: ee.Feature(h.geometry().centroid())) + .randomColumn(seed=seed) + .sort("random") + ) + + if not args.forest and not args.nonforest: + print("Please specify --forest and/or --nonforest") + exit() + + if args.forest: + write_geotiff_patch_from_points_v2( + sampleImage, + ee_points_forest, + ["R", "G", "B", "N"], + 10, + 32, + output_directory=args.output_dir, + suffix="forest", + ) + + if args.nonforest: + write_geotiff_patch_from_points_v2( + sampleImage, + ee_points_nonforest, + ["R", "G", "B", "N"], + 10, + 32, + output_directory=args.output_dir, + suffix="nonforest", + ) + if __name__ == "__main__": - main() + main() diff --git a/fao_models/geotiffToRecords.py b/fao_models/geotiffToRecords.py index 2e054b3..fbb4797 100644 --- a/fao_models/geotiffToRecords.py +++ b/fao_models/geotiffToRecords.py @@ -1,8 +1,29 @@ import os +import math +from enum import Enum, auto + import numpy as np import rasterio -import math import tensorflow as tf +from tqdm import tqdm + + +class SplitStrategy(Enum): + all = auto() + balanced = auto() + + +def balance_files(files: list[str]) -> list[str]: + forest = [] + nonforest = [] + for i in files: + if "non" in i: + nonforest.append(i) + else: + forest.append(i) + min_recs = min(len(forest), len(nonforest)) + return nonforest[:min_recs] + forest[:min_recs] + def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" @@ -10,29 +31,41 @@ def _bytes_feature(value): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + def serialize_example(image, label): feature = { - 'image': _bytes_feature(image), - 'label': _bytes_feature(tf.io.serialize_tensor(label)) + "image": _bytes_feature(image), + "label": _bytes_feature(tf.io.serialize_tensor(label)), } example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString() -def write_sharded_tfrecords(folder_path, output_dir, items_per_record=500): + +def write_sharded_tfrecords( + folder_path, + output_dir, + items_per_record=500, + balance_strat: SplitStrategy = SplitStrategy.all, +): filenames = [f for f in os.listdir(folder_path) if f.endswith(".tif")] + if balance_strat.balanced: + filenames = balance_files(filenames) + num_geotiffs = len(filenames) num_shards = math.ceil(num_geotiffs / items_per_record) options = tf.io.TFRecordOptions(compression_type="GZIP") for shard_id in range(num_shards): - shard_filename = os.path.join(output_dir, f'geotiffs_shard_{shard_id:03d}.tfrecord.gz') + shard_filename = os.path.join( + output_dir, f"geotiffs_shard_{shard_id:03d}.tfrecord.gz" + ) start_index = shard_id * items_per_record end_index = min((shard_id + 1) * items_per_record, num_geotiffs) with tf.io.TFRecordWriter(shard_filename, options=options) as writer: - for filename in filenames[start_index:end_index]: + for filename in tqdm(filenames[start_index:end_index]): file_path = os.path.join(folder_path, filename) try: with rasterio.open(file_path) as dataset: @@ -40,10 +73,12 @@ def write_sharded_tfrecords(folder_path, output_dir, items_per_record=500): if raster_data.shape != (4, 32, 32): continue raster_data = np.transpose(raster_data, (1, 2, 0)) - image_raw = tf.io.serialize_tensor(raster_data.astype(np.float32)) - label = 1 if 'nonforest' in filename else 0 - print(filename) - print(label) + image_raw = tf.io.serialize_tensor( + raster_data.astype(np.float32) + ) + label = 1 if "nonforest" in filename else 0 + # print(filename) + # print(label) label = tf.convert_to_tensor(label, dtype=tf.int64) tf_example = serialize_example(image_raw, label) writer.write(tf_example) @@ -54,15 +89,21 @@ def write_sharded_tfrecords(folder_path, output_dir, items_per_record=500): # Print a message after finishing writing each TFRecord file print(f"Finished writing {shard_filename}") -# TESTING ###################################### -# tiff_path = "/home/ate/sig/gitmodels/fao-models/data/" -tiff_path = r"C:\fao-models\data" -# tf_path = "/home/ate/sig/gitmodels/fao-models/tfrecords" -tf_path = r"C:\fao-models\tfrecords_test" -for folder in [tiff_path,tf_path]: - if not os.path.exists(folder): - os.makedirs(folder) -items_per_record = 1000 # Number of GeoTIFFs to store in each TFRecord file - -# Call the function -write_sharded_tfrecords(tiff_path, tf_path, items_per_record=items_per_record) + +if __name__ == "__main__": + tiff_path = r"data" + tf_path_root = r"tfrecords" + balance_strat = SplitStrategy.balanced + items_per_record = 1000 # Number of GeoTIFFs to store in each TFRecord file + + tf_path_strat = os.path.join(tf_path_root, balance_strat.name) + for folder in [tiff_path, tf_path_root, tf_path_strat]: + if not os.path.exists(folder): + os.makedirs(folder) + + write_sharded_tfrecords( + tiff_path, + tf_path_strat, + items_per_record=items_per_record, + balance_strat=balance_strat, + ) diff --git a/fao_models/serving.py b/fao_models/serving.py index 53fdb15..d79065c 100644 --- a/fao_models/serving.py +++ b/fao_models/serving.py @@ -1,271 +1,278 @@ -import ee -import os -# from google.colab import auth -from google.api_core import exceptions, retry +import concurrent.futures -import concurrent -import google +import os import io -import multiprocessing + +from google.api_core import exceptions, retry +import ee import numpy as np -import requests import tensorflow as tf import requests -# import exceptions # where do we import this from? - -@retry.Retry(deadline=60*10) -def get_tiff_patch_url_file_point(image: ee.Image, point: ee.Geometry, bands: list, scale:int, patch_size: int, output_file:str): - """ - Return ee.Image.getDownloadURL response and a filename as a tuple for a GeoTIFF. filename is passed through for concurrent.futures multiprocessing jobs. Uses points rather than polygons. - args: - image: ee.Image - box: ee.Geometry - bands: list(str) - output_file: str - """ - # Create the URL to download the band values of the patch of pixels. - point = ee.Geometry(point) - region = point.buffer(scale * patch_size / 2, 1).bounds(1) - - url = image.getDownloadURL({ - "region": region, - "dimensions": [patch_size, patch_size], - "format": "GEO_TIFF", - "bands": bands, - }) - response = requests.get(url) - if response.status_code == 429: - raise exceptions.TooManyRequests(response.text) - response.raise_for_status() - return (response, output_file) + @retry.Retry() -def get_tiff_patch_url_file_box(image,box:ee.Geometry,bands:list,output_file:str): - """ - Return ee.Image.getDownloadURL response and a filename as a tuple for a GeoTIFF. filename is passed through for concurrent.futures multiprocessing jobs. Uses polygons (boxes) rather than points. - args: - image: ee.Image - box: ee.Geometry - bands: list(str) - output_file: str - """ - url = image.getDownloadUrl({ - 'bands': bands, - 'region': box, - 'scale': 10, - 'format': 'GEO_TIFF' - }) - return (requests.get(url),output_file) - -def write_geotiff_patch_from_boxes(image,boxes,bands,output_directory): - """Writes patches inside boxes a GEE Image within a FeatureCollection of boxes to individual GeoTIFFs - args: - image: ee.Image - boxes: ee.FeatureCollection - bands: list(str) - - """ - EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=40) # max concurrent requests to high volume endpoint - - # convert boxes FeatureCollection to ee.Geomtry's - patch_box_list = boxes.toList(boxes.size()).map(lambda f:ee.Feature(f).geometry()).getInfo() # list of ee.Geometry's - # TODO: split into train/val/test folders within data/ directory - patch_box_list_filenames = [os.path.join(output_directory,f'patch_box{list_index}.tif') for list_index in list(range(0,boxes.size().getInfo()))] # list of filenames - - future_to_point = { - EXECUTOR.submit(get_tiff_patch_url_file_box, image, box, bands, filename): (box,filename) for (box,filename) in zip(patch_box_list,patch_box_list_filenames) - } - - for future in concurrent.futures.as_completed(future_to_point): - result = future.result() - resp = result[0] - filename = result[1] - with open(filename, 'wb') as fd: - fd.write(resp.content) - -def write_geotiff_patch_from_points_v2(image,points,bands,scale,patch_size,output_directory,suffix=None,num_workers=10): - """Writes patches inside boxes a GEE Image within a FeatureCollection of boxes to individual GeoTIFFs - args: - image: ee.Image - points: ee.FeatureCollection - bands: list(str) - scale: int - patch_size: int - output_directory: str - suffix: str - - """ - with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: # max concurrent requests to high volume endpoint is 40 - +def get_tiff_patch_url_file_point( + image: ee.Image, + point: ee.Geometry, + bands: list, + scale: int, + patch_size: int, + output_file: str, +): + """ + Return ee.Image.getDownloadURL response and a filename as a tuple for a GeoTIFF. filename is passed through for concurrent.futures multiprocessing jobs. Uses points rather than polygons. + args: + image: ee.Image + box: ee.Geometry + bands: list(str) + output_file: str + """ + # Create the URL to download the band values of the patch of pixels. + point = ee.Geometry(point) + region = point.buffer(scale * patch_size / 2, 1).bounds(1) + + url = image.getDownloadURL( + { + "region": region, + "dimensions": [patch_size, patch_size], + "format": "GEO_TIFF", + "bands": bands, + } + ) + response = requests.get(url) + if response.status_code == 429: + raise exceptions.TooManyRequests(response.text) + response.raise_for_status() + with open(output_file, "wb") as fd: + fd.write(response.content) + return 1 + + +def write_geotiff_patch_from_points_v2( + image, + points, + bands, + scale, + patch_size, + output_directory, + suffix=None, + num_workers=10, +): + """Writes patches inside boxes a GEE Image within a FeatureCollection of boxes to individual GeoTIFFs + args: + image: ee.Image + points: ee.FeatureCollection + bands: list(str) + scale: int + patch_size: int + output_directory: str + suffix: str + + """ + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers + ) as executor: # max concurrent requests to high volume endpoint is 40 + # convert points FeatureCollection to ee.Geomtry's + patch_pt_list = ( + points.toList(points.size()) + .map(lambda f: ee.Feature(f).geometry()) + .getInfo() + ) # list of ee.Geometry's + patch_pt_list_filenames = [ + os.path.join(output_directory, f"patch_pt{list_index}_{suffix}.tif") + for list_index in list(range(0, points.size().getInfo())) + ] # list of filenames + pt_filename = zip(patch_pt_list, patch_pt_list_filenames) + # # don't write patches that already exist on disk + pt_filename = [pf for pf in pt_filename if not os.path.exists(pf[1])] + # end this stuff + futures = { + executor.submit( + get_tiff_patch_url_file_point, + image, + pt, + bands, + scale, + patch_size, + filename, + ): (pt, filename) + for ( + pt, + filename, + ) in pt_filename # zip(patch_pt_list,patch_pt_list_filenames) + } + try: + for future in concurrent.futures.as_completed(futures, timeout=120): + try: + # get result of future with timeout of 120s + result = future.result(timeout=120) + + except concurrent.futures.TimeoutError: + print("The task exceeded 2-minute limit and was cancelled") + except Exception as e: + print(f"Loop - Generated an exception: {e}") + except Exception as e: + # TODO - this still generates an exception when doing a long running (and finishing task) + # Test to see if removing try/except will still allow task to finish. + print(f"Outer - Generated an exception: {e}") + + +def write_geotiff_patch_from_points( + image, points, bands, scale, patch_size, output_directory, suffix=None +): + """Writes patches inside boxes a GEE Image within a FeatureCollection of boxes to individual GeoTIFFs + args: + image: ee.Image + points: ee.FeatureCollection + bands: list(str) + scale: int + patch_size: int + output_directory: str + suffix: str + + """ + EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=40 + ) # max concurrent requests to high volume endpoint + # convert points FeatureCollection to ee.Geomtry's - patch_pt_list = points.toList(points.size()).map(lambda f:ee.Feature(f).geometry()).getInfo() # list of ee.Geometry's - patch_pt_list_filenames = [os.path.join(output_directory,f'patch_pt{list_index}_{suffix}.tif') for list_index in list(range(0,points.size().getInfo()))] # list of filenames - pt_filename = zip(patch_pt_list,patch_pt_list_filenames) - # # don't write patches that already exist on disk - # pt_filename = [pf for pf in pt_filename if not os.path.exists(pf[1])] - - futures = { - executor.submit(get_tiff_patch_url_file_point, image, pt, bands, scale, patch_size, filename): - (pt,filename) for (pt,filename) in pt_filename # zip(patch_pt_list,patch_pt_list_filenames) + patch_pt_list = ( + points.toList(points.size()).map(lambda f: ee.Feature(f).geometry()).getInfo() + ) # list of ee.Geometry's + patch_pt_list_filenames = [ + os.path.join(output_directory, f"patch_pt{list_index}_{suffix}.tif") + for list_index in list(range(0, points.size().getInfo())) + ] # list of filenames + + pt_filename = zip(patch_pt_list, patch_pt_list_filenames) + + # don't write patches that already exist on disk + pt_filename = [pf for pf in pt_filename if not os.path.exists(pf[1])] + + future_to_point = { + EXECUTOR.submit( + get_tiff_patch_url_file_point, image, pt, bands, scale, patch_size, filename + ): (pt, filename) + for (pt, filename) in zip(patch_pt_list, patch_pt_list_filenames) } - for future in concurrent.futures.as_completed(futures): - try: - # get result of future with timeout of 120s - result = future.result(timeout=120) + for future in concurrent.futures.as_completed(future_to_point): + result = future.result() resp = result[0] filename = result[1] - with open(filename, 'wb') as fd: - fd.write(resp.content) - except concurrent.futures.TimeoutError: - print("The task exceeded 2-minute limit and was cancelled") - except Exception as e: - print(f"Generated an exception: {e}") - - -def write_geotiff_patch_from_points(image,points,bands,scale,patch_size,output_directory,suffix=None): - """Writes patches inside boxes a GEE Image within a FeatureCollection of boxes to individual GeoTIFFs - args: - image: ee.Image - points: ee.FeatureCollection - bands: list(str) - scale: int - patch_size: int - output_directory: str - suffix: str - - """ - EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=40) # max concurrent requests to high volume endpoint - - # convert points FeatureCollection to ee.Geomtry's - patch_pt_list = points.toList(points.size()).map(lambda f:ee.Feature(f).geometry()).getInfo() # list of ee.Geometry's - patch_pt_list_filenames = [os.path.join(output_directory,f'patch_pt{list_index}_{suffix}.tif') for list_index in list(range(0,points.size().getInfo()))] # list of filenames - - pt_filename = zip(patch_pt_list,patch_pt_list_filenames) - - # don't write patches that already exist on disk - pt_filename = [pf for pf in pt_filename if not os.path.exists(pf[1])] - - future_to_point = { - EXECUTOR.submit(get_tiff_patch_url_file_point, image, pt, bands, scale, patch_size, filename): (pt,filename) for (pt,filename) in zip(patch_pt_list,patch_pt_list_filenames) - } - - for future in concurrent.futures.as_completed(future_to_point): - result = future.result() - resp = result[0] - filename = result[1] - with open(filename, 'wb') as fd: - fd.write(resp.content) + with open(filename, "wb") as fd: + fd.write(resp.content) + def write_tfrecord_batch(image, patch_size, points, scale, output_file): - """Writes patches at a set of points to a TFRecord file, using ee.data.ComputePixels - args: - image: ee.Image - patch_size: int - points: python list of ee.Geometry.Point objects, easily done with `pointFC.aggregate_array('.geo').getInfo()` - scale: int - output_file: str - returns: None - """ - # REPLACE WITH YOUR BUCKET! - OUTPUT_FILE = output_file - - # Output resolution in meters. - SCALE = scale - - # Pre-compute a geographic coordinate system. - proj = ee.Projection('EPSG:4326').atScale(SCALE).getInfo() - - # Get scales in degrees out of the transform. - SCALE_X = proj['transform'][0] - SCALE_Y = -proj['transform'][4] - - # Patch size in pixels. - PATCH_SIZE = patch_size - - # Offset to the upper left corner. - OFFSET_X = -SCALE_X * PATCH_SIZE / 2 - OFFSET_Y = -SCALE_Y * PATCH_SIZE / 2 - - # Request template for ee.data.ComputePixels - REQUEST = { - 'fileFormat': 'NPY', - 'grid': { - 'dimensions': { - 'width': PATCH_SIZE, - 'height': PATCH_SIZE - }, - 'affineTransform': { - 'scaleX': SCALE_X, - 'shearX': 0, - 'shearY': 0, - 'scaleY': SCALE_Y, + """Writes patches at a set of points to a TFRecord file, using ee.data.ComputePixels + args: + image: ee.Image + patch_size: int + points: python list of ee.Geometry.Point objects, easily done with `pointFC.aggregate_array('.geo').getInfo()` + scale: int + output_file: str + returns: None + """ + # REPLACE WITH YOUR BUCKET! + OUTPUT_FILE = output_file + + # Output resolution in meters. + SCALE = scale + + # Pre-compute a geographic coordinate system. + proj = ee.Projection("EPSG:4326").atScale(SCALE).getInfo() + + # Get scales in degrees out of the transform. + SCALE_X = proj["transform"][0] + SCALE_Y = -proj["transform"][4] + + # Patch size in pixels. + PATCH_SIZE = patch_size + + # Offset to the upper left corner. + OFFSET_X = -SCALE_X * PATCH_SIZE / 2 + OFFSET_Y = -SCALE_Y * PATCH_SIZE / 2 + + # Request template for ee.data.ComputePixels + REQUEST = { + "fileFormat": "NPY", + "grid": { + "dimensions": {"width": PATCH_SIZE, "height": PATCH_SIZE}, + "affineTransform": { + "scaleX": SCALE_X, + "shearX": 0, + "shearY": 0, + "scaleY": SCALE_Y, }, - 'crsCode': proj['crs'] - } + "crsCode": proj["crs"], + }, } - # Blue, green, red, NIR, AOT. - FEATURES = image.bandNames().getInfo()#['B2_median', 'B3_median', 'B4_median', 'B8_median', 'AOT_median'] - - # Specify the size and shape of patches expected by the model. - KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE] - COLUMNS = [ - tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES - ] - FEATURES_DICT = dict(zip(FEATURES, COLUMNS)) - - EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=40) # max concurrent requests to high volume endpoint - - # functions for batch .tfrecord writer workflow - @retry.Retry() - def get_patch(coords, image,format='NPY'): - """Uses ee.data.ComputePixels() to get a patch centered on the coordinates, as a numpy array.""" - request = dict(REQUEST) - request['fileFormat'] = format - request['expression'] = image - request['grid']['affineTransform']['translateX'] = coords[0] + OFFSET_X - request['grid']['affineTransform']['translateY'] = coords[1] + OFFSET_Y - return np.load(io.BytesIO(ee.data.computePixels(request))) - - def get_sample_coords(roi, n): - """"Get a random sample of N points in the ROI.""" - points = ee.FeatureCollection.randomPoints(region=roi, points=n, maxError=1) - return points.aggregate_array('.geo').getInfo() - - def array_to_example(structured_array): - """"Serialize a structured numpy array into a tf.Example proto.""" - feature = {} - for f in FEATURES: - feature[f] = tf.train.Feature( - float_list = tf.train.FloatList( - value = structured_array[f].flatten())) - return tf.train.Example( - features = tf.train.Features(feature = feature)) - - - def write_tf_dataset(image, sample_points, file_name): - """"Write patches at the sample points into one TFRecord file.""" - future_to_point = { - EXECUTOR.submit(get_patch, point['coordinates'], image): point for point in sample_points - } + # Blue, green, red, NIR, AOT. + FEATURES = ( + image.bandNames().getInfo() + ) # ['B2_median', 'B3_median', 'B4_median', 'B8_median', 'AOT_median'] - # Optionally compress files. - writer = tf.io.TFRecordWriter(file_name) + # Specify the size and shape of patches expected by the model. + KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE] + COLUMNS = [ + tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES + ] + FEATURES_DICT = dict(zip(FEATURES, COLUMNS)) - for future in concurrent.futures.as_completed(future_to_point): - point = future_to_point[future] - try: - np_array = future.result() - example_proto = array_to_example(np_array) - writer.write(example_proto.SerializeToString()) - writer.flush() - except Exception as e: - # print(e) - pass - - writer.close() - - # write patches to .tfrecord file - write_tf_dataset(image, points, OUTPUT_FILE) - return + EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=40 + ) # max concurrent requests to high volume endpoint + + # functions for batch .tfrecord writer workflow + @retry.Retry() + def get_patch(coords, image, format="NPY"): + """Uses ee.data.ComputePixels() to get a patch centered on the coordinates, as a numpy array.""" + request = dict(REQUEST) + request["fileFormat"] = format + request["expression"] = image + request["grid"]["affineTransform"]["translateX"] = coords[0] + OFFSET_X + request["grid"]["affineTransform"]["translateY"] = coords[1] + OFFSET_Y + return np.load(io.BytesIO(ee.data.computePixels(request))) + + def get_sample_coords(roi, n): + """ "Get a random sample of N points in the ROI.""" + points = ee.FeatureCollection.randomPoints(region=roi, points=n, maxError=1) + return points.aggregate_array(".geo").getInfo() + + def array_to_example(structured_array): + """ "Serialize a structured numpy array into a tf.Example proto.""" + feature = {} + for f in FEATURES: + feature[f] = tf.train.Feature( + float_list=tf.train.FloatList(value=structured_array[f].flatten()) + ) + return tf.train.Example(features=tf.train.Features(feature=feature)) + + def write_tf_dataset(image, sample_points, file_name): + """ "Write patches at the sample points into one TFRecord file.""" + future_to_point = { + EXECUTOR.submit(get_patch, point["coordinates"], image): point + for point in sample_points + } + + # Optionally compress files. + writer = tf.io.TFRecordWriter(file_name) + + for future in concurrent.futures.as_completed(future_to_point): + point = future_to_point[future] + try: + np_array = future.result() + example_proto = array_to_example(np_array) + writer.write(example_proto.SerializeToString()) + writer.flush() + except Exception as e: + # print(e) + pass + + writer.close() + + # write patches to .tfrecord file + write_tf_dataset(image, points, OUTPUT_FILE) + return diff --git a/requirements.txt b/requirements.txt index 296d654..ef2f7a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ -numpy \ No newline at end of file +numpy +tensorflow +earthengine-api +rasterio +tqdm \ No newline at end of file