-
Notifications
You must be signed in to change notification settings - Fork 387
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
consolidate prediction options into PredictOptions for all tasks (#2055)
Changes: - Move `predict_chip_sz` and `predict_batch_sz` to `PredictOptions`. Also add `stride`. - Make `predict_options` a field in `RVPipelineConfig`. Previously, this was only defined in the SS and OD subclasses. - Make `Backend.predict_scene()` take `PredictOptions` instead of `chip_sz`, `stride` etc. - Move default `stride`, `crop_sz` initialization to pydantic validators. - Move OD prediction post-processing to the OD PyTorch `Backend`. - Remove unused SS post processing. - Remove support for `RASTERVISION_PREDICT_BATCH_SIZE` `RVConfig` param that was used by `Learner.predict_dataset()`. - Update usage in examples. - Update unit and integration tests.
Showing
32 changed files
with
244 additions
and
240 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,4 +31,5 @@ | |
ChipOptions.__name__, | ||
WindowSamplingConfig.__name__, | ||
WindowSamplingMethod.__name__, | ||
PredictOptions.__name__, | ||
] |
30 changes: 1 addition & 29 deletions
30
rastervision_core/rastervision/core/rv_pipeline/object_detection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,5 @@ | ||
from typing import TYPE_CHECKING | ||
import logging | ||
|
||
from rastervision.core.rv_pipeline import RVPipeline | ||
from rastervision.core.data.label import ObjectDetectionLabels | ||
|
||
if TYPE_CHECKING: | ||
from rastervision.core.data import Labels, Scene | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class ObjectDetection(RVPipeline): | ||
def predict_scene(self, scene: 'Scene') -> 'Labels': | ||
if self.backend is None: | ||
self.build_backend() | ||
|
||
# Use strided windowing to ensure that each object is fully visible (ie. not | ||
# cut off) within some window. This means prediction takes 4x longer for object | ||
# detection :( | ||
chip_sz = self.config.predict_chip_sz | ||
stride = chip_sz // 2 | ||
labels = self.backend.predict_scene( | ||
scene, chip_sz=chip_sz, stride=stride) | ||
labels = self.post_process_predictions(labels, scene) | ||
return labels | ||
|
||
def post_process_predictions(self, labels: ObjectDetectionLabels, | ||
scene: 'Scene') -> ObjectDetectionLabels: | ||
return ObjectDetectionLabels.prune_duplicates( | ||
labels, | ||
score_thresh=self.config.predict_options.score_thresh, | ||
merge_thresh=self.config.predict_options.merge_thresh) | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 1 addition & 49 deletions
50
rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,5 @@ | ||
from typing import TYPE_CHECKING | ||
import logging | ||
|
||
import numpy as np | ||
|
||
from rastervision.core.rv_pipeline import RVPipeline | ||
|
||
if TYPE_CHECKING: | ||
from rastervision.core.data import ( | ||
Labels, | ||
Scene, | ||
) | ||
from rastervision.core.rv_pipeline.semantic_segmentation_config import ( | ||
SemanticSegmentationConfig) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class SemanticSegmentation(RVPipeline): | ||
def post_process_batch(self, windows, chips, labels): | ||
# Fill in null class for any NODATA pixels. | ||
null_class_id = self.config.dataset.class_config.null_class_id | ||
for window, chip in zip(windows, chips): | ||
nodata_mask = np.sum(chip, axis=2) == 0 | ||
labels.mask_fill(window, nodata_mask, fill_value=null_class_id) | ||
|
||
return labels | ||
|
||
def predict_scene(self, scene: 'Scene') -> 'Labels': | ||
if self.backend is None: | ||
self.build_backend() | ||
|
||
cfg: 'SemanticSegmentationConfig' = self.config | ||
chip_sz = cfg.predict_chip_sz | ||
stride = cfg.predict_options.stride | ||
crop_sz = cfg.predict_options.crop_sz | ||
|
||
if stride is None: | ||
stride = chip_sz | ||
|
||
if crop_sz == 'auto': | ||
overlap_sz = chip_sz - stride | ||
if overlap_sz % 2 == 1: | ||
log.warning( | ||
'Using crop_sz="auto" but overlap size (chip_sz minus ' | ||
'stride) is odd. This means that one pixel row/col will ' | ||
'still overlap after cropping.') | ||
crop_sz = overlap_sz // 2 | ||
|
||
labels = self.backend.predict_scene( | ||
scene, chip_sz=chip_sz, stride=stride, crop_sz=crop_sz) | ||
labels = self.post_process_predictions(labels, scene) | ||
return labels | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import unittest | ||
|
||
from rastervision.core.rv_pipeline import (ObjectDetectionPredictOptions) | ||
|
||
|
||
class TestObjectDetectionPredictOptions(unittest.TestCase): | ||
def test_stride_validator(self): | ||
cfg = ObjectDetectionPredictOptions(chip_sz=10) | ||
self.assertEqual(cfg.stride, 5) | ||
cfg = ObjectDetectionPredictOptions(chip_sz=11) | ||
self.assertEqual(cfg.stride, 5) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import unittest | ||
|
||
from rastervision.core.data import (ClassConfig, DatasetConfig) | ||
from rastervision.core.backend import (BackendConfig) | ||
from rastervision.core.rv_pipeline.rv_pipeline_config import ( | ||
PredictOptions, RVPipelineConfig, rv_pipeline_config_upgrader) | ||
|
||
|
||
class TestPredictOptions(unittest.TestCase): | ||
def test_stride_validator(self): | ||
cfg = PredictOptions(chip_sz=10) | ||
self.assertEqual(cfg.stride, 10) | ||
|
||
|
||
class TestRVPipelineConfig(unittest.TestCase): | ||
def test_upgrader(self): | ||
cfg_dict = dict( | ||
dataset=DatasetConfig( | ||
class_config=ClassConfig(names=[]), | ||
train_scenes=[], | ||
validation_scenes=[]), | ||
backend=BackendConfig(), | ||
train_chip_sz=20, | ||
chip_nodata_threshold=0.5, | ||
predict_chip_sz=20, | ||
predict_batch_sz=8) | ||
cfg_dict = rv_pipeline_config_upgrader(cfg_dict, 10) | ||
cfg_dict = rv_pipeline_config_upgrader(cfg_dict, 11) | ||
cfg = RVPipelineConfig(**cfg_dict) | ||
|
||
cfg_dict = dict( | ||
dataset=DatasetConfig( | ||
class_config=ClassConfig(names=[]), | ||
train_scenes=[], | ||
validation_scenes=[]), | ||
backend=BackendConfig(), | ||
train_chip_sz=20, | ||
chip_nodata_threshold=0.5, | ||
chip_options=dict(method='random'), | ||
predict_chip_sz=20, | ||
predict_batch_sz=8, | ||
predict_options=dict()) | ||
cfg_dict = rv_pipeline_config_upgrader(cfg_dict, 10) | ||
cfg_dict = rv_pipeline_config_upgrader(cfg_dict, 11) | ||
cfg = RVPipelineConfig(**cfg_dict) | ||
self.assertEqual(cfg.chip_options.get_chip_sz(), 20) | ||
self.assertEqual(cfg.chip_options.nodata_threshold, 0.5) | ||
self.assertEqual(cfg.predict_options.chip_sz, 20) | ||
self.assertEqual(cfg.predict_options.batch_sz, 8) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters