Skip to content

Commit

Permalink
model params from dict, define pipeline as var, format
Browse files Browse the repository at this point in the history
  • Loading branch information
jdilger committed Jun 17, 2024
1 parent 10e8581 commit 26cd2c9
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions fao_models/beam_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# 4. load model and run inference on tensor to return prediction value
# 5. write prediction value to new CSV file with (global id, lat, long, prediction value)


class GetPatch(beam.DoFn):
def __init__(self):
super().__init__()
Expand All @@ -40,57 +41,63 @@ def setup(self):

def process(self, element):
from beam_utils import get_ee_img, get_patch_numpy, to_tensor

# element is a tuple of (global_id, [lon,lat])
global_id = element[0]
coords = element[1]

image = get_ee_img(coords)
patch = get_patch_numpy(coords, image)
patch_tensor = to_tensor(patch)

yield {
"id": global_id,
"long": coords[0],
"lat": coords[1],
"patch": patch_tensor
"patch": patch_tensor,
}


class Predict(beam.DoFn):
def __init__(self, config_path):
from common import load_yml
# from _types import Config # Config was a dataclass subclass in Johns repo that type casts the yml file loaded..

# from _types import Config # Config was a dataclass subclass in Johns repo that type casts the yml file loaded..

self._config = load_yml(config_path)
super().__init__()

def setup(self):
# load the model
from models import get_model, freeze
self.model = get_model(model_name=self._config.model_name,
optimizer = self._config.optimizer,
loss_fn = self._config.loss_fn,
training_mode = True)
self.model.load_weights(self._config.checkpoint)

self.model = get_model(
model_name=self._config["model_name"],
optimizer=self._config["optimizer"],
loss_fn=self._config["loss_function"],
training_mode=True,
)
self.model.load_weights(self._config["checkpoint"])
freeze(self.model)

return super().setup()

def process(self, element):

model = self.model
patch = element["patch"]
prob = round(float(model(patch).numpy()),2)
prob = round(float(model(patch).numpy()), 2)
prediction = "Forest" if prob > 0.5 else "Non-Forest"

yield {
"id": element["id"],
"long": element["long"],
"lat": element["lat"],
"prob_label": prob,
"pred_label": prediction
"pred_label": prediction,
}


# https://github.com/kubeflow/examples/blob/master/LICENSE
class DictToCSVString(beam.DoFn):
"""Convert incoming dict to a CSV string.
Expand Down Expand Up @@ -138,25 +145,25 @@ def process(self, element, *_args, **_kwargs):
yield csv_string


def pipeline(beam_options,
dotargs: SimpleNamespace):
def pipeline(beam_options, dotargs: SimpleNamespace):
if beam_options is not None:
beam_options = PipelineOptions(**load_yml(beam_options))

pColl = parse_shp_to_latlon(dotargs.input)
cols = ["id", "long", "lat", "prob_label", "pred_label"]
with beam.Pipeline() as p:
(
var = (
p
| "Construct PCollection" >> beam.Create(pColl)
| "Get Patch" >> beam.ParDo(GetPatch())
| "Predict"
>> beam.ParDo(Predict(config_path=dotargs.model_config))
| "Predict" >> beam.ParDo(Predict(config_path=dotargs.model_config))
| "Dict To CSV String" >> beam.ParDo(DictToCSVString(cols))
| "Write String To CSV" >> WriteToText(dotargs.output, header=",".join(cols))
| "Write String To CSV"
>> WriteToText(dotargs.output, header=",".join(cols))
)

#test file


# test file
# file = 'C:\\Users\\kyle\\Downloads\\FRA_hex_shp_5records.shp'
def run():
argparse.FileType()
Expand All @@ -169,11 +176,9 @@ def run():
group.add_argument("--beam-config", "-bc", type=str)
args = parser.parse_args()

pipeline(
beam_options=args.beam_config,
dotargs=args)
pipeline(beam_options=args.beam_config, dotargs=args)


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run()
run()

0 comments on commit 26cd2c9

Please sign in to comment.