Skip to content

Commit

Permalink
Feat: allow s3 and or az dependencies
Browse files Browse the repository at this point in the history
Merged branch 'develop' into feat/save-cog-to-azure.
  • Loading branch information
wietzesuijker committed Jan 6, 2025
1 parent 0591a68 commit bc8811c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 123 deletions.
8 changes: 8 additions & 0 deletions odc/geo/_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def datacube(self) -> bool:
def tifffile(self) -> bool:
return self._check("tifffile")

@property
def azure(self) -> bool:
return self._check("azure.storage.blob")

@property
def botocore(self) -> bool:
return self._check("botocore")

@staticmethod
def _check(lib_name: str) -> bool:
return importlib.util.find_spec(lib_name) is not None
Expand Down
13 changes: 9 additions & 4 deletions odc/geo/cog/_az.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import base64
from typing import Any, Union

from azure.storage.blob import BlobBlock, BlobServiceClient
from dask.delayed import Delayed
import dask

from ._mpu import mpu_write
from ._multipart import MultiPartUploadBase
Expand Down Expand Up @@ -51,6 +50,9 @@ def __init__(
self.credential = credential

# Initialise Azure Blob service client
# pylint: disable=import-outside-toplevel,import-error
from azure.storage.blob import BlobServiceClient

self.blob_service_client = BlobServiceClient(
account_url=account_url, credential=credential
)
Expand Down Expand Up @@ -85,6 +87,9 @@ def finalise(self, parts: list[dict[str, Any]]) -> str:
:param parts: List of uploaded parts metadata.
:return: The ETag of the finalised blob.
"""
# pylint: disable=import-outside-toplevel,import-error
from azure.storage.blob import BlobBlock

block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts]
self.blob_client.commit_block_list(block_list)
return self.blob_client.get_blob_properties().etag
Expand Down Expand Up @@ -121,7 +126,7 @@ def writer(self, kw: dict[str, Any], client: Any = None):

def upload(
self,
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
*,
mk_header: Any = None,
mk_footer: Any = None,
Expand All @@ -130,7 +135,7 @@ def upload(
spill_sz: int = 20 * (1 << 20),
client: Any = None,
**kw,
) -> "Delayed":
) -> dask.delayed.Delayed:
"""
Upload chunks to Azure Blob Storage with multipart uploads.
Expand Down
4 changes: 4 additions & 0 deletions odc/geo/cog/_multipart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""
Multipart upload interface.
Defines the `MultiPartUploadBase` class for implementing multipart upload functionality.
This interface standardises methods for initiating, uploading, and finalising
multipart uploads across storage backends.
"""

from abc import ABC, abstractmethod
Expand Down
60 changes: 26 additions & 34 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,20 @@
from functools import partial
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse
from xml.sax.saxutils import escape as xml_escape

import numpy as np
import xarray as xr


from .._interop import have
from ..geobox import GeoBox
from ..math import resolve_nodata
from ..types import Shape2d, SomeNodata, Unset, shape_
from ._mpu import mpu_write
from ._mpu_fs import MPUFileSink

try:
from ._az import AzMultiPartUpload

HAVE_AZURE = True
except ImportError:
AzMultiPartUpload = None
HAVE_AZURE = False
try:
from ._s3 import S3MultiPartUpload, s3_parse_url

HAVE_S3 = True
except ImportError:
S3MultiPartUpload = None
s3_parse_url = None
HAVE_S3 = False

from ._shared import (
GDAL_COMP,
GEOTIFF_TAGS,
Expand Down Expand Up @@ -641,6 +627,7 @@ def save_cog_with_dask(
bigtiff: bool = True,
overview_resampling: Union[int, str] = "nearest",
aws: Optional[dict[str, Any]] = None,
azure: Optional[dict[str, Any]] = None,
client: Any = None,
stats: bool | int = True,
**kw,
Expand Down Expand Up @@ -669,13 +656,12 @@ def save_cog_with_dask(

from ..xr import ODCExtensionDa

if aws is None:
aws = {}
aws = aws or {}
azure = azure or {}

upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw}
upload_params.update(
{k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws}
)
upload_params = {
k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw
}
parts_base = kw.pop("parts_base", None)

# Normalise compression settings and remove GDAL compat options from kw
Expand Down Expand Up @@ -750,19 +736,25 @@ def save_cog_with_dask(
# Determine output type and initiate uploader
parsed_url = urlparse(dst)
if parsed_url.scheme == "s3":
if not HAVE_S3:
raise ImportError("Install `boto3` to enable S3 support.")
bucket, key = s3_parse_url(dst)
uploader = S3MultiPartUpload(bucket, key, **aws)
if have.s3:
from ._s3 import S3MultiPartUpload, s3_parse_url

bucket, key = s3_parse_url(dst)
uploader = S3MultiPartUpload(bucket, key, **aws)
else:
raise RuntimeError("Please install `boto3` to use S3")
elif parsed_url.scheme == "az":
if not HAVE_AZURE:
raise ImportError("Install azure-storage-blob` to enable Azure support.")
uploader = AzMultiPartUpload(
account_url=azure.get("account_url"),
container=parsed_url.netloc,
blob=parsed_url.path.lstrip("/"),
credential=azure.get("credential"),
)
if have.azure:
from ._az import AzMultiPartUpload

uploader = AzMultiPartUpload(
account_url=azure.get("account_url"),
container=parsed_url.netloc,
blob=parsed_url.path.lstrip("/"),
credential=azure.get("credential"),
)
else:
raise RuntimeError("Please install `azure-storage-blob` to use Azure")
else:
# Assume local disk
write = MPUFileSink(dst, parts_base=parts_base)
Expand Down
154 changes: 69 additions & 85 deletions tests/test_az.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,80 @@
"""Tests for the Azure AzMultiPartUpload class."""

import base64
import unittest
from unittest.mock import MagicMock, patch

# Conditional import for Azure support
try:
from odc.geo.cog._az import AzMultiPartUpload
import pytest

HAVE_AZURE = True
except ImportError:
AzMultiPartUpload = None
HAVE_AZURE = False
pytest.importorskip("azure.storage.blob")
from odc.geo.cog._az import AzMultiPartUpload # noqa: E402


def require_azure(test_func):
"""Decorator to skip tests if Azure dependencies are not installed."""
return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")(
test_func
@pytest.fixture
def azure_mpu():
"""Fixture for initializing AzMultiPartUpload."""
account_url = "https://account_name.blob.core.windows.net"
return AzMultiPartUpload(account_url, "container", "some.blob", None)


def test_mpu_init(azure_mpu):
"""Basic test for AzMultiPartUpload initialization."""
assert azure_mpu.account_url == "https://account_name.blob.core.windows.net"
assert azure_mpu.container == "container"
assert azure_mpu.blob == "some.blob"
assert azure_mpu.credential is None


@patch("odc.geo.cog._az.BlobServiceClient")
def test_azure_multipart_upload(mock_blob_service_client):
"""Test the full Azure AzMultiPartUpload functionality."""
# Mock Azure Blob SDK client structure
mock_blob_client = MagicMock()
mock_container_client = MagicMock()
mock_blob_service_client.return_value.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client

# Simulate return values for Azure Blob SDK methods
mock_blob_client.get_blob_properties.return_value.etag = "mock-etag"

# Test parameters
account_url = "https://mockaccount.blob.core.windows.net"
container = "mock-container"
blob = "mock-blob"
credential = "mock-sas-token"

# Create an instance of AzMultiPartUpload and call its methods
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
upload_id = azure_upload.initiate()
part1 = azure_upload.write_part(1, b"first chunk of data")
part2 = azure_upload.write_part(2, b"second chunk of data")
etag = azure_upload.finalise([part1, part2])

# Define block IDs
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
block_id2 = base64.b64encode(b"block-2").decode("utf-8")

# Verify the results
assert upload_id == "azure-block-upload"
assert etag == "mock-etag"

# Verify BlobServiceClient instantiation
mock_blob_service_client.assert_called_once_with(
account_url=account_url, credential=credential
)

# Verify stage_block calls
mock_blob_client.stage_block.assert_any_call(
block_id=block_id1, data=b"first chunk of data"
)
mock_blob_client.stage_block.assert_any_call(
block_id=block_id2, data=b"second chunk of data"
)

class TestAzMultiPartUpload(unittest.TestCase):
"""Test the AzMultiPartUpload class."""

@require_azure
def test_mpu_init(self):
"""Basic test for AzMultiPartUpload initialization."""
account_url = "https://account_name.blob.core.windows.net"
mpu = AzMultiPartUpload(account_url, "container", "some.blob", None)

self.assertEqual(mpu.account_url, account_url)
self.assertEqual(mpu.container, "container")
self.assertEqual(mpu.blob, "some.blob")
self.assertIsNone(mpu.credential)

@require_azure
@patch("odc.geo.cog._az.BlobServiceClient")
def test_azure_multipart_upload(self, mock_blob_service_client):
"""Test the full Azure AzMultiPartUpload functionality."""
# Arrange - Mock Azure Blob SDK client structure
mock_blob_client = MagicMock()
mock_container_client = MagicMock()
mock_blob_service_client.return_value.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client

# Simulate return values for Azure Blob SDK methods
mock_blob_client.get_blob_properties.return_value.etag = "mock-etag"

# Test parameters
account_url = "https://mockaccount.blob.core.windows.net"
container = "mock-container"
blob = "mock-blob"
credential = "mock-sas-token"

# Act
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
upload_id = azure_upload.initiate()
part1 = azure_upload.write_part(1, b"first chunk of data")
part2 = azure_upload.write_part(2, b"second chunk of data")
etag = azure_upload.finalise([part1, part2])

# Correctly calculate block IDs
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
block_id2 = base64.b64encode(b"block-2").decode("utf-8")

# Assert
self.assertEqual(upload_id, "azure-block-upload")
self.assertEqual(etag, "mock-etag")

# Verify BlobServiceClient instantiation
mock_blob_service_client.assert_called_once_with(
account_url=account_url, credential=credential
)

# Verify stage_block calls
mock_blob_client.stage_block.assert_any_call(
block_id=block_id1, data=b"first chunk of data"
)
mock_blob_client.stage_block.assert_any_call(
block_id=block_id2, data=b"second chunk of data"
)

# Verify commit_block_list was called correctly
block_list = mock_blob_client.commit_block_list.call_args[0][0]
self.assertEqual(len(block_list), 2)
self.assertEqual(block_list[0].id, block_id1)
self.assertEqual(block_list[1].id, block_id2)
mock_blob_client.commit_block_list.assert_called_once()


if __name__ == "__main__":
unittest.main()
# Verify commit_block_list was called correctly
block_list = mock_blob_client.commit_block_list.call_args[0][0]
assert len(block_list) == 2
assert block_list[0].id == block_id1
assert block_list[1].id == block_id2
mock_blob_client.commit_block_list.assert_called_once()

0 comments on commit bc8811c

Please sign in to comment.