Skip to content

Commit

Permalink
improve test coverage
Browse files Browse the repository at this point in the history
AdeelH committed Feb 2, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 7d8e091 commit 2ed63e3
Showing 11 changed files with 33 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ def run(self,
pipeline: 'Pipeline',
commands: List[str],
num_splits: int = 1,
pipeline_run_name: str = 'raster-vision'):
pipeline_run_name: str = 'raster-vision'): # pragma: no cover
cmd, args = self.build_cmd(
cfg_json_uri,
pipeline,
@@ -117,7 +117,7 @@ def run_command(self,
use_gpu: bool = False,
job_queue: Optional[str] = None,
job_def: Optional[str] = None,
**kwargs) -> str:
**kwargs) -> str: # pragma: no cover
"""Submit a command as a job to AWS Batch.
Args:
Original file line number Diff line number Diff line change
@@ -7,13 +7,11 @@
from rastervision.pipeline.config import (register_config, Field)


def ss_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 4:
try:
# removed in version 4
del cfg_dict['rgb_class_config']
except KeyError:
pass
def ss_label_source_config_upgrader(cfg_dict: dict,
version: int) -> dict: # pragma: no cover
if version == 3:
# removed in version 4
cfg_dict.pop('rgb_class_config', None)
return cfg_dict


Original file line number Diff line number Diff line change
@@ -9,7 +9,8 @@
from rastervision.core.rv_pipeline import RVPipelineConfig


def rs_config_upgrader(cfg_dict: dict, version: int) -> dict:
def rs_config_upgrader(cfg_dict: dict,
version: int) -> dict: # pragma: no cover
if version == 6:
# removed in version 7
if cfg_dict.get('extent_crop') is not None:
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@
from rastervision.pipeline.config import ConfigError, Field, register_config


def rasterio_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
def rasterio_source_config_upgrader(cfg_dict: dict,
version: int) -> dict: # pragma: no cover
if version == 5:
# removed in version 6
x_shift = cfg_dict.get('x_shift', 0)
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@
SceneConfig, VectorSource)


def vector_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
def vector_source_config_upgrader(cfg_dict: dict,
version: int) -> dict: # pragma: no cover
if version == 4:
from rastervision.core.data.vector_transformer import (
ClassInferenceTransformerConfig, BufferTransformerConfig)
Original file line number Diff line number Diff line change
@@ -10,13 +10,11 @@
from rastervision.core.data import ClassConfig


def ss_evaluator_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 3:
try:
# removed in version 3
del cfg_dict['vector_output_uri']
except KeyError:
pass
def ss_evaluator_config_upgrader(cfg_dict: dict,
version: int) -> dict: # pragma: no cover
if version == 2:
# removed in version 3
cfg_dict.pop('vector_output_uri', None)
return cfg_dict


Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
PyTorchChipClassification)


def clf_learner_backend_config_upgrader(cfg_dict, version):
def clf_learner_backend_config_upgrader(cfg_dict, version): # pragma: no cover
if version == 0:
fields = {
'augmentors': default_augmentors,
Original file line number Diff line number Diff line change
@@ -9,7 +9,8 @@
PyTorchObjectDetection)


def objdet_learner_backend_config_upgrader(cfg_dict, version):
def objdet_learner_backend_config_upgrader(cfg_dict,
version): # pragma: no cover
if version == 0:
fields = {
'augmentors': default_augmentors,
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
PyTorchSemanticSegmentation)


def ss_learner_backend_config_upgrader(cfg_dict, version):
def ss_learner_backend_config_upgrader(cfg_dict, version): # pragma: no cover
if version == 0:
fields = {
'augmentors': default_augmentors,
Original file line number Diff line number Diff line change
@@ -987,6 +987,7 @@ def setup_ddp_params(self):
if not self.distributed:
return

# pragma: no cover
if self.model is not None:
raise ValueError(
'In distributed mode, the model must be specified via '
@@ -1147,7 +1148,7 @@ def setup_data(self, distributed: Optional[bool] = None):
distributed = self.distributed

if self.train_ds is None or self.valid_ds is None:
if distributed:
if distributed: # pragma: no cover
if self.is_ddp_local_master:
train_ds, valid_ds, test_ds = self.build_datasets()
log.debug(f'{self.ddp_rank=} Done.')
@@ -1219,7 +1220,7 @@ def build_dataloader(self,
collate_fn = self.get_collate_fn()
sampler = self.build_sampler(ds, split, distributed=distributed)

if distributed:
if distributed: # pragma: no cover
world_sz = self.ddp_world_size
if world_sz is None:
raise ValueError('World size not set. '
@@ -1269,14 +1270,14 @@ def build_sampler(self,
split = split.lower()
sampler = None
if split == 'train':
if distributed:
if distributed: # pragma: no cover
sampler = DistributedSampler(
ds,
shuffle=True,
num_replicas=self.ddp_world_size,
rank=self.ddp_rank)
elif split == 'valid':
if distributed:
if distributed: # pragma: no cover
sampler = DistributedSampler(
ds,
shuffle=False,
@@ -1560,8 +1561,8 @@ def _bundle_transforms(self, model_bundle_dir: str) -> None:
#########
# Misc.
#########
def ddp(self, rank: Optional[int] = None,
world_size: Optional[int] = None) -> DDPContextManager:
def ddp(self, rank: Optional[int] = None, world_size: Optional[int] = None
) -> DDPContextManager: # pragma: no cover
"""Return a :class:`DDPContextManager`.
This should be used to wrap code that needs to be executed in parallel.
@@ -1759,7 +1760,7 @@ def setup_tensorboard(self):

def run_tensorboard(self):
"""Run TB server serving logged stats."""
if self.cfg.run_tensorboard:
if self.cfg.run_tensorboard: # pragma: no cover
log.info('Starting tensorboard process')
self.tb_process = Popen(
['tensorboard', '--bind_all', f'--logdir={self.tb_log_dir}'])
@@ -1769,7 +1770,7 @@ def stop_tensorboard(self):
"""Stop TB logging and server if it's running."""
if self.tb_writer is not None:
self.tb_writer.close()
if self.tb_process is not None:
if self.tb_process is not None: # pragma: no cover
self.tb_process.terminate()

@property
3 changes: 2 additions & 1 deletion tests/pytorch_learner/test_semantic_segmentation_learner.py
Original file line number Diff line number Diff line change
@@ -122,7 +122,8 @@ def _test_learner(self,
learner = backend.learner_cfg.build(tmp_dir, training=True)

learner.plot_dataloaders()
learner.train()
learner.train(1)
learner.train(1)
learner.plot_predictions(split='valid')
learner.save_model_bundle()

0 comments on commit 2ed63e3

Please sign in to comment.