diff --git a/tests/datasets/azcopy b/tests/datasets/azcopy new file mode 100755 index 00000000000..7eb861cf378 --- /dev/null +++ b/tests/datasets/azcopy @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Basic mock-up of the azcopy CLI. + +Only needed until azcopy supports local <-> local transfers + +* https://github.com/Azure/azure-storage-azcopy/issues/2669 +""" + +import argparse +import shutil + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + sync = subparsers.add_parser("sync") + sync.add_argument("source") + sync.add_argument("destination") + args, _ = parser.parse_known_args() + shutil.copytree(args.source, args.destination, dirs_exist_ok=True) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 018b0d9bae2..586b38243cf 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -25,6 +25,7 @@ BoundingBox, DatasetNotFoundError, array_to_tensor, + azcopy, concat_samples, disambiguate_timestamp, download_and_extract_archive, @@ -180,6 +181,19 @@ def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) ) +def test_azcopy(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: + source = os.path.join("tests", "data", "cyclone") + if shutil.which("azcopy"): + path = os.path.dirname(os.path.realpath(__file__)) + monkeypatch.setenv("PATH", path, prepend=os.pathsep) + azcopy("sync", source, tmp_path, "--recursive=true") + assert os.path.exists(tmp_path / "nasa_tropical_storm_competition_test_labels") + else: + match = "azcopy is not installed and is required to download this dataset" + with pytest.raises(FileNotFoundError, match=match): + azcopy("sync", source, tmp_path, "--recursive=true") + + def test_download_radiant_mlhub_dataset( tmp_path: Path, monkeypatch: MonkeyPatch ) -> None: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 2ef37dfd64d..6bd02a55085 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -12,6 +12,7 @@ import gzip import lzma import os +import subprocess import sys import tarfile from collections.abc import Iterable, Iterator, Sequence @@ -220,6 +221,26 @@ def download_and_extract_archive( extract_archive(archive, extract_root) +def azcopy(*args: Any, **kwargs: Any) -> None: + """Wrapper around ``azcopy`` command. + + Args: + args: Arguments to pass to ``azcopy``. + kwargs: Keyword arguments to pass to ``subprocess.run``. + + Raises: + FileNotFoundError: If ``azcopy`` is not installed. + + .. versionadded:: 0.6 + """ + kwargs["check"] = True + try: + subprocess.run(("azcopy",) + args, **kwargs) + except FileNotFoundError: + msg = "azcopy is not installed and is required to download this dataset" + raise FileNotFoundError(msg) + + def download_radiant_mlhub_dataset( dataset_id: str, download_root: str, api_key: str | None = None ) -> None: