Skip to content

Commit

Permalink
Merge pull request #22 from MarjanAsgari/main
Browse files Browse the repository at this point in the history
transformers_as_arg
  • Loading branch information
mpelchat04 authored Oct 4, 2024
2 parents ab02c3e + 3e468d6 commit 0fb5766
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 2 deletions.
3 changes: 3 additions & 0 deletions geo_inference/config/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ arguments:
vec: False # Vector Conversion: bool
yolo: False # YOLO Conversion: bool
coco: False # COCO Conversion: bool
transformers : True
transformer_flip : False
transformer_rotate : True
device: "gpu" # cpu or gpu: str
gpu_id: 0
mgpu: False
Expand Down
27 changes: 26 additions & 1 deletion geo_inference/geo_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
import numpy as np
import xarray as xr
import ttach as tta
from typing import Dict
from dask import config
import dask.array as da
Expand Down Expand Up @@ -83,6 +84,9 @@ def __init__(
gpu_id: int = 0,
num_classes: int = 5,
prediction_threshold : float = 0.3,
transformers : bool = False,
transformer_flip: bool = False,
transformer_rotate: bool = False,
):
self.work_dir: Path = get_directory(work_dir)
self.device = (
Expand All @@ -95,6 +99,23 @@ def __init__(
),
map_location=self.device,
)
if transformers:
if transformer_flip and transformer_rotate: # do all
transforms = tta.aliases.d4_transform()
elif transformer_rotate: # do rotate only
transforms = tta.Compose(
[
tta.Rotate90(angles=[90]),
]
)
elif transformer_flip: # do flip only
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
]
)
self.model = tta.SegmentationTTAWrapper(self.model, transforms, merge_mode='mean')
self.mask_to_vec = mask_to_vec
self.mask_to_coco = mask_to_coco
self.mask_to_yolo = mask_to_yolo
Expand Down Expand Up @@ -363,7 +384,10 @@ def main() -> None:
device=arguments["device"],
gpu_id=arguments["gpu_id"],
num_classes=arguments["classes"],
prediction_threshold=arguments["prediction_threshold"]
prediction_threshold=arguments["prediction_threshold"],
transformers=arguments["transformers"],
transformer_flip=arguments["transformer_flip"],
transformer_rotate=arguments["transformer_rotate"],
)
inference_mask_layer_name = geo_inference(
inference_input=arguments["image"],
Expand All @@ -372,6 +396,7 @@ def main() -> None:
workers=arguments["workers"],
bbox=arguments["bbox"],
)
print(inference_mask_layer_name)



Expand Down
15 changes: 15 additions & 0 deletions geo_inference/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def cmd_interface(argv=None):

parser.add_argument("-pr", "--prediction_thr", type=float, nargs=1, help="Prediction Threshold")

parser.add_argument("-tr", "--transformers", nargs=1, help="Transformers Addition")
parser.add_argument("-tr_f", "--transformer_flip", nargs=1, help="Transformers Addition - Flip")
parser.add_argument("-tr_e", "--transformer_rotate", nargs=1, help="Transformers Addition - Rotate")

args = parser.parse_args()

if args.args:
Expand All @@ -452,6 +456,10 @@ def cmd_interface(argv=None):
classes = config["arguments"]["classes"]
patch_size = config["arguments"]["patch_size"]
prediction_threshold = config["arguments"]["prediction_thr"]
transformers = config["arguments"]["transformers"]
transformer_flip = config["arguments"]["transformer_flip"]
transformer_rotate = config["arguments"]["transformer_rotate"]

elif args.image:
image =args.image[0]
model = args.model[0] if args.model else None
Expand All @@ -468,6 +476,10 @@ def cmd_interface(argv=None):
classes = args.classes[0] if args.classes else 5
patch_size = args.patch_size[0] if args.patch_size else 1024
prediction_threshold = args.prediction_thr[0] if args.prediction_thr else 0.3
transformers = args.transformers[0] if args.transformers else False
transformer_flip = args.transformer_flip if args.transformer_flip else False
transformer_rotate = args.transformer_rotate if args.transformer_rotate else False

else:
print("use the help [-h] option for correct usage")
raise SystemExit
Expand All @@ -487,6 +499,9 @@ def cmd_interface(argv=None):
"gpu_id": gpu_id,
"patch_size": patch_size,
"prediction_threshold": prediction_threshold,
"transformers": transformers,
"transformer_flip": transformer_flip,
"transformer_rotate":transformer_rotate,
}
return arguments

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ dask[distributed]>=2024.6.2
requests>=2.32.3
xarray>=2024.6.0
pystac>=1.10.1
rioxarray>=0.15.6
rioxarray>=0.15.6
ttach>=0.0.3
3 changes: 3 additions & 0 deletions tests/data/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ arguments:
classes : 5
n_workers: 20
prediction_thr : 0.3
transformers: False
transformer_flip : False
transformer_rotate : False
patch_size: 1024
9 changes: 9 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_read_yaml(test_data_dir):
"mgpu": False,
"classes": 5,
"prediction_thr": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"n_workers": 20
}

Expand Down Expand Up @@ -145,6 +148,9 @@ def test_cmd_interface_with_args(monkeypatch, test_data_dir):
"classes": 5,
"multi_gpu": False,
"prediction_threshold": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"patch_size": 1024
}

Expand All @@ -169,6 +175,9 @@ def test_cmd_interface_with_image(monkeypatch):
"gpu_id": 0,
"classes": 5,
"prediction_threshold": 0.3,
"transformers": False,
"transformer_flip" : False,
"transformer_rotate" : False,
"multi_gpu": False,
}

Expand Down

0 comments on commit 0fb5766

Please sign in to comment.