Skip to content

Commit

Permalink
Add mocked download tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed May 2, 2024
1 parent 226ec8b commit 135bc85
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
35 changes: 35 additions & 0 deletions tests/datasets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from collections.abc import Iterator
from typing import BinaryIO

import pytest
from pytest import MonkeyPatch


class ContainerClient:
def __init__(self, account_url: str, container_name: str) -> None:
self.account_url = account_url
self.container_name = container_name

def list_blob_names(self, name_starts_with: str = "") -> Iterator[str]:
prefix = os.path.join(self.account_url, self.container_name)
for root, dirs, files in os.walk(prefix):
for file in files:
name = os.path.join(root, file).replace(prefix + os.sep, "", 1)
if name.startswith(name_starts_with):
yield name

def download_blob(self, blob: str) -> BinaryIO:
path = os.path.join(self.account_url, self.container_name, blob)
# TODO: filehandle leak
f = open(path, "rb", buffering=0)
return f


@pytest.fixture
def container_client(monkeypatch: MonkeyPatch) -> None:
asb = pytest.importorskip("azure.storage.blob", minversion="12.4")
monkeypatch.setattr(asb, "ContainerClient", ContainerClient)
19 changes: 18 additions & 1 deletion tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,24 @@ def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch)
)


def test_download_azure_container(tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
def test_download_azure_container(tmp_path: Path, container_client: None) -> None:
account_url = os.path.join("tests", "data")
container_name = "cyclone"
name_starts_with = "nasa_tropical_storm_competition_test_source"
download_azure_container(
account_url=account_url,
container_name=container_name,
root=str(tmp_path),
name_starts_with=name_starts_with,
)
assert os.path.exists(
tmp_path / "nasa_tropical_storm_competition_test_source" / "collection.json"
)


@pytest.mark.slow
def test_download_azure_container_slow(tmp_path: Path) -> None:
pytest.importorskip("azure.storage.blob", minversion="12.4")
split = "train"
account_url = "https://radiantearth.blob.core.windows.net"
container_name = "mlhub"
Expand Down

0 comments on commit 135bc85

Please sign in to comment.