-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
all changes on old laptop [clean up later]
- Loading branch information
1 parent
4ef4e79
commit 10e8581
Showing
24 changed files
with
1,208 additions
and
491 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import collections | ||
import typing | ||
import argparse | ||
from types import SimpleNamespace | ||
import csv | ||
import io | ||
import logging | ||
|
||
import apache_beam as beam | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.io import ReadFromCsv, WriteToText | ||
|
||
from beam_utils import parse_shp_to_latlon | ||
from common import load_yml | ||
|
||
|
||
# want my pipeline to have these general steps | ||
|
||
# 1. Read in data from SHP (hexagons were provided as SHP and CSV but CSV has no geom column, centroids only came as SHP file) | ||
# 2. parse data into row-wise elements of (global id, [lon,lat]) - rest of pipeline passes these elements through | ||
# 3. download imagery for each element and convert to a tensor | ||
# 4. load model and run inference on tensor to return prediction value | ||
# 5. write prediction value to new CSV file with (global id, lat, long, prediction value) | ||
|
||
class GetPatch(beam.DoFn): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def setup(self): | ||
import ee | ||
import google.auth | ||
|
||
credentials, _ = google.auth.default() | ||
ee.Initialize( | ||
credentials, | ||
project="pc530-fao-fra-rss", | ||
opt_url="https://earthengine-highvolume.googleapis.com", | ||
) | ||
return super().setup() | ||
|
||
def process(self, element): | ||
from beam_utils import get_ee_img, get_patch_numpy, to_tensor | ||
|
||
# element is a tuple of (global_id, [lon,lat]) | ||
global_id = element[0] | ||
coords = element[1] | ||
|
||
image = get_ee_img(coords) | ||
patch = get_patch_numpy(coords, image) | ||
patch_tensor = to_tensor(patch) | ||
|
||
yield { | ||
"id": global_id, | ||
"long": coords[0], | ||
"lat": coords[1], | ||
"patch": patch_tensor | ||
} | ||
|
||
class Predict(beam.DoFn): | ||
def __init__(self, config_path): | ||
from common import load_yml | ||
# from _types import Config # Config was a dataclass subclass in Johns repo that type casts the yml file loaded.. | ||
|
||
self._config = load_yml(config_path) | ||
super().__init__() | ||
|
||
def setup(self): | ||
# load the model | ||
from models import get_model, freeze | ||
self.model = get_model(model_name=self._config.model_name, | ||
optimizer = self._config.optimizer, | ||
loss_fn = self._config.loss_fn, | ||
training_mode = True) | ||
self.model.load_weights(self._config.checkpoint) | ||
freeze(self.model) | ||
|
||
return super().setup() | ||
|
||
def process(self, element): | ||
|
||
model = self.model | ||
patch = element["patch"] | ||
prob = round(float(model(patch).numpy()),2) | ||
prediction = "Forest" if prob > 0.5 else "Non-Forest" | ||
|
||
yield { | ||
"id": element["id"], | ||
"long": element["long"], | ||
"lat": element["lat"], | ||
"prob_label": prob, | ||
"pred_label": prediction | ||
} | ||
|
||
# https://github.com/kubeflow/examples/blob/master/LICENSE | ||
class DictToCSVString(beam.DoFn): | ||
"""Convert incoming dict to a CSV string. | ||
This DoFn converts a Python dict into | ||
a CSV string. | ||
Args: | ||
fieldnames: A list of strings representing keys of a dict. | ||
""" | ||
|
||
def __init__(self, fieldnames): | ||
super(DictToCSVString, self).__init__() | ||
|
||
self.fieldnames = fieldnames | ||
|
||
def process(self, element, *_args, **_kwargs): | ||
"""Convert a Python dict instance into CSV string. | ||
This routine uses the Python CSV DictReader to | ||
robustly convert an input dict to a comma-separated | ||
CSV string. This also handles appropriate escaping of | ||
characters like the delimiter ",". The dict values | ||
must be serializable into a string. | ||
Args: | ||
element: A dict mapping string keys to string values. | ||
{ | ||
"key1": "STRING", | ||
"key2": "STRING" | ||
} | ||
Yields: | ||
A string representing the row in CSV format. | ||
""" | ||
fieldnames = self.fieldnames | ||
filtered_element = { | ||
key: value for (key, value) in element.items() if key in fieldnames | ||
} | ||
with io.StringIO() as stream: | ||
writer = csv.DictWriter(stream, fieldnames) | ||
writer.writerow(filtered_element) | ||
csv_string = stream.getvalue().strip("\r\n") | ||
|
||
yield csv_string | ||
|
||
|
||
def pipeline(beam_options, | ||
dotargs: SimpleNamespace): | ||
if beam_options is not None: | ||
beam_options = PipelineOptions(**load_yml(beam_options)) | ||
|
||
pColl = parse_shp_to_latlon(dotargs.input) | ||
cols = ["id", "long", "lat", "prob_label", "pred_label"] | ||
with beam.Pipeline() as p: | ||
( | ||
p | ||
| "Construct PCollection" >> beam.Create(pColl) | ||
| "Get Patch" >> beam.ParDo(GetPatch()) | ||
| "Predict" | ||
>> beam.ParDo(Predict(config_path=dotargs.model_config)) | ||
| "Dict To CSV String" >> beam.ParDo(DictToCSVString(cols)) | ||
| "Write String To CSV" >> WriteToText(dotargs.output, header=",".join(cols)) | ||
) | ||
|
||
#test file | ||
# file = 'C:\\Users\\kyle\\Downloads\\FRA_hex_shp_5records.shp' | ||
def run(): | ||
argparse.FileType() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input", "-i", type=str, required=True) | ||
parser.add_argument("--output", "-o", type=str, required=True) | ||
parser.add_argument("--model-config", "-mc", type=str, required=True) | ||
group = parser.add_argument_group("pipeline-options") | ||
group.add_argument("--beam-config", "-bc", type=str) | ||
args = parser.parse_args() | ||
|
||
pipeline( | ||
beam_options=args.beam_config, | ||
dotargs=args) | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.getLogger().setLevel(logging.INFO) | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import collections | ||
import argparse | ||
from types import SimpleNamespace | ||
import csv | ||
import io | ||
|
||
import apache_beam as beam | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.io import ReadFromCsv, WriteToText | ||
|
||
from common import load_yml | ||
|
||
|
||
TMP = "/Users/johndilger/Documents/projects/SSL4EO-S12/fao_models/TMP" | ||
BANDS = [ | ||
"B1", | ||
"B2", | ||
"B3", | ||
"B4", | ||
"B5", | ||
"B6", | ||
"B7", | ||
"B8", | ||
"B8A", | ||
"B9", | ||
"B10", | ||
"B11", | ||
"B12", | ||
] | ||
CROPS = [44, 264, 264, 264, 132, 132, 132, 264, 132, 44, 44, 132, 132] | ||
PROJECT = "pc530-fao-fra-rss" | ||
|
||
|
||
# https://github.com/kubeflow/examples/blob/master/LICENSE | ||
class DictToCSVString(beam.DoFn): | ||
"""Convert incoming dict to a CSV string. | ||
This DoFn converts a Python dict into | ||
a CSV string. | ||
Args: | ||
fieldnames: A list of strings representing keys of a dict. | ||
""" | ||
|
||
def __init__(self, fieldnames): | ||
super(DictToCSVString, self).__init__() | ||
|
||
self.fieldnames = fieldnames | ||
|
||
def process(self, element, *_args, **_kwargs) -> collections.abc.Iterator[str]: | ||
"""Convert a Python dict instance into CSV string. | ||
This routine uses the Python CSV DictReader to | ||
robustly convert an input dict to a comma-separated | ||
CSV string. This also handles appropriate escaping of | ||
characters like the delimiter ",". The dict values | ||
must be serializable into a string. | ||
Args: | ||
element: A dict mapping string keys to string values. | ||
{ | ||
"key1": "STRING", | ||
"key2": "STRING" | ||
} | ||
Yields: | ||
A string representing the row in CSV format. | ||
""" | ||
fieldnames = self.fieldnames | ||
filtered_element = { | ||
key: value for (key, value) in element.items() if key in fieldnames | ||
} | ||
with io.StringIO() as stream: | ||
writer = csv.DictWriter(stream, fieldnames) | ||
writer.writerow(filtered_element) | ||
csv_string = stream.getvalue().strip("\r\n") | ||
|
||
yield csv_string | ||
|
||
|
||
class ComputeWordLengthFn(beam.DoFn): | ||
def process(self, element): | ||
return [len(element)] | ||
|
||
|
||
class Predict(beam.DoFn): | ||
def __init__(self, config_path): | ||
from common import load_yml | ||
from _types import Config | ||
|
||
self._config = Config(**load_yml(config_path)) | ||
super().__init__() | ||
|
||
def setup(self): | ||
self.load_model() | ||
return super().setup() | ||
|
||
def load_model(self): | ||
"""load model""" | ||
from models._models import get_model | ||
from models.dino.utils import restart_from_checkpoint | ||
import os | ||
|
||
c = self._config | ||
self.model, self.linear_classifier = get_model(**c.__dict__) | ||
restart_from_checkpoint( | ||
os.path.join(c.model_head_root), | ||
state_dict=self.linear_classifier, | ||
) | ||
|
||
def process(self, element): | ||
import torch | ||
from datasets.ssl4eo_dataset import SSL4EO | ||
|
||
dataset = SSL4EO( | ||
root=element["img_root"].parent, | ||
mode="s2c", | ||
normalize=False, # todo add normalized to self._config. | ||
) | ||
|
||
image = dataset[0] | ||
image = torch.unsqueeze(torch.tensor(image), 0).type(torch.float32) | ||
|
||
self.linear_classifier.eval() | ||
with torch.no_grad(): | ||
intermediate_output = self.model.get_intermediate_layers( | ||
image, self._config.n_last_blocks | ||
) | ||
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) | ||
|
||
output = self.linear_classifier(output) | ||
element["prob_label"] = output.detach().cpu().item() | ||
element["pred_label"] = round(element["prob_label"]) | ||
yield element | ||
|
||
|
||
class GetImagery(beam.DoFn): | ||
def __init__(self, dst): | ||
self.dst = dst | ||
super().__init__() | ||
|
||
def setup(self): | ||
import ee | ||
import google.auth | ||
|
||
credentials, _ = google.auth.default() | ||
ee.Initialize( | ||
credentials, | ||
project=PROJECT, | ||
opt_url="https://earthengine-highvolume.googleapis.com", | ||
) | ||
return super().setup() | ||
|
||
def process(self, element): | ||
"""download imagery""" | ||
from download_data.download_wraper import single_patch | ||
from pathlib import Path | ||
|
||
sample = element | ||
coords = (sample.long, sample.lat) | ||
local_root = Path(self.dst) | ||
img_root = single_patch( | ||
coords, | ||
id=sample.id, | ||
dst=local_root / "imgs", | ||
year=2019, | ||
bands=BANDS, | ||
crop_dimensions=CROPS, | ||
) | ||
yield { | ||
"img_root": img_root, | ||
"long": sample.long, | ||
"lat": sample.lat, | ||
"id": sample.id, | ||
} | ||
|
||
|
||
def pipeline(beam_options, dotargs: SimpleNamespace): | ||
if beam_options is not None: | ||
beam_options = PipelineOptions(**load_yml(beam_options)) | ||
|
||
cols = ["id", "long", "lat", "prob_label", "pred_label"] | ||
with beam.Pipeline() as p: | ||
bdf = ( | ||
p | ||
| "read input data" >> ReadFromCsv(dotargs.input) | ||
| "download imagery" | ||
>> beam.ParDo(GetImagery(dst=TMP)).with_output_types(dict) | ||
| "predict" | ||
>> beam.ParDo(Predict(config_path=dotargs.model_config)).with_output_types( | ||
dict | ||
) | ||
| "to csv str" >> beam.ParDo(DictToCSVString(cols)) | ||
| "write to csv" >> WriteToText(dotargs.output, header=",".join(cols)) | ||
) | ||
|
||
|
||
def run(): | ||
argparse.FileType() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input", "-i", type=str, required=True) | ||
parser.add_argument("--output", "-o", type=str, required=True) | ||
parser.add_argument("--model-config", "-mc", type=str, required=True) | ||
group = parser.add_argument_group("pipeline-options") | ||
group.add_argument("--beam-config", "-bc", type=str) | ||
args = parser.parse_args() | ||
|
||
pipeline(beam_options=args.beam_config, dotargs=args) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
Oops, something went wrong.