From 26cd2c9de15f7d5116a09699136bb0fb34d53c94 Mon Sep 17 00:00:00 2001 From: John Dilger Date: Mon, 17 Jun 2024 12:36:27 -0500 Subject: [PATCH] model params from dict, define pipeline as var, format --- fao_models/beam_pipeline.py | 61 ++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/fao_models/beam_pipeline.py b/fao_models/beam_pipeline.py index b673de8..1076823 100644 --- a/fao_models/beam_pipeline.py +++ b/fao_models/beam_pipeline.py @@ -22,6 +22,7 @@ # 4. load model and run inference on tensor to return prediction value # 5. write prediction value to new CSV file with (global id, lat, long, prediction value) + class GetPatch(beam.DoFn): def __init__(self): super().__init__() @@ -40,26 +41,28 @@ def setup(self): def process(self, element): from beam_utils import get_ee_img, get_patch_numpy, to_tensor - + # element is a tuple of (global_id, [lon,lat]) global_id = element[0] coords = element[1] - + image = get_ee_img(coords) patch = get_patch_numpy(coords, image) patch_tensor = to_tensor(patch) - + yield { "id": global_id, "long": coords[0], "lat": coords[1], - "patch": patch_tensor + "patch": patch_tensor, } + class Predict(beam.DoFn): def __init__(self, config_path): from common import load_yml - # from _types import Config # Config was a dataclass subclass in Johns repo that type casts the yml file loaded.. + + # from _types import Config # Config was a dataclass subclass in Johns repo that type casts the yml file loaded.. self._config = load_yml(config_path) super().__init__() @@ -67,30 +70,34 @@ def __init__(self, config_path): def setup(self): # load the model from models import get_model, freeze - self.model = get_model(model_name=self._config.model_name, - optimizer = self._config.optimizer, - loss_fn = self._config.loss_fn, - training_mode = True) - self.model.load_weights(self._config.checkpoint) + + self.model = get_model( + model_name=self._config["model_name"], + optimizer=self._config["optimizer"], + loss_fn=self._config["loss_function"], + training_mode=True, + ) + self.model.load_weights(self._config["checkpoint"]) freeze(self.model) return super().setup() - + def process(self, element): - + model = self.model patch = element["patch"] - prob = round(float(model(patch).numpy()),2) + prob = round(float(model(patch).numpy()), 2) prediction = "Forest" if prob > 0.5 else "Non-Forest" - + yield { "id": element["id"], "long": element["long"], "lat": element["lat"], "prob_label": prob, - "pred_label": prediction + "pred_label": prediction, } + # https://github.com/kubeflow/examples/blob/master/LICENSE class DictToCSVString(beam.DoFn): """Convert incoming dict to a CSV string. @@ -138,25 +145,25 @@ def process(self, element, *_args, **_kwargs): yield csv_string -def pipeline(beam_options, - dotargs: SimpleNamespace): +def pipeline(beam_options, dotargs: SimpleNamespace): if beam_options is not None: beam_options = PipelineOptions(**load_yml(beam_options)) - + pColl = parse_shp_to_latlon(dotargs.input) cols = ["id", "long", "lat", "prob_label", "pred_label"] with beam.Pipeline() as p: - ( + var = ( p | "Construct PCollection" >> beam.Create(pColl) | "Get Patch" >> beam.ParDo(GetPatch()) - | "Predict" - >> beam.ParDo(Predict(config_path=dotargs.model_config)) + | "Predict" >> beam.ParDo(Predict(config_path=dotargs.model_config)) | "Dict To CSV String" >> beam.ParDo(DictToCSVString(cols)) - | "Write String To CSV" >> WriteToText(dotargs.output, header=",".join(cols)) + | "Write String To CSV" + >> WriteToText(dotargs.output, header=",".join(cols)) ) - -#test file + + +# test file # file = 'C:\\Users\\kyle\\Downloads\\FRA_hex_shp_5records.shp' def run(): argparse.FileType() @@ -169,11 +176,9 @@ def run(): group.add_argument("--beam-config", "-bc", type=str) args = parser.parse_args() - pipeline( - beam_options=args.beam_config, - dotargs=args) + pipeline(beam_options=args.beam_config, dotargs=args) if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) - run() \ No newline at end of file + run()