Skip to content

Commit

Permalink
Datasets: add azcopy download support
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed May 3, 2024
1 parent bd9c757 commit 38113c1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/datasets/azcopy
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Basic mock-up of the azcopy CLI.
Only needed until azcopy supports local <-> local transfers
* https://github.com/Azure/azure-storage-azcopy/issues/2669
"""

import argparse
import shutil


if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
sync = subparsers.add_parser("sync")
sync.add_argument("source")
sync.add_argument("destination")
args, _ = parser.parse_known_args()
shutil.copytree(args.source, args.destination, dirs_exist_ok=True)
14 changes: 14 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BoundingBox,
DatasetNotFoundError,
array_to_tensor,
azcopy,
concat_samples,
disambiguate_timestamp,
download_and_extract_archive,
Expand Down Expand Up @@ -180,6 +181,19 @@ def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch)
)


def test_azcopy(tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
source = os.path.join("tests", "data", "cyclone")
if shutil.which("azcopy"):
path = os.path.dirname(os.path.realpath(__file__))
monkeypatch.setenv("PATH", path, prepend=os.pathsep)
azcopy("sync", source, tmp_path, "--recursive=true")
assert os.path.exists(tmp_path / "nasa_tropical_storm_competition_test_labels")
else:
match = "azcopy is not installed and is required to download this dataset"
with pytest.raises(FileNotFoundError, match=match):
azcopy("sync", source, tmp_path, "--recursive=true")


def test_download_radiant_mlhub_dataset(
tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
Expand Down
21 changes: 21 additions & 0 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import gzip
import lzma
import os
import subprocess
import sys
import tarfile
from collections.abc import Iterable, Iterator, Sequence
Expand Down Expand Up @@ -220,6 +221,26 @@ def download_and_extract_archive(
extract_archive(archive, extract_root)


def azcopy(*args: Any, **kwargs: Any) -> None:
"""Wrapper around ``azcopy`` command.
Args:
args: Arguments to pass to ``azcopy``.
kwargs: Keyword arguments to pass to ``subprocess.run``.
Raises:
FileNotFoundError: If ``azcopy`` is not installed.
.. versionadded:: 0.6
"""
kwargs["check"] = True
try:
subprocess.run(("azcopy",) + args, **kwargs)
except FileNotFoundError:
msg = "azcopy is not installed and is required to download this dataset"
raise FileNotFoundError(msg)


def download_radiant_mlhub_dataset(
dataset_id: str, download_root: str, api_key: str | None = None
) -> None:
Expand Down

0 comments on commit 38113c1

Please sign in to comment.