diff --git a/.readthedocs.yml b/.readthedocs.yml index 022dd676e..324282218 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -40,6 +40,8 @@ python: path: rastervision_pytorch_learner/ - method: pip path: rastervision_pytorch_backend/ + - method: pip + path: rastervision_aws_sagemaker/ # https://docs.readthedocs.io/en/stable/config-file/v2.html#search search: diff --git a/docs/framework/examples.rst b/docs/framework/examples.rst index 5abb09088..c231b5068 100644 --- a/docs/framework/examples.rst +++ b/docs/framework/examples.rst @@ -83,7 +83,7 @@ The ``--tensorboard`` option should be used if running locally and you would lik export PROCESSED_URI="/opt/data/examples/spacenet/rio/processed-data" export ROOT_URI="/opt/data/examples/spacenet/rio/local-output" - rastervision run local rastervision.examples.chip_classification.spacenet_rio \ + rastervision run local rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \ -a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \ -a test True --splits 2 @@ -104,7 +104,7 @@ To run the full experiment on GPUs using AWS Batch, use something like the follo export PROCESSED_URI="s3://mybucket/examples/spacenet/rio/processed-data" export ROOT_URI="s3://mybucket/examples/spacenet/rio/remote-output" - rastervision run batch rastervision.examples.chip_classification.spacenet_rio \ + rastervision run batch rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \ -a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \ -a test False --splits 8 diff --git a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py index f6d4ff854..4c33deb78 100644 --- a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py +++ b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Any, Iterator, Tuple import io import os import subprocess @@ -16,41 +16,38 @@ # Code from https://alexwlchan.net/2017/07/listing-s3-keys/ -def get_matching_s3_objects(bucket, prefix='', suffix='', - request_payer='None'): - """ - Generate objects in an S3 bucket. - - :param bucket: Name of the S3 bucket. - :param prefix: Only fetch objects whose key starts with - this prefix (optional). - :param suffix: Only fetch objects whose keys end with - this suffix (optional). +def get_matching_s3_objects( + bucket: str, + prefix: str = '', + suffix: str = '', + delimiter: str = '/', + request_payer: str = 'None') -> Iterator[tuple[str, Any]]: + """Generate objects in an S3 bucket. + + Args: + bucket: Name of the S3 bucket. + prefix: Only fetch objects whose key starts with this prefix. + suffix: Only fetch objects whose keys end with this suffix. """ s3 = S3FileSystem.get_client() - kwargs = {'Bucket': bucket, 'RequestPayer': request_payer} - - # If the prefix is a single string (not a tuple of strings), we can - # do the filtering directly in the S3 API. - if isinstance(prefix, str): - kwargs['Prefix'] = prefix - + kwargs = dict( + Bucket=bucket, + RequestPayer=request_payer, + Delimiter=delimiter, + Prefix=prefix, + ) while True: - - # The S3 API response is a large blob of metadata. - # 'Contents' contains information about the listed objects. - resp = s3.list_objects_v2(**kwargs) - - try: - contents = resp['Contents'] - except KeyError: - return - - for obj in contents: + resp: dict = s3.list_objects_v2(**kwargs) + dirs: list[dict] = resp.get('CommonPrefixes', {}) + files: list[dict] = resp.get('Contents', {}) + for obj in dirs: + key = obj['Prefix'] + if key.startswith(prefix) and key.endswith(suffix): + yield key, obj + for obj in files: key = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): - yield obj - + yield key, obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). @@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='', break -def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'): - """ - Generate the keys in an S3 bucket. +def get_matching_s3_keys(bucket: str, + prefix: str = '', + suffix: str = '', + delimiter: str = '/', + request_payer: str = 'None') -> Iterator[str]: + """Generate the keys in an S3 bucket. - :param bucket: Name of the S3 bucket. - :param prefix: Only fetch keys that start with this prefix (optional). - :param suffix: Only fetch keys that end with this suffix (optional). + Args: + bucket: Name of the S3 bucket. + prefix: Only fetch keys that start with this prefix. + suffix: Only fetch keys that end with this suffix. """ - for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer): - yield obj['Key'] + obj_iterator = get_matching_s3_objects( + bucket, + prefix=prefix, + suffix=suffix, + delimiter=delimiter, + request_payer=request_payer) + out = (key for key, _ in obj_iterator) + return out def progressbar(total_size: int, desc: str): @@ -180,8 +187,9 @@ def read_bytes(uri: str) -> bytes: bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: - file_size = s3.head_object( - Bucket=bucket, Key=key)['ContentLength'] + obj = s3.head_object( + Bucket=bucket, Key=key, RequestPayer=request_payer) + file_size = obj['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, @@ -256,7 +264,9 @@ def copy_from(src_uri: str, dst_path: str) -> None: request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: - file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength'] + obj = s3.head_object( + Bucket=bucket, Key=key, RequestPayer=request_payer) + file_size = obj['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket, @@ -284,11 +294,16 @@ def last_modified(uri: str) -> datetime: return head_data['LastModified'] @staticmethod - def list_paths(uri, ext=''): + def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]: request_payer = S3FileSystem.get_request_payer() parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( - bucket, prefix, suffix=ext, request_payer=request_payer) - return [os.path.join('s3://', bucket, key) for key in keys] + bucket, + prefix, + suffix=ext, + delimiter=delimiter, + request_payer=request_payer) + paths = [os.path.join('s3://', bucket, key) for key in keys] + return paths diff --git a/rastervision_core/rastervision/core/data/dataset_config.py b/rastervision_core/rastervision/core/data/dataset_config.py index 0fd7e12d3..b2e7348f8 100644 --- a/rastervision_core/rastervision/core/data/dataset_config.py +++ b/rastervision_core/rastervision/core/data/dataset_config.py @@ -90,3 +90,12 @@ def get_split_config(self, split_ind, num_splits): @property def all_scenes(self) -> List[SceneConfig]: return self.train_scenes + self.validation_scenes + self.test_scenes + + def __repr__(self): + num_train = len(self.train_scenes) + num_val = len(self.validation_scenes) + num_test = len(self.test_scenes) + out = (f'DatasetConfig(train_scenes=<{num_train} scenes>, ' + f'validation_scenes=<{num_val} scenes>, ' + f'test_scenes=<{num_test} scenes>)') + return out diff --git a/rastervision_core/rastervision/core/data/raster_source/raster_source.py b/rastervision_core/rastervision/core/data/raster_source/raster_source.py index 6a9343d9b..da851d4e6 100644 --- a/rastervision_core/rastervision/core/data/raster_source/raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/raster_source.py @@ -147,24 +147,20 @@ def get_chip(self, return chip - def get_chip_by_map_window( - self, - window_map_coords: 'Box', - out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray': - """Same as get_chip(), but input is a window in map coords. """ + def get_chip_by_map_window(self, window_map_coords: 'Box', *args, + **kwargs) -> 'np.ndarray': + """Same as get_chip(), but input is a window in map coords.""" window_pixel_coords = self.crs_transformer.map_to_pixel( window_map_coords, bbox=self.bbox).normalize() - chip = self.get_chip(window_pixel_coords, out_shape=out_shape) + chip = self.get_chip(window_pixel_coords, *args, **kwargs) return chip - def _get_chip_by_map_window( - self, - window_map_coords: 'Box', - out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray': - """Same as _get_chip(), but input is a window in map coords. """ + def _get_chip_by_map_window(self, window_map_coords: 'Box', *args, + **kwargs) -> 'np.ndarray': + """Same as _get_chip(), but input is a window in map coords.""" window_pixel_coords = self.crs_transformer.map_to_pixel( window_map_coords, bbox=self.bbox) - chip = self._get_chip(window_pixel_coords, out_shape=out_shape) + chip = self._get_chip(window_pixel_coords, *args, **kwargs) return chip def get_raw_chip(self, diff --git a/rastervision_core/requirements.txt b/rastervision_core/requirements.txt index f0eb34fd6..576c90b23 100644 --- a/rastervision_core/requirements.txt +++ b/rastervision_core/requirements.txt @@ -2,7 +2,7 @@ rastervision_pipeline==0.30.0 shapely==2.0.2 geopandas==0.14.3 numpy==1.26.3 -pillow==10.2.0 +pillow==10.3.0 pyproj==3.6.1 rasterio==1.3.9 pystac==1.9.0 diff --git a/rastervision_pipeline/rastervision/pipeline/cli.py b/rastervision_pipeline/rastervision/pipeline/cli.py index 1445f997c..25263448f 100644 --- a/rastervision_pipeline/rastervision/pipeline/cli.py +++ b/rastervision_pipeline/rastervision/pipeline/cli.py @@ -1,18 +1,20 @@ +from typing import TYPE_CHECKING import sys import os import logging -import importlib -import importlib.util -from typing import List, Dict, Optional, Tuple import click from rastervision.pipeline import (registry_ as registry, rv_config_ as rv_config) from rastervision.pipeline.file_system import (file_to_json, get_tmp_dir) -from rastervision.pipeline.config import build_config, save_pipeline_config +from rastervision.pipeline.config import (build_config, Config, + save_pipeline_config) from rastervision.pipeline.pipeline_config import PipelineConfig +if TYPE_CHECKING: + from rastervision.pipeline.runner import Runner + log = logging.getLogger(__name__) @@ -40,8 +42,9 @@ def convert_bool_args(args: dict) -> dict: return new_args -def get_configs(cfg_module_path: str, runner: str, - args: Dict[str, any]) -> List[PipelineConfig]: +def get_configs(cfg_module_path: str, + runner: str | None = None, + args: dict[str, any] | None = None) -> list[PipelineConfig]: """Get PipelineConfigs from a module. Calls a get_config(s) function with some arguments from the CLI @@ -55,6 +58,26 @@ def get_configs(cfg_module_path: str, runner: str, args: CLI args to pass to the get_config(s) function that comes from the --args option """ + if cfg_module_path.endswith('.json'): + cfgs_json = file_to_json(cfg_module_path) + if not isinstance(cfgs_json, list): + cfgs_json = [cfgs_json] + cfgs = [Config.deserialize(json) for json in cfgs_json] + else: + cfgs = get_configs_from_module(cfg_module_path, runner, args) + + for cfg in cfgs: + if not issubclass(type(cfg), PipelineConfig): + raise TypeError('All objects returned by get_configs in ' + f'{cfg_module_path} must be PipelineConfigs.') + return cfgs + + +def get_configs_from_module(cfg_module_path: str, runner: str, + args: dict[str, any]) -> list[PipelineConfig]: + import importlib + import importlib.util + if cfg_module_path.endswith('.py'): # From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa spec = importlib.util.spec_from_file_location('rastervision.pipeline', @@ -65,20 +88,14 @@ def get_configs(cfg_module_path: str, runner: str, cfg_module = importlib.import_module(cfg_module_path) _get_config = getattr(cfg_module, 'get_config', None) - _get_configs = _get_config - if _get_config is None: - _get_configs = getattr(cfg_module, 'get_configs', None) + _get_configs = getattr(cfg_module, 'get_configs', _get_config) if _get_configs is None: - raise Exception('There must be a get_config or get_configs function ' - f'in {cfg_module_path}.') + raise ImportError('There must be a get_config() or get_configs() ' + f'function in {cfg_module_path}.') + cfgs = _get_configs(runner, **args) if not isinstance(cfgs, list): cfgs = [cfgs] - - for cfg in cfgs: - if not issubclass(type(cfg), PipelineConfig): - raise Exception('All objects returned by get_configs in ' - f'{cfg_module_path} must be PipelineConfigs.') return cfgs @@ -89,8 +106,7 @@ def get_configs(cfg_module_path: str, runner: str, @click.option( '-v', '--verbose', help='Increment the verbosity level.', count=True) @click.option('--tmpdir', help='Root of temporary directories to use.') -def main(ctx: click.Context, profile: Optional[str], verbose: int, - tmpdir: str): +def main(ctx: click.Context, profile: str | None, verbose: int, tmpdir: str): """The main click command. Sets the profile, verbosity, and tmp_dir in RVConfig. @@ -103,20 +119,22 @@ def main(ctx: click.Context, profile: Optional[str], verbose: int, rv_config.set_everett_config(profile=profile) -def _run_pipeline(cfg, - runner, - tmp_dir, - splits=1, - commands=None, +def _run_pipeline(cfg: PipelineConfig, + runner: 'Runner', + tmp_dir: str, + splits: int = 1, + commands: list[str] | None = None, pipeline_run_name: str = 'raster-vision'): cfg.update() cfg.recursive_validate_config() - # This is to run the validation again to check any fields that may have changed - # after the Config was constructed, possibly by the update method. + + # This is to run the validation again to check any fields that may have + # changed after the Config was constructed, possibly by the update method. build_config(cfg.dict()) cfg_json_uri = cfg.get_config_uri() save_pipeline_config(cfg, cfg_json_uri) pipeline = cfg.build(tmp_dir) + if not commands: commands = pipeline.commands @@ -150,8 +168,8 @@ def _run_pipeline(cfg, '--pipeline-run-name', default='raster-vision', help='The name for this run of the pipeline.') -def run(runner: str, cfg_module: str, commands: List[str], - arg: List[Tuple[str, str]], splits: int, pipeline_run_name: str): +def run(runner: str, cfg_module: str, commands: list[str], + arg: list[tuple[str, str]], splits: int, pipeline_run_name: str): """Run COMMANDS within pipelines in CFG_MODULE using RUNNER. RUNNER: name of the Runner to use @@ -178,9 +196,9 @@ def run(runner: str, cfg_module: str, commands: List[str], def _run_command(cfg_json_uri: str, command: str, - split_ind: Optional[int] = None, - num_splits: Optional[int] = None, - runner: Optional[str] = None): + split_ind: int | None = None, + num_splits: int | None = None, + runner: str | None = None): """Run a single command using a serialized PipelineConfig. Args: @@ -229,8 +247,8 @@ def _run_command(cfg_json_uri: str, help='The number of processes to use for running splittable commands') @click.option( '--runner', type=str, help='Name of runner to use', default='inprocess') -def run_command(cfg_json_uri: str, command: str, split_ind: Optional[int], - num_splits: Optional[int], runner: str): +def run_command(cfg_json_uri: str, command: str, split_ind: int | None, + num_splits: int | None, runner: str): """Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI.""" _run_command( cfg_json_uri, diff --git a/rastervision_pipeline/rastervision/pipeline/config.py b/rastervision_pipeline/rastervision/pipeline/config.py index f7a950c8b..7d5766ec8 100644 --- a/rastervision_pipeline/rastervision/pipeline/config.py +++ b/rastervision_pipeline/rastervision/pipeline/config.py @@ -115,6 +115,13 @@ def validate_list(self, field: str, valid_options: List[str]): if val not in valid_options: raise ConfigError(f'{val} is not a valid option for {field}') + def dict(self, with_rv_metadata: bool = False, **kwargs) -> dict: + cfg_json = self.json(**kwargs) + cfg_dict = json.loads(cfg_json) + if with_rv_metadata: + cfg_dict['plugin_versions'] = registry.plugin_versions + return cfg_dict + def to_file(self, uri: str, with_rv_metadata: bool = True) -> None: """Save a Config to a JSON file, optionally with RV metadata. @@ -124,13 +131,7 @@ def to_file(self, uri: str, with_rv_metadata: bool = True) -> None: ``plugin_versions``, so that the config can be upgraded when loaded. """ - cfg_json = self.json() - if with_rv_metadata: - # self.dict() --> json_to_file() would be simpler but runs into - # JSON serialization problems - cfg_dict = json.loads(cfg_json) - cfg_dict['plugin_versions'] = registry.plugin_versions - cfg_json = json.dumps(cfg_dict) + cfg_dict = self.dict(with_rv_metadata=with_rv_metadata) json_to_file(cfg_dict, uri) @classmethod @@ -157,7 +158,7 @@ def from_file(cls, uri: str) -> 'Self': Args: uri: URI to load from. """ - cfg_dict = load_config_dict(uri) + cfg_dict = file_to_json(uri) cfg = cls.from_dict(cfg_dict) return cfg @@ -168,6 +169,9 @@ def from_dict(cls, cfg_dict: dict) -> 'Self': Args: cfg_dict: Dict to deserialize. """ + if 'plugin_versions' in cfg_dict: + cfg_dict: dict = upgrade_config(cfg_dict) + cfg_dict.pop('plugin_versions', None) cfg = build_config(cfg_dict) return cfg @@ -192,15 +196,6 @@ def save_pipeline_config(cfg: 'PipelineConfig', output_uri: str) -> None: str_to_file(cfg_json, output_uri) -def load_config_dict(uri: str) -> dict: - """Load a serialized Config from a JSON file as a dict and upgrade it.""" - cfg_dict = file_to_json(uri) - if 'plugin_versions' in cfg_dict: - cfg_dict = upgrade_config(cfg_dict) - cfg_dict.pop('plugin_versions', None) - return cfg_dict - - def build_config(x: Union[dict, List[Union[dict, Config]], Config] ) -> Union[Config, List[Config]]: """Build a Config from various types of input. @@ -216,9 +211,10 @@ def build_config(x: Union[dict, List[Union[dict, Config]], Config] Config: the corresponding Config(s) """ if isinstance(x, dict): - new_x = {} - for k, v in x.items(): - new_x[k] = build_config(v) + new_x = { + k: build_config(v) + for k, v in x.items() if k not in ('plugin_versions', 'rv_config') + } type_hint = new_x.get('type_hint') if type_hint is not None: config_cls = registry.get_config(type_hint) @@ -311,12 +307,11 @@ def upgrade_config( to the current version. """ plugin_versions = config_dict.get('plugin_versions') - plugin_versions = upgrade_plugin_versions(plugin_versions) if plugin_versions is None: - raise ConfigError( - 'Configuration is missing plugin_version field so is not backward ' - 'compatible.') - return _upgrade_config(config_dict, plugin_versions) + return config_dict + plugin_versions = upgrade_plugin_versions(plugin_versions) + out = _upgrade_config(config_dict, plugin_versions) + return out def get_plugin(config_cls: Type) -> str: diff --git a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py index 8e43a3dd3..5048d3053 100644 --- a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py +++ b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py @@ -219,8 +219,10 @@ def file_exists(uri, fs=None, include_dir=True) -> bool: return fs.file_exists(uri, include_dir) -def list_paths(uri: str, ext: str = '', - fs: Optional[FileSystem] = None) -> List[str]: +def list_paths(uri: str, + ext: str = '', + fs: Optional[FileSystem] = None, + **kwargs) -> List[str]: """List paths rooted at URI. Optionally only includes paths with a certain file extension. @@ -230,6 +232,7 @@ def list_paths(uri: str, ext: str = '', ext: the optional file extension to filter by fs: if supplied, use fs instead of automatically chosen FileSystem for uri + **kwargs: extra kwargs to pass to fs.list_paths(). """ if uri is None: return None @@ -237,7 +240,7 @@ def list_paths(uri: str, ext: str = '', if not fs: fs = FileSystem.get_file_system(uri, 'r') - return fs.list_paths(uri, ext=ext) + return fs.list_paths(uri, ext=ext, **kwargs) def upload_or_copy(src_path: str, diff --git a/rastervision_pipeline/rastervision/pipeline/pipeline_config.py b/rastervision_pipeline/rastervision/pipeline/pipeline_config.py index c62fa255a..b051ef998 100644 --- a/rastervision_pipeline/rastervision/pipeline/pipeline_config.py +++ b/rastervision_pipeline/rastervision/pipeline/pipeline_config.py @@ -44,3 +44,6 @@ def build(self, tmp_dir: str) -> 'Pipeline': """ from rastervision.pipeline.pipeline import Pipeline # noqa return Pipeline(self, tmp_dir) + + def dict(self, with_rv_metadata: bool = True, **kwargs) -> dict: + return super().dict(with_rv_metadata=with_rv_metadata, **kwargs) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index 914de02b3..cb1a4b205 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -1118,7 +1118,8 @@ def setup_model(self, if self.model is None: self.model = self.build_model(model_def_path=model_def_path) self.model.to(device=self.device) - if self.is_ddp_process: # pragma: no cover + if self.is_ddp_process and not isinstance(self.model, + DDP): # pragma: no cover self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = DDP(self.model, device_ids=[self.ddp_local_rank]) self.load_init_weights(model_weights_path=model_weights_path) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 523a41808..e0aec3912 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -1180,11 +1180,12 @@ class GeoDataConfig(DataConfig): {}, description='Window sampling config.') def __repr_args__(self): - ds = self.scene_dataset - ds_repr = (f'<{len(ds.train_scenes)} train_scenes, ' - f'{len(ds.validation_scenes)} validation_scenes, ' - f'{len(ds.test_scenes)} test_scenes>') - out = [('scene_dataset', ds_repr), ('sampling', str(self.sampling))] + ds_str = repr(self.scene_dataset) + if isinstance(self.sampling, dict): + sampling_str = f'Dict with {len(self.sampling)} keys' + else: + sampling_str = str(self.sampling) + out = [('scene_dataset', ds_str), ('sampling', sampling_str)] return out @validator('sampling') diff --git a/rastervision_pytorch_learner/requirements.txt b/rastervision_pytorch_learner/requirements.txt index 264a2bac9..9439ab293 100644 --- a/rastervision_pytorch_learner/requirements.txt +++ b/rastervision_pytorch_learner/requirements.txt @@ -12,4 +12,4 @@ psutil==5.9.3 opencv-python-headless==4.9.0.80 matplotlib==3.8.4 tqdm==4.66.2 -onnx==1.15.0 +onnx==1.16.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ffb341a5..aa75bc063 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,6 +6,6 @@ unify==0.5 sphinx-autobuild==2021.3.14 seaborn==0.13.2 jupyter==1.0.0 -jupyterlab==4.1.2 +jupyterlab==4.1.8 jupyter_contrib_nbextensions==0.7.0 pystac_client==0.7.6 diff --git a/scripts/test b/scripts/test index 694e56e9b..9296acb0f 100755 --- a/scripts/test +++ b/scripts/test @@ -47,7 +47,7 @@ if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then -v "$(pwd):/opt/src" \ --rm -t \ "raster-vision-${IMAGE_TYPE}" \ - coverage xml + coverage xml --omit=/opt/data/* --skip-empty ;; *) echo "Invalid argument. Run --help for usage." @@ -67,7 +67,7 @@ if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then ./scripts/integration_tests # Create new coverage reports - coverage xml + coverage xml --omit=/opt/data/* --skip-empty fi fi fi diff --git a/scripts/unit_tests b/scripts/unit_tests index 59124676f..aed7489f7 100755 --- a/scripts/unit_tests +++ b/scripts/unit_tests @@ -28,6 +28,7 @@ else if ! [ -x "$(command -v coverage)" ]; then python -m unittest discover -t "$SRC_DIR" tests -vf else - coverage run -m unittest discover -t "$SRC_DIR" tests -vf && coverage html + coverage run -m unittest discover -t "$SRC_DIR" tests -vf && \ + coverage html --omit=/opt/data/* --skip-empty --precision=2 fi fi diff --git a/tests/README.md b/tests/README.md index 871469254..31e61b1dc 100644 --- a/tests/README.md +++ b/tests/README.md @@ -50,7 +50,7 @@ coverage run -m unittest discover -t . tests -vf Generate an HTML report from the `.coverage` file: ```sh -coverage html +coverage html --omit=/opt/data/* --skip-empty --precision=2 ``` See `coverage help` for other available report formats. diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index 2b1f41d4e..71daa2e4b 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -1,7 +1,14 @@ +import os +from os.path import join import unittest import shutil -from rastervision.pipeline.cli import main, print_error, convert_bool_args +from rastervision.pipeline.file_system.utils import get_tmp_dir +from rastervision.pipeline.config import Config +from rastervision.pipeline.cli import (convert_bool_args, get_configs, main, + print_error) +from rastervision.pipeline_example_plugin1.sample_pipeline import ( + SamplePipelineConfig) from click.testing import CliRunner @@ -75,6 +82,31 @@ def test_convert_bool_args(self): args_out = convert_bool_args(args_in) self.assertDictEqual(args_out, dict(a=True, b=False)) + def test_get_configs_json(self): + cfg = SamplePipelineConfig(root_uri='abc', names=['x', 'y', 'z']) + with get_tmp_dir() as tmp_dir: + cfg_path = join(tmp_dir, 'cfg.json') + cfg.to_file(cfg_path) + cfgs = get_configs(cfg_path) + self.assertEqual(len(cfgs), 1) + self.assertEqual(cfgs[0].root_uri, cfg.root_uri) + self.assertListEqual(cfgs[0].names, cfg.names) + + def test_get_configs_no_func(self): + with get_tmp_dir() as tmp_dir: + cfg_path = join(tmp_dir, 'cfg.py') + with open(cfg_path, 'w'): + pass + self.assertRaises(ImportError, lambda: get_configs(cfg_path)) + os.remove(cfg_path) + + def test_get_configs_not_pipeline_cfg(self): + cfg = Config() + with get_tmp_dir() as tmp_dir: + cfg_path = join(tmp_dir, 'cfg.json') + cfg.to_file(cfg_path) + self.assertRaises(TypeError, lambda: get_configs(cfg_path)) + if __name__ == '__main__': unittest.main()