Skip to content

Commit

Permalink
processing scripts for ssl4eo
Browse files Browse the repository at this point in the history
  • Loading branch information
jdilger committed Apr 18, 2024
1 parent 72afbd6 commit e730254
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
31 changes: 31 additions & 0 deletions fao_models/scripts/clean_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pandas as pd
import os


def load_csv(filepath: str):
df = pd.read_csv(filepath)
return df


def add_new_index(df: pd.DataFrame):
df["id"] = range(len(df))


def select_columns(df: pd.DataFrame, columns: list[str]):
return df[columns]


def save(df: pd.DataFrame, location):
df.to_csv(location)


def main():
files = ["testing_sample.csv", "training_sample.csv", "validation_sample.csv"]
for f in files:
df = load_csv(f)
add_new_index(df)
df = select_columns(df, ["id", "lng", "lat"])
save(df, f"match_{f}")


main()
51 changes: 51 additions & 0 deletions fao_models/scripts/hex_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ee
import requests

PROJECT = "pc530-fao-fra-rss"
ee.Initialize(project=PROJECT)

hex = ee.FeatureCollection(
"projects/pc530-fao-fra-rss/assets/reference/hexWCenPropertiesTropics"
)
hex_forest = hex.filter(
ee.Filter.And(ee.Filter.eq("FOREST", 1), ee.Filter.eq("LU18CEN", "Forest"))
)
hex_nonforest = hex.filter(
ee.Filter.And(ee.Filter.eq("FOREST", 0), ee.Filter.neq("LU18CEN", "Forest"))
)
hex_samples = (
ee.FeatureCollection([hex_forest, hex_nonforest]).flatten().randomColumn(seed=42)
)


# .Point(lng, lat).
def to_centroid(feat: ee.Feature):
feat = feat.centroid()
coords = feat.geometry().coordinates()
return feat.set({"lng": coords.get(0), "lat": coords.get(1)})


hex_samples = hex_samples.map(to_centroid)
print(hex_samples.limit(1).getInfo())
# split into train test and validate 70, 20, 10
training_sample = hex_samples.filter("random <= 0.7")
testing_sample = hex_samples.filter("random > 0.7 and random <=.9")
validation_sample = hex_samples.filter("random > 0.9")
# print(
# training_sample.size().getInfo(),
# testing_sample.size().getInfo(),
# validation_sample.size().getInfo(),
# )


def download(collection: ee.FeatureCollection, full_namepath: str):
url = collection.getDownloadURL()
response = requests.get(url)
response.raise_for_status()
with open(full_namepath, "wb") as fd:
fd.write(response.content)


download(training_sample, "training_sample.csv")
download(testing_sample, "testing_sample.csv")
download(validation_sample, "validation_sample.csv")

0 comments on commit e730254

Please sign in to comment.