Skip to content

Commit

Permalink
Feat: safely import azure-storage-blob and boto3
Browse files Browse the repository at this point in the history
  • Loading branch information
wietzesuijker committed Dec 17, 2024
1 parent d80b5cd commit 69198e5
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 47 deletions.
2 changes: 1 addition & 1 deletion dev-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge

dependencies:
- python =3.8
- python =3.10

# odc-geo dependencies
- pyproj
Expand Down
6 changes: 3 additions & 3 deletions odc/geo/cog/_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def max_part(self) -> int:
return 50_000


class MultiPartUpload(AzureLimits, MultiPartUploadBase):
class AzMultiPartUpload(AzureLimits, MultiPartUploadBase):
def __init__(
self, account_url: str, container: str, blob: str, credential: Any = None
):
Expand Down Expand Up @@ -161,11 +161,11 @@ class DelayedAzureWriter(AzureLimits):
Dask-compatible writer for Azure Blob Storage multipart uploads.
"""

def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]):
def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]):
"""
Initialise the Azure writer.
:param mpu: MultiPartUpload instance.
:param mpu: AzMultiPartUpload instance.
:param kw: Additional parameters for the writer.
"""
self.mpu = mpu
Expand Down
6 changes: 3 additions & 3 deletions odc/geo/cog/_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def max_part(self) -> int:
return 10_000


class MultiPartUpload(S3Limits, MultiPartUploadBase):
class S3MultiPartUpload(S3Limits, MultiPartUploadBase):
"""
Dask to S3 dumper.
"""
Expand Down Expand Up @@ -237,7 +237,7 @@ class DelayedS3Writer(S3Limits):

# pylint: disable=import-outside-toplevel,import-error

def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]):
def __init__(self, mpu: S3MultiPartUpload, kw: dict[str, Any]):
self.mpu = mpu
self.kw = kw # mostly ContentType= kinda thing
self._shared_var: Optional["distributed.Variable"] = None
Expand All @@ -263,7 +263,7 @@ def _shared(self, client: "distributed.Client") -> "distributed.Variable":
self._shared_var = Variable(self._build_name("MPUpload"), client)
return self._shared_var

def _ensure_init(self, final_write: bool = False) -> MultiPartUpload:
def _ensure_init(self, final_write: bool = False) -> S3MultiPartUpload:
# pylint: disable=too-many-return-statements
mpu = self.mpu
if mpu.started:
Expand Down
24 changes: 21 additions & 3 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,29 @@
import numpy as np
import xarray as xr


from .._interop import have
from ..geobox import GeoBox
from ..math import resolve_nodata
from ..types import Shape2d, SomeNodata, Unset, shape_
from ._az import MultiPartUpload as AzMultiPartUpload
from ._mpu import mpu_write
from ._mpu_fs import MPUFileSink
from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url

try:
from ._az import AzMultiPartUpload

HAVE_AZURE = True
except ImportError:
AzMultiPartUpload = None
HAVE_AZURE = False
try:
from ._s3 import S3MultiPartUpload, s3_parse_url

HAVE_S3 = True
except ImportError:
S3MultiPartUpload = None
s3_parse_url = None
HAVE_S3 = False

from ._shared import (
GDAL_COMP,
GEOTIFF_TAGS,
Expand Down Expand Up @@ -738,9 +752,13 @@ def save_cog_with_dask(
# Determine output type and initiate uploader
parsed_url = urlparse(dst)
if parsed_url.scheme == "s3":
if not HAVE_S3:
raise ImportError("Install `boto3` to enable S3 support.")
bucket, key = s3_parse_url(dst)
uploader = S3MultiPartUpload(bucket, key, **aws)
elif parsed_url.scheme == "az":
if not HAVE_AZURE:
raise ImportError("Install azure-storage-blob` to enable Azure support.")
uploader = AzMultiPartUpload(
account_url=azure.get("account_url"),
container=parsed_url.netloc,
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ tiff =
s3 =
boto3

az =
azure-storage-blob

all =
%(warp)s
%(tiff)s
%(s3)s
%(az)s

test =
pytest
Expand Down
88 changes: 55 additions & 33 deletions tests/test_az.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
"""Tests for the Azure MultiPartUpload class."""
"""Tests for the Azure AzMultiPartUpload class."""

import base64
import unittest
from unittest.mock import MagicMock, patch

from odc.geo.cog._az import MultiPartUpload
# Conditional import for Azure support
try:
from odc.geo.cog._az import AzMultiPartUpload

HAVE_AZURE = True
except ImportError:
AzMultiPartUpload = None
HAVE_AZURE = False

def test_mpu_init():
"""Basic test for the MultiPartUpload class."""
account_url = "https://account_name.blob.core.windows.net"
mpu = MultiPartUpload(account_url, "container", "some.blob", None)
if mpu.account_url != account_url:
raise AssertionError(f"mpu.account_url should be '{account_url}'.")
if mpu.container != "container":
raise AssertionError("mpu.container should be 'container'.")
if mpu.blob != "some.blob":
raise AssertionError("mpu.blob should be 'some.blob'.")
if mpu.credential is not None:
raise AssertionError("mpu.credential should be 'None'.")

def require_azure(test_func):
"""Decorator to skip tests if Azure dependencies are not installed."""
return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")(
test_func
)

class TestMultiPartUpload(unittest.TestCase):
"""Test the MultiPartUpload class."""

class TestAzMultiPartUpload(unittest.TestCase):
"""Test the AzMultiPartUpload class."""

@require_azure
def test_mpu_init(self):
"""Basic test for AzMultiPartUpload initialization."""
account_url = "https://account_name.blob.core.windows.net"
mpu = AzMultiPartUpload(account_url, "container", "some.blob", None)

self.assertEqual(mpu.account_url, account_url)
self.assertEqual(mpu.container, "container")
self.assertEqual(mpu.blob, "some.blob")
self.assertIsNone(mpu.credential)

@require_azure
@patch("odc.geo.cog._az.BlobServiceClient")
def test_azure_multipart_upload(self, mock_blob_service_client):
"""Test the MultiPartUpload class."""
# Arrange - mock the Azure Blob SDK
# Mock the blob client and its methods
"""Test the full Azure AzMultiPartUpload functionality."""
# Arrange - Mock Azure Blob SDK client structure
mock_blob_client = MagicMock()
mock_container_client = MagicMock()
mcc = mock_container_client
mock_blob_service_client.return_value.get_container_client.return_value = mcc
mock_blob_service_client.return_value.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client

# Simulate return values for Azure Blob SDK methods
Expand All @@ -43,32 +56,41 @@ def test_azure_multipart_upload(self, mock_blob_service_client):
blob = "mock-blob"
credential = "mock-sas-token"

# Act - create an instance of MultiPartUpload and call its methods
azure_upload = MultiPartUpload(account_url, container, blob, credential)
# Act
azure_upload = AzMultiPartUpload(account_url, container, blob, credential)
upload_id = azure_upload.initiate()
part1 = azure_upload.write_part(1, b"first chunk of data")
part2 = azure_upload.write_part(2, b"second chunk of data")
etag = azure_upload.finalise([part1, part2])

# Assert - check the results
# Check that the initiate method behaves as expected
# Correctly calculate block IDs
block_id1 = base64.b64encode(b"block-1").decode("utf-8")
block_id2 = base64.b64encode(b"block-2").decode("utf-8")

# Assert
self.assertEqual(upload_id, "azure-block-upload")
self.assertEqual(etag, "mock-etag")

# Verify the calls to Azure Blob SDK methods
# Verify BlobServiceClient instantiation
mock_blob_service_client.assert_called_once_with(
account_url=account_url, credential=credential
)

# Verify stage_block calls
mock_blob_client.stage_block.assert_any_call(
part1["BlockId"], b"first chunk of data"
block_id=block_id1, data=b"first chunk of data"
)
mock_blob_client.stage_block.assert_any_call(
part2["BlockId"], b"second chunk of data"
block_id=block_id2, data=b"second chunk of data"
)
mock_blob_client.commit_block_list.assert_called_once()
self.assertEqual(etag, "mock-etag")

# Verify block list passed during finalise
# Verify commit_block_list was called correctly
block_list = mock_blob_client.commit_block_list.call_args[0][0]
self.assertEqual(len(block_list), 2)
self.assertEqual(block_list[0].id, part1["BlockId"])
self.assertEqual(block_list[1].id, part2["BlockId"])
self.assertEqual(block_list[0].id, block_id1)
self.assertEqual(block_list[1].id, block_id2)
mock_blob_client.commit_block_list.assert_called_once()


if __name__ == "__main__":
unittest.main()
13 changes: 9 additions & 4 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from odc.geo.cog._s3 import MultiPartUpload
"""Tests for odc.geo.cog._s3."""

from odc.geo.cog._s3 import S3MultiPartUpload

# TODO: moto


def test_s3_mpu():
mpu = MultiPartUpload("bucket", "file.dat")
assert mpu.bucket == "bucket"
assert mpu.key == "file.dat"
"""Test S3MultiPartUpload class initialization."""
mpu = S3MultiPartUpload("bucket", "file.dat")
if mpu.bucket != "bucket":
raise ValueError("Invalid bucket")
if mpu.key != "file.dat":
raise ValueError("Invalid key")

0 comments on commit 69198e5

Please sign in to comment.