From 135bc85c7f8f0278de2d408e4e9a922bf181b9d3 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 2 May 2024 15:49:55 +0200 Subject: [PATCH] Add mocked download tests --- tests/datasets/conftest.py | 35 +++++++++++++++++++++++++++++++++++ tests/datasets/test_utils.py | 19 ++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/datasets/conftest.py diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py new file mode 100644 index 00000000000..b2d5573ea8f --- /dev/null +++ b/tests/datasets/conftest.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from collections.abc import Iterator +from typing import BinaryIO + +import pytest +from pytest import MonkeyPatch + + +class ContainerClient: + def __init__(self, account_url: str, container_name: str) -> None: + self.account_url = account_url + self.container_name = container_name + + def list_blob_names(self, name_starts_with: str = "") -> Iterator[str]: + prefix = os.path.join(self.account_url, self.container_name) + for root, dirs, files in os.walk(prefix): + for file in files: + name = os.path.join(root, file).replace(prefix + os.sep, "", 1) + if name.startswith(name_starts_with): + yield name + + def download_blob(self, blob: str) -> BinaryIO: + path = os.path.join(self.account_url, self.container_name, blob) + # TODO: filehandle leak + f = open(path, "rb", buffering=0) + return f + + +@pytest.fixture +def container_client(monkeypatch: MonkeyPatch) -> None: + asb = pytest.importorskip("azure.storage.blob", minversion="12.4") + monkeypatch.setattr(asb, "ContainerClient", ContainerClient) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 189bda53737..84ce8f13b69 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -186,7 +186,24 @@ def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) ) -def test_download_azure_container(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: +def test_download_azure_container(tmp_path: Path, container_client: None) -> None: + account_url = os.path.join("tests", "data") + container_name = "cyclone" + name_starts_with = "nasa_tropical_storm_competition_test_source" + download_azure_container( + account_url=account_url, + container_name=container_name, + root=str(tmp_path), + name_starts_with=name_starts_with, + ) + assert os.path.exists( + tmp_path / "nasa_tropical_storm_competition_test_source" / "collection.json" + ) + + +@pytest.mark.slow +def test_download_azure_container_slow(tmp_path: Path) -> None: + pytest.importorskip("azure.storage.blob", minversion="12.4") split = "train" account_url = "https://radiantearth.blob.core.windows.net" container_name = "mlhub"