diff --git a/odc/geo/_interop.py b/odc/geo/_interop.py index e7c61d49..f25ea5d6 100644 --- a/odc/geo/_interop.py +++ b/odc/geo/_interop.py @@ -43,6 +43,14 @@ def datacube(self) -> bool: def tifffile(self) -> bool: return self._check("tifffile") + @property + def azure(self) -> bool: + return self._check("azure.storage.blob") + + @property + def botocore(self) -> bool: + return self._check("botocore") + @staticmethod def _check(lib_name: str) -> bool: return importlib.util.find_spec(lib_name) is not None diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 5c79497a..9995a0ab 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -1,8 +1,7 @@ import base64 from typing import Any, Union -from azure.storage.blob import BlobBlock, BlobServiceClient -from dask.delayed import Delayed +import dask from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -51,6 +50,9 @@ def __init__( self.credential = credential # Initialise Azure Blob service client + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobServiceClient + self.blob_service_client = BlobServiceClient( account_url=account_url, credential=credential ) @@ -85,6 +87,9 @@ def finalise(self, parts: list[dict[str, Any]]) -> str: :param parts: List of uploaded parts metadata. :return: The ETag of the finalised blob. """ + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobBlock + block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts] self.blob_client.commit_block_list(block_list) return self.blob_client.get_blob_properties().etag @@ -121,7 +126,7 @@ def writer(self, kw: dict[str, Any], client: Any = None): def upload( self, - chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], + chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], *, mk_header: Any = None, mk_footer: Any = None, @@ -130,7 +135,7 @@ def upload( spill_sz: int = 20 * (1 << 20), client: Any = None, **kw, - ) -> "Delayed": + ) -> dask.delayed.Delayed: """ Upload chunks to Azure Blob Storage with multipart uploads. diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index 0ba3c9b1..c9060bee 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -1,5 +1,9 @@ """ Multipart upload interface. + +Defines the `MultiPartUploadBase` class for implementing multipart upload functionality. +This interface standardises methods for initiating, uploading, and finalising +multipart uploads across storage backends. """ from abc import ABC, abstractmethod diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 8d813c4d..fae90408 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -11,11 +11,13 @@ from functools import partial from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from urllib.parse import urlparse from xml.sax.saxutils import escape as xml_escape import numpy as np import xarray as xr + from .._interop import have from ..geobox import GeoBox from ..math import resolve_nodata @@ -23,22 +25,6 @@ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -try: - from ._az import AzMultiPartUpload - - HAVE_AZURE = True -except ImportError: - AzMultiPartUpload = None - HAVE_AZURE = False -try: - from ._s3 import S3MultiPartUpload, s3_parse_url - - HAVE_S3 = True -except ImportError: - S3MultiPartUpload = None - s3_parse_url = None - HAVE_S3 = False - from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -641,6 +627,7 @@ def save_cog_with_dask( bigtiff: bool = True, overview_resampling: Union[int, str] = "nearest", aws: Optional[dict[str, Any]] = None, + azure: Optional[dict[str, Any]] = None, client: Any = None, stats: bool | int = True, **kw, @@ -669,13 +656,12 @@ def save_cog_with_dask( from ..xr import ODCExtensionDa - if aws is None: - aws = {} + aws = aws or {} + azure = azure or {} - upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw} - upload_params.update( - {k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws} - ) + upload_params = { + k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw + } parts_base = kw.pop("parts_base", None) # Normalise compression settings and remove GDAL compat options from kw @@ -750,19 +736,25 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": - if not HAVE_S3: - raise ImportError("Install `boto3` to enable S3 support.") - bucket, key = s3_parse_url(dst) - uploader = S3MultiPartUpload(bucket, key, **aws) + if have.s3: + from ._s3 import S3MultiPartUpload, s3_parse_url + + bucket, key = s3_parse_url(dst) + uploader = S3MultiPartUpload(bucket, key, **aws) + else: + raise RuntimeError("Please install `boto3` to use S3") elif parsed_url.scheme == "az": - if not HAVE_AZURE: - raise ImportError("Install azure-storage-blob` to enable Azure support.") - uploader = AzMultiPartUpload( - account_url=azure.get("account_url"), - container=parsed_url.netloc, - blob=parsed_url.path.lstrip("/"), - credential=azure.get("credential"), - ) + if have.azure: + from ._az import AzMultiPartUpload + + uploader = AzMultiPartUpload( + account_url=azure.get("account_url"), + container=parsed_url.netloc, + blob=parsed_url.path.lstrip("/"), + credential=azure.get("credential"), + ) + else: + raise RuntimeError("Please install `azure-storage-blob` to use Azure") else: # Assume local disk write = MPUFileSink(dst, parts_base=parts_base) diff --git a/tests/test_az.py b/tests/test_az.py index 024e01ce..9462f091 100644 --- a/tests/test_az.py +++ b/tests/test_az.py @@ -1,96 +1,80 @@ """Tests for the Azure AzMultiPartUpload class.""" import base64 -import unittest from unittest.mock import MagicMock, patch -# Conditional import for Azure support -try: - from odc.geo.cog._az import AzMultiPartUpload +import pytest - HAVE_AZURE = True -except ImportError: - AzMultiPartUpload = None - HAVE_AZURE = False +pytest.importorskip("azure.storage.blob") +from odc.geo.cog._az import AzMultiPartUpload # noqa: E402 -def require_azure(test_func): - """Decorator to skip tests if Azure dependencies are not installed.""" - return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")( - test_func +@pytest.fixture +def azure_mpu(): + """Fixture for initializing AzMultiPartUpload.""" + account_url = "https://account_name.blob.core.windows.net" + return AzMultiPartUpload(account_url, "container", "some.blob", None) + + +def test_mpu_init(azure_mpu): + """Basic test for AzMultiPartUpload initialization.""" + assert azure_mpu.account_url == "https://account_name.blob.core.windows.net" + assert azure_mpu.container == "container" + assert azure_mpu.blob == "some.blob" + assert azure_mpu.credential is None + + +@patch("odc.geo.cog._az.BlobServiceClient") +def test_azure_multipart_upload(mock_blob_service_client): + """Test the full Azure AzMultiPartUpload functionality.""" + # Mock Azure Blob SDK client structure + mock_blob_client = MagicMock() + mock_container_client = MagicMock() + mock_blob_service_client.return_value.get_container_client.return_value = ( + mock_container_client + ) + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Simulate return values for Azure Blob SDK methods + mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" + + # Test parameters + account_url = "https://mockaccount.blob.core.windows.net" + container = "mock-container" + blob = "mock-blob" + credential = "mock-sas-token" + + # Create an instance of AzMultiPartUpload and call its methods + azure_upload = AzMultiPartUpload(account_url, container, blob, credential) + upload_id = azure_upload.initiate() + part1 = azure_upload.write_part(1, b"first chunk of data") + part2 = azure_upload.write_part(2, b"second chunk of data") + etag = azure_upload.finalise([part1, part2]) + + # Define block IDs + block_id1 = base64.b64encode(b"block-1").decode("utf-8") + block_id2 = base64.b64encode(b"block-2").decode("utf-8") + + # Verify the results + assert upload_id == "azure-block-upload" + assert etag == "mock-etag" + + # Verify BlobServiceClient instantiation + mock_blob_service_client.assert_called_once_with( + account_url=account_url, credential=credential ) + # Verify stage_block calls + mock_blob_client.stage_block.assert_any_call( + block_id=block_id1, data=b"first chunk of data" + ) + mock_blob_client.stage_block.assert_any_call( + block_id=block_id2, data=b"second chunk of data" + ) -class TestAzMultiPartUpload(unittest.TestCase): - """Test the AzMultiPartUpload class.""" - - @require_azure - def test_mpu_init(self): - """Basic test for AzMultiPartUpload initialization.""" - account_url = "https://account_name.blob.core.windows.net" - mpu = AzMultiPartUpload(account_url, "container", "some.blob", None) - - self.assertEqual(mpu.account_url, account_url) - self.assertEqual(mpu.container, "container") - self.assertEqual(mpu.blob, "some.blob") - self.assertIsNone(mpu.credential) - - @require_azure - @patch("odc.geo.cog._az.BlobServiceClient") - def test_azure_multipart_upload(self, mock_blob_service_client): - """Test the full Azure AzMultiPartUpload functionality.""" - # Arrange - Mock Azure Blob SDK client structure - mock_blob_client = MagicMock() - mock_container_client = MagicMock() - mock_blob_service_client.return_value.get_container_client.return_value = ( - mock_container_client - ) - mock_container_client.get_blob_client.return_value = mock_blob_client - - # Simulate return values for Azure Blob SDK methods - mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" - - # Test parameters - account_url = "https://mockaccount.blob.core.windows.net" - container = "mock-container" - blob = "mock-blob" - credential = "mock-sas-token" - - # Act - azure_upload = AzMultiPartUpload(account_url, container, blob, credential) - upload_id = azure_upload.initiate() - part1 = azure_upload.write_part(1, b"first chunk of data") - part2 = azure_upload.write_part(2, b"second chunk of data") - etag = azure_upload.finalise([part1, part2]) - - # Correctly calculate block IDs - block_id1 = base64.b64encode(b"block-1").decode("utf-8") - block_id2 = base64.b64encode(b"block-2").decode("utf-8") - - # Assert - self.assertEqual(upload_id, "azure-block-upload") - self.assertEqual(etag, "mock-etag") - - # Verify BlobServiceClient instantiation - mock_blob_service_client.assert_called_once_with( - account_url=account_url, credential=credential - ) - - # Verify stage_block calls - mock_blob_client.stage_block.assert_any_call( - block_id=block_id1, data=b"first chunk of data" - ) - mock_blob_client.stage_block.assert_any_call( - block_id=block_id2, data=b"second chunk of data" - ) - - # Verify commit_block_list was called correctly - block_list = mock_blob_client.commit_block_list.call_args[0][0] - self.assertEqual(len(block_list), 2) - self.assertEqual(block_list[0].id, block_id1) - self.assertEqual(block_list[1].id, block_id2) - mock_blob_client.commit_block_list.assert_called_once() - - -if __name__ == "__main__": - unittest.main() + # Verify commit_block_list was called correctly + block_list = mock_blob_client.commit_block_list.call_args[0][0] + assert len(block_list) == 2 + assert block_list[0].id == block_id1 + assert block_list[1].id == block_id2 + mock_blob_client.commit_block_list.assert_called_once()