Skip to content

Commit

Permalink
Merge pull request #1 from kyle-woodward/feature/exporting
Browse files Browse the repository at this point in the history
Feature/exporting
  • Loading branch information
kyle-woodward authored Apr 4, 2024
2 parents 1be7e37 + 162ad2a commit b25e412
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 352 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ data/
tfrecords/
logs/
/**/*.png
*.tif
*.tiff

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
192 changes: 115 additions & 77 deletions fao_models/export_s2_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,36 @@
from serving import *
import os
import argparse

# os.environ['TF_ENABLE_ONEDNN_OPTS=0']

PROJECT = 'sig-ee-cloud' # change to your cloud project name
ee.Initialize(project=PROJECT)
PROJECT = "pc530-fao-fra-rss" # change to your cloud project name
# ee.Initialize(project=PROJECT)

## INIT WITH HIGH VOLUME ENDPOINT
credentials, _ = google.auth.default()
ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')
ee.Initialize(
credentials,
project=PROJECT,
opt_url="https://earthengine-highvolume.googleapis.com",
)

# USE HEXAGONS TO MAKE PATCH BOUNDS ######################################################
all_hex = ee.FeatureCollection("projects/pc530-fao-fra-rss/assets/reference/hexWCenPropertiesTropics")
all_hex = ee.FeatureCollection(
"projects/pc530-fao-fra-rss/assets/reference/hexWCenPropertiesTropics"
)
# print(all_hex.size().getInfo())
# print(all_hex.limit(1).getInfo()['features'][0]['properties'])

hexForest = all_hex.filter(ee.Filter.And(ee.Filter.eq('FOREST',1),ee.Filter.eq('LU18CEN','Forest')))
hexForest = all_hex.filter(
ee.Filter.And(ee.Filter.eq("FOREST", 1), ee.Filter.eq("LU18CEN", "Forest"))
)
# print('pureForest size',hexForest.size().getInfo())
# print('LU18CEN values pure forest',hexForest.aggregate_histogram('LU18CEN').getInfo())

hexNonForest = all_hex.filter(ee.Filter.And(ee.Filter.eq('FOREST',0),ee.Filter.neq('LU18CEN','Forest')))
hexNonForest = all_hex.filter(
ee.Filter.And(ee.Filter.eq("FOREST", 0), ee.Filter.neq("LU18CEN", "Forest"))
)
# print('pureNonForest size',hexNonForest.size().getInfo())
# print('LU18CEN values pure nonForest',hexNonForest.aggregate_histogram('LU18CEN').getInfo())

Expand All @@ -32,106 +43,133 @@
sample_size_total = FNFhex.size().getInfo()
# print(sample_size_total)


# create 320 x 320 m box for image patches (32x32 px patches for training)
def hex_patch_box(fc,size):
def per_hex(f):
centroid = f.geometry().centroid()
patch_box = centroid.buffer(size/2).bounds()
return ee.Feature(patch_box)
return fc.map(per_hex)
def hex_patch_box(fc, size):
def per_hex(f):
centroid = f.geometry().centroid()
patch_box = centroid.buffer(size / 2).bounds()
return ee.Feature(patch_box)

return fc.map(per_hex)

patch_boxes = hex_patch_box(FNFhex,320)

patch_boxes = hex_patch_box(FNFhex, 320)

# Finally, for actual workflow at botom we only need the centroids of each hexagon to generate the image patchess
FNFhex_centroids = FNFhex.map(lambda h: ee.Feature(h.geometry().centroid()))
# print(FNFhex_centroids.first().getInfo())

# image patch generation from hexagon centroid
hexLabel = ee.Image(0).paint(hexForest,1).paint(hexNonForest,2).selfMask().rename('class')
hexLabel = (
ee.Image(0).paint(hexForest, 1).paint(hexNonForest, 2).selfMask().rename("class")
)

## MAKE S2 COMPOSITE IN HEXAGONS ##########################################
# Using Cloud Score + for cloud/cloud-shadow masking
# Harmonized Sentinel-2 Level 2A collection.
s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
s2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")

# Cloud Score+ image collection. Note Cloud Score+ is produced from Sentinel-2
# Level 1C data and can be applied to either L1C or L2A collections.
csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED');
csPlus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")

# Use 'cs' or 'cs_cdf', depending on your use case; see docs for guidance.
QA_BAND = 'cs_cdf'
QA_BAND = "cs_cdf"

# The threshold for masking; values between 0.50 and 0.65 generally work well.
# Higher values will remove thin clouds, haze & cirrus shadows.
CLEAR_THRESHOLD = 0.50;
CLEAR_THRESHOLD = 0.50

# Make a clear median composite.
sampleImage = (s2
.filterDate('2017-01-01', '2019-12-31')
sampleImage = (
s2.filterDate("2017-01-01", "2019-12-31")
.linkCollection(csPlus, [QA_BAND])
.map(lambda img: img.updateMask(hexLabel))
.map(lambda img: img.updateMask(img.select(QA_BAND).gte(CLEAR_THRESHOLD)))
.median()
.addBands(hexLabel)
.select(['B4','B3','B2','B8','class'],['R','G','B','N','class'])) # B G R classlabel
.select(["B4", "B3", "B2", "B8", "class"], ["R", "G", "B", "N", "class"])
) # B G R classlabel


## TESTING ################################################################
def main():
# initalize new cli parser
parser = argparse.ArgumentParser(
description="Export S2 image patches."
)

parser.add_argument(
"-o",
"--output_dir",
type=str,
help="path to config file",
)

parser.add_argument(
"-f",
"--forest",
dest="forest",
action="store_true",
help="export forest labeled patches",
required=False
)

parser.add_argument(
"-nf",
"--nonforest",
dest="nonforest",
action="store_true",
help="export nonforest labeled patches",
required=False
)
args = parser.parse_args()

parser.set_defaults(forest=False)
parser.set_defaults(nonforest=False)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# we have about a 1/3 to 2/3 split of forest / nonforest makeup of total hexagons
ee_points_forest = hexForest.map(lambda h: ee.Feature(h.geometry().centroid())).randomColumn().sort('random')
ee_points_nonforest = hexNonForest.map(lambda h: ee.Feature(h.geometry().centroid())).randomColumn().sort('random')

if not args.forest and not args.nonforest:
print('Please specify --forest and/or --nonforest')
exit()

if args.forest:
# for i in [0.,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]: # first chunk finishes then hangs..
# print(f'exporting patches from points chunk {i}')
# points_chunk = ee_points_forest.filter(ee.Filter.And(ee.Filter.gt('random',i),ee.Filter.lte('random',i+0.1)))
# # print(points_chunk.size().getInfo())
# write_geotiff_patch_from_points_v2(sampleImage,points_chunk,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='forest')
write_geotiff_patch_from_points_v2(sampleImage,ee_points_forest,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='forest')

if args.nonforest:
write_geotiff_patch_from_points_v2(sampleImage,ee_points_nonforest,['R','G','B','N'],10,32,output_directory=args.output_dir, suffix='nonforest')
# initalize new cli parser
parser = argparse.ArgumentParser(description="Export S2 image patches.")

parser.add_argument(
"-o",
"--output_dir",
type=str,
help="path to config file",
)

parser.add_argument(
"-f",
"--forest",
dest="forest",
action="store_true",
help="export forest labeled patches",
required=False,
)

parser.add_argument(
"-nf",
"--nonforest",
dest="nonforest",
action="store_true",
help="export nonforest labeled patches",
required=False,
)
args = parser.parse_args()

parser.set_defaults(forest=False)
parser.set_defaults(nonforest=False)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# we have about a 1/3 to 2/3 split of forest / nonforest makeup of total hexagons
# set seed so we can reproduce list and skip exports that have finished.
seed = 42
ee_points_forest = (
hexForest.map(lambda h: ee.Feature(h.geometry().centroid()))
.randomColumn(seed=seed)
.sort("random")
)
ee_points_nonforest = (
hexNonForest.map(lambda h: ee.Feature(h.geometry().centroid()))
.randomColumn(seed=seed)
.sort("random")
)

if not args.forest and not args.nonforest:
print("Please specify --forest and/or --nonforest")
exit()

if args.forest:
write_geotiff_patch_from_points_v2(
sampleImage,
ee_points_forest,
["R", "G", "B", "N"],
10,
32,
output_directory=args.output_dir,
suffix="forest",
)

if args.nonforest:
write_geotiff_patch_from_points_v2(
sampleImage,
ee_points_nonforest,
["R", "G", "B", "N"],
10,
32,
output_directory=args.output_dir,
suffix="nonforest",
)


if __name__ == "__main__":
main()
main()
85 changes: 63 additions & 22 deletions fao_models/geotiffToRecords.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,84 @@
import os
import math
from enum import Enum, auto

import numpy as np
import rasterio
import math
import tensorflow as tf
from tqdm import tqdm


class SplitStrategy(Enum):
all = auto()
balanced = auto()


def balance_files(files: list[str]) -> list[str]:
forest = []
nonforest = []
for i in files:
if "non" in i:
nonforest.append(i)
else:
forest.append(i)
min_recs = min(len(forest), len(nonforest))
return nonforest[:min_recs] + forest[:min_recs]


def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def serialize_example(image, label):
feature = {
'image': _bytes_feature(image),
'label': _bytes_feature(tf.io.serialize_tensor(label))
"image": _bytes_feature(image),
"label": _bytes_feature(tf.io.serialize_tensor(label)),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()

def write_sharded_tfrecords(folder_path, output_dir, items_per_record=500):

def write_sharded_tfrecords(
folder_path,
output_dir,
items_per_record=500,
balance_strat: SplitStrategy = SplitStrategy.all,
):
filenames = [f for f in os.listdir(folder_path) if f.endswith(".tif")]
if balance_strat.balanced:
filenames = balance_files(filenames)

num_geotiffs = len(filenames)
num_shards = math.ceil(num_geotiffs / items_per_record)

options = tf.io.TFRecordOptions(compression_type="GZIP")

for shard_id in range(num_shards):
shard_filename = os.path.join(output_dir, f'geotiffs_shard_{shard_id:03d}.tfrecord.gz')
shard_filename = os.path.join(
output_dir, f"geotiffs_shard_{shard_id:03d}.tfrecord.gz"
)

start_index = shard_id * items_per_record
end_index = min((shard_id + 1) * items_per_record, num_geotiffs)

with tf.io.TFRecordWriter(shard_filename, options=options) as writer:
for filename in filenames[start_index:end_index]:
for filename in tqdm(filenames[start_index:end_index]):
file_path = os.path.join(folder_path, filename)
try:
with rasterio.open(file_path) as dataset:
raster_data = dataset.read() / 10000
if raster_data.shape != (4, 32, 32):
continue
raster_data = np.transpose(raster_data, (1, 2, 0))
image_raw = tf.io.serialize_tensor(raster_data.astype(np.float32))
label = 1 if 'nonforest' in filename else 0
print(filename)
print(label)
image_raw = tf.io.serialize_tensor(
raster_data.astype(np.float32)
)
label = 1 if "nonforest" in filename else 0
# print(filename)
# print(label)
label = tf.convert_to_tensor(label, dtype=tf.int64)
tf_example = serialize_example(image_raw, label)
writer.write(tf_example)
Expand All @@ -54,15 +89,21 @@ def write_sharded_tfrecords(folder_path, output_dir, items_per_record=500):
# Print a message after finishing writing each TFRecord file
print(f"Finished writing {shard_filename}")

# TESTING ######################################
# tiff_path = "/home/ate/sig/gitmodels/fao-models/data/"
tiff_path = r"C:\fao-models\data"
# tf_path = "/home/ate/sig/gitmodels/fao-models/tfrecords"
tf_path = r"C:\fao-models\tfrecords_test"
for folder in [tiff_path,tf_path]:
if not os.path.exists(folder):
os.makedirs(folder)
items_per_record = 1000 # Number of GeoTIFFs to store in each TFRecord file

# Call the function
write_sharded_tfrecords(tiff_path, tf_path, items_per_record=items_per_record)

if __name__ == "__main__":
tiff_path = r"data"
tf_path_root = r"tfrecords"
balance_strat = SplitStrategy.balanced
items_per_record = 1000 # Number of GeoTIFFs to store in each TFRecord file

tf_path_strat = os.path.join(tf_path_root, balance_strat.name)
for folder in [tiff_path, tf_path_root, tf_path_strat]:
if not os.path.exists(folder):
os.makedirs(folder)

write_sharded_tfrecords(
tiff_path,
tf_path_strat,
items_per_record=items_per_record,
balance_strat=balance_strat,
)
Loading

0 comments on commit b25e412

Please sign in to comment.