From 7b6b5534f72cea2049d74510262ad29614a0904c Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Fri, 13 Dec 2024 00:25:30 +0000 Subject: [PATCH] feat: save cog with dask to azure --- odc/geo/cog/_az.py | 187 ++++++++++++++++++++++++++++++++++++++ odc/geo/cog/_multipart.py | 81 +++++++++++++++++ odc/geo/cog/_s3.py | 61 +++++++------ odc/geo/cog/_tifffile.py | 85 +++++++++-------- tests/test_az.py | 62 +++++++++++++ 5 files changed, 408 insertions(+), 68 deletions(-) create mode 100644 odc/geo/cog/_az.py create mode 100644 odc/geo/cog/_multipart.py create mode 100644 tests/test_az.py diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py new file mode 100644 index 00000000..d3e30a36 --- /dev/null +++ b/odc/geo/cog/_az.py @@ -0,0 +1,187 @@ +import base64 +from typing import Any, Dict, List, Union + +from azure.storage.blob import BlobBlock, BlobServiceClient +from dask.delayed import Delayed + +from ._mpu import mpu_write +from ._multipart import MultiPartUploadBase + + +class AzureLimits: + """ + Common Azure writer settings. + """ + + @property + def min_write_sz(self) -> int: + # Azure minimum write size for blocks (default is 4 MiB) + return 4 * (1 << 20) + + @property + def max_write_sz(self) -> int: + # Azure maximum write size for blocks (default is 100 MiB) + return 100 * (1 << 20) + + @property + def min_part(self) -> int: + return 1 + + @property + def max_part(self) -> int: + # Azure supports up to 50,000 blocks per blob + return 50_000 + + +class MultiPartUpload(AzureLimits, MultiPartUploadBase): + def __init__(self, account_url: str, container: str, blob: str, credential: Any = None): + """ + Initialise Azure multipart upload. + + :param account_url: URL of the Azure storage account. + :param container: Name of the container. + s :param blob: Name of the blob. + :param credential: Authentication credentials (e.g., SAS token or key). + """ + self.account_url = account_url + self.container = container + self.blob = blob + self.credential = credential + + # Initialise Azure Blob service client + self.blob_service_client = BlobServiceClient(account_url=account_url, credential=credential) + self.container_client = self.blob_service_client.get_container_client(container) + self.blob_client = self.container_client.get_blob_client(blob) + + self.block_ids: List[str] = [] + + def initiate(self, **kwargs) -> str: + """ + Initialise the upload. No-op for Azure. + """ + return "azure-block-upload" + + def write_part(self, part: int, data: bytes) -> Dict[str, Any]: + """ + Stage a block in Azure. + + :param part: Part number (unique). + :param data: Data for this part. + :return: A dictionary containing part information. + """ + block_id = base64.b64encode(f"block-{part}".encode()).decode() + self.blob_client.stage_block(block_id=block_id, data=data) + self.block_ids.append(block_id) + return {"PartNumber": part, "BlockId": block_id} + + def finalise(self, parts: List[Dict[str, Any]]) -> str: + """ + Commit the block list to finalise the upload. + + :param parts: List of uploaded parts metadata. + :return: The ETag of the finalised blob. + """ + block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts] + self.blob_client.commit_block_list(block_list) + return self.blob_client.get_blob_properties().etag + + def cancel(self): + """ + Cancel the upload by clearing the block list. + """ + self.block_ids.clear() + + @property + def url(self) -> str: + """ + Get the Azure blob URL. + + :return: The full URL of the blob. + """ + return self.blob_client.url + + @property + def started(self) -> bool: + """ + Check if any blocks have been staged. + + :return: True if blocks have been staged, False otherwise. + """ + return bool(self.block_ids) + + def writer(self, kw: Dict[str, Any], client: Any = None): + """ + Return a stateless writer compatible with Dask. + """ + return DelayedAzureWriter(self, kw) + + def upload( + self, + chunks: Union["dask.bag.Bag", List["dask.bag.Bag"]], + *, + mk_header: Any = None, + mk_footer: Any = None, + user_kw: Dict[str, Any] = None, + writes_per_chunk: int = 1, + spill_sz: int = 20 * (1 << 20), + client: Any = None, + **kw, + ) -> "Delayed": + """ + Upload chunks to Azure Blob Storage with multipart uploads. + + :param chunks: Dask bag of chunks to upload. + :param mk_header: Function to create header data. + :param mk_footer: Function to create footer data. + :param user_kw: User-provided metadata for the upload. + :param writes_per_chunk: Number of writes per chunk. + :param spill_sz: Spill size for buffering data. + :param client: Dask client for distributed execution. + :return: A Dask delayed object representing the finalised upload. + """ + write = self.writer(kw, client=client) if spill_sz else None + return mpu_write( + chunks, + write, + mk_header=mk_header, + mk_footer=mk_footer, + user_kw=user_kw, + writes_per_chunk=writes_per_chunk, + spill_sz=spill_sz, + dask_name_prefix="azure-finalise", + ) + + +class DelayedAzureWriter(AzureLimits): + """ + Dask-compatible writer for Azure Blob Storage multipart uploads. + """ + + def __init__(self, mpu: MultiPartUpload, kw: Dict[str, Any]): + """ + Initialise the Azure writer. + + :param mpu: MultiPartUpload instance. + :param kw: Additional parameters for the writer. + """ + self.mpu = mpu + self.kw = kw # Additional metadata like ContentType + + def __call__(self, part: int, data: bytes) -> Dict[str, Any]: + """ + Write a single part to Azure Blob Storage. + + :param part: Part number. + :param data: Chunk data. + :return: Metadata for the written part. + """ + return self.mpu.write_part(part, data) + + def finalise(self, parts: List[Dict[str, Any]]) -> str: + """ + Finalise the upload by committing the block list. + + :param parts: List of uploaded parts metadata. + :return: ETag of the finalised blob. + """ + return self.mpu.finalise(parts) \ No newline at end of file diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py new file mode 100644 index 00000000..9d65b331 --- /dev/null +++ b/odc/geo/cog/_multipart.py @@ -0,0 +1,81 @@ +""" +Multipart upload interface. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + +import dask.bag + + +class MultiPartUploadBase(ABC): + """Abstract base class for multipart upload.""" + + @abstractmethod + def initiate(self, **kwargs) -> str: + """Initiate a multipart upload and return an identifier.""" + pass + + @abstractmethod + def write_part(self, part: int, data: bytes) -> Dict[str, Any]: + """Upload a single part.""" + pass + + @abstractmethod + def finalise(self, parts: List[Dict[str, Any]]) -> str: + """Finalise the upload with a list of parts.""" + pass + + @abstractmethod + def cancel(self, other: str = ""): + """Cancel the multipart upload.""" + pass + + @property + @abstractmethod + def url(self) -> str: + """Return the URL of the upload target.""" + pass + + @property + @abstractmethod + def started(self) -> bool: + """Check if the multipart upload has been initiated.""" + pass + + @abstractmethod + def writer(self, kw: Dict[str, Any], *, client: Any = None) -> Any: + """ + Return a Dask-compatible writer for multipart uploads. + + :param kw: Additional parameters for the writer. + :param client: Dask client for distributed execution. + """ + pass + + @abstractmethod + def upload( + self, + chunks: Union["dask.bag.Bag", List["dask.bag.Bag"]], + *, + mk_header: Any = None, + mk_footer: Any = None, + user_kw: Dict[str, Any] = None, + writes_per_chunk: int = 1, + spill_sz: int = 20 * (1 << 20), + client: Any = None, + **kw, + ) -> Any: + """ + Orchestrate the upload process with multipart uploads. + + :param chunks: Dask bag of chunks to upload. + :param mk_header: Function to create header data. + :param mk_footer: Function to create footer data. + :param user_kw: User-provided metadata for the upload. + :param writes_per_chunk: Number of writes per chunk. + :param spill_sz: Spill size for buffering data. + :param client: Dask client for distributed execution. + :return: A Dask delayed object representing the finalised upload. + """ + pass diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index b2b3a198..5029456b 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -5,11 +5,13 @@ from __future__ import annotations from threading import Lock -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from cachetools import cached from ._mpu import PartsWriter, SomeData, mpu_write +from ._multipart import MultiPartUploadBase + if TYPE_CHECKING: import dask.bag @@ -17,7 +19,7 @@ from botocore.credentials import ReadOnlyCredentials from dask.delayed import Delayed -_state: dict[str, Any] = {} +_state: Dict[str, Any] = {} def _mpu_local_lock(k="mpu_lock") -> Lock: @@ -68,7 +70,7 @@ def max_part(self) -> int: return 10_000 -class MultiPartUpload(S3Limits): +class MultiPartUpload(S3Limits, MultiPartUploadBase): """ Dask to S3 dumper. """ @@ -92,6 +94,7 @@ def __init__( @cached({}) def s3_client(self): + """Return the S3 client.""" # pylint: disable=import-outside-toplevel,import-error from botocore.session import Session @@ -108,15 +111,15 @@ def s3_client(self): ) def initiate(self, **kw) -> str: + """Initiate the S3 multipart upload.""" assert self.uploadId == "" s3 = self.s3_client() - rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key, **kw) - uploadId = rr["UploadId"] - self.uploadId = uploadId - return uploadId + self.uploadId = rr["UploadId"] + return self.uploadId - def write_part(self, part: int, data: SomeData) -> dict[str, Any]: + def write_part(self, part: int, data: SomeData) -> Dict[str, Any]: + """Write a single part to S3.""" s3 = self.s3_client() assert self.uploadId != "" rr = s3.upload_part( @@ -126,31 +129,33 @@ def write_part(self, part: int, data: SomeData) -> dict[str, Any]: Key=self.key, UploadId=self.uploadId, ) - etag = rr["ETag"] - return {"PartNumber": part, "ETag": etag} + return {"PartNumber": part, "ETag": rr["ETag"]} @property def url(self) -> str: + """Return the S3 URL of the object.""" return f"s3://{self.bucket}/{self.key}" - def finalise(self, parts: list[dict[str, Any]]) -> str: + def finalise(self, parts: List[Dict[str, Any]]) -> str: + """Finalise the multipart upload.""" + # pylint: disable=import-outside-toplevel,import-error s3 = self.s3_client() assert self.uploadId - rr = s3.complete_multipart_upload( Bucket=self.bucket, Key=self.key, UploadId=self.uploadId, MultipartUpload={"Parts": parts}, ) - return rr["ETag"] - + @property def started(self) -> bool: + """Check if the multipart upload has been initiated.""" return len(self.uploadId) > 0 def cancel(self, other: str = ""): + """Cancel the multipart upload.""" uploadId = other if other else self.uploadId if not uploadId: return @@ -169,23 +174,23 @@ def cancel(self, other: str = ""): if uploadId == self.uploadId: self.uploadId = "" - def list_active(self): + def list_active(self) -> List[str]: + """List active multipart uploads.""" s3 = self.s3_client() rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key) return [x["UploadId"] for x in rr.get("Uploads", [])] def read(self, **kw): + """Read the object directly from S3.""" s3 = self.s3_client() return s3.get_object(Bucket=self.bucket, Key=self.key, **kw)["Body"].read() def __dask_tokenize__(self): - return ( - self.bucket, - self.key, - self.uploadId, - ) + """Dask-specific tokenization for S3 uploads.""" + return (self.bucket, self.key, self.uploadId) def writer(self, kw, *, client: Any = None) -> PartsWriter: + """Return a Dask-compatible writer.""" if client is None: client = _dask_client() writer = DelayedS3Writer(self, kw) @@ -193,19 +198,19 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter: writer.prep_client(client) return writer - # pylint: disable=too-many-arguments def upload( self, - chunks: "dask.bag.Bag" | list["dask.bag.Bag"], + chunks: "dask.bag.Bag" | List["dask.bag.Bag"], *, mk_header: Any = None, mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, + user_kw: Dict[str, Any] | None = None, writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, **kw, ) -> "Delayed": + """Upload chunks to S3 with multipart uploads.""" write = self.writer(kw, client=client) if spill_sz else None return mpu_write( chunks, @@ -233,7 +238,7 @@ class DelayedS3Writer(S3Limits): # pylint: disable=import-outside-toplevel,import-error - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: MultiPartUpload, kw: Dict[str, Any]): self.mpu = mpu self.kw = kw # mostly ContentType= kinda thing self._shared_var: Optional["distributed.Variable"] = None @@ -279,7 +284,7 @@ def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: uploadId = _safe_get(shared_state, 0.1) if uploadId is not None: - # someone else initialized it + # someone else initialised it mpu.uploadId = uploadId return mpu @@ -287,7 +292,7 @@ def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: with lock: uploadId = _safe_get(shared_state, 0.1) if uploadId is not None: - # someone else initialized it while we were getting a lock + # someone else initialised it while we were getting a lock mpu.uploadId = uploadId return mpu @@ -301,11 +306,11 @@ def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: assert mpu.started or final_write return mpu - def __call__(self, part: int, data: SomeData) -> dict[str, Any]: + def __call__(self, part: int, data: SomeData) -> Dict[str, Any]: mpu = self._ensure_init() return mpu.write_part(part, data) - def finalise(self, parts: list[dict[str, Any]]) -> Any: + def finalise(self, parts: List[Dict[str, Any]]) -> Any: assert len(parts) > 0 mpu = self._ensure_init() etag = mpu.finalise(parts) diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 60bd1fac..030ae533 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -11,18 +11,21 @@ from functools import partial from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from urllib.parse import urlparse from xml.sax.saxutils import escape as xml_escape import numpy as np import xarray as xr + from .._interop import have from ..geobox import GeoBox from ..math import resolve_nodata from ..types import Shape2d, SomeNodata, Unset, shape_ +from ._az import MultiPartUpload as AzMultiPartUpload from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -from ._s3 import MultiPartUpload, s3_parse_url +from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -41,13 +44,13 @@ def _render_gdal_metadata( - band_stats: list[dict[str, float]] | dict[str, float] | None, + band_stats: List[Dict[str, float]] | Dict[str, float] | None, precision: int = 10, pad: int = 0, eol: str = "", gdal_metadata_extra: Optional[List[str]] = None, ) -> str: - def _item(sample: int, stats: dict[str, float]) -> str: + def _item(sample: int, stats: Dict[str, float]) -> str: return eol.join( [ f'{v:{pad}.{precision}f}' @@ -471,7 +474,7 @@ def _patch_hdr( tiles: List[Tuple[int, Tuple[int, int, int, int]]], meta: CogMeta, hdr0: bytes, - stats: Optional[list[dict[str, float]]] = None, + stats: Optional[List[Dict[str, float]]] = None, gdal_metadata_extra: Optional[List[str]] = None, ) -> bytes: # pylint: disable=import-outside-toplevel,import-error @@ -625,15 +628,16 @@ def save_cog_with_dask( bigtiff: bool = True, overview_resampling: Union[int, str] = "nearest", aws: Optional[Dict[str, Any]] = None, + azure: Optional[Dict[str, Any]] = None, client: Any = None, stats: bool | int = True, **kw, ) -> Any: """ - Save a Cloud Optimized GeoTIFF to S3 or file with Dask. + Save a Cloud Optimized GeoTIFF to S3, Azure Blob Storage, or file with Dask. :param xx: Pixels as :py:class:`xarray.DataArray` backed by Dask - :param dst: S3 url or a file path on shared storage + :param dst: S3, Azure URL, or file path :param compression: Compression to use, default is ``DEFLATE`` :param level: Compression "level", depends on chosen compression :param predictor: TIFF predictor setting @@ -642,6 +646,7 @@ def save_cog_with_dask( :param blocksize: Configure blocksizes for main and overview images :param bigtiff: Generate BigTIFF by default, set to ``False`` to disable :param aws: Configure AWS write access + :param azure: Azure credentials/config :param client: Dask client :param stats: Set to ``False`` to disable stats computation @@ -652,34 +657,33 @@ def save_cog_with_dask( from ..xr import ODCExtensionDa - if aws is None: - aws = {} - - upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw} - upload_params.update( - {k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws} - ) + # Normalise AWS/Azure params + aws = aws or {} + azure = azure or {} + upload_params = {k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw} parts_base = kw.pop("parts_base", None) - # normalize compression and remove GDAL compat options from kw + # Normalise compression settings and remove GDAL compat options from kw predictor, compression, compressionargs = _norm_compression_tifffile( xx.dtype, predictor, compression, compressionargs, level=level, kw=kw ) + xx_odc = xx.odc assert isinstance(xx_odc, ODCExtensionDa) assert isinstance(xx_odc.geobox, GeoBox) or xx_odc.geobox is None - + ydim = xx_odc.ydim - data_chunks: Tuple[int, int] = xx.data.chunksize[ydim : ydim + 2] + data_chunks: Tuple[int, int] = xx.data.chunksize[ydim:ydim + 2] if isinstance(blocksize, Unset): blocksize = [data_chunks, int(max(*data_chunks) // 2)] + # Metadata band_names = _band_names(xx) sample_descriptions_metadata = _gdal_sample_descriptions(band_names) - no_metadata = (stats is False) and not band_names - gdal_metadata = None if no_metadata else "" + gdal_metadata = None if stats is False and not band_names else "" + # Prepare COG metadata and header meta, hdr0 = _make_empty_cog( xx.shape, xx.dtype, @@ -696,9 +700,7 @@ def save_cog_with_dask( hdr0 = bytes(hdr0) if band_names and len(band_names) != meta.nsamples: - raise ValueError( - f"Found {len(band_names)} band names ({band_names}) but there are {meta.nsamples} bands." - ) + raise ValueError(f"Found {len(band_names)} band names ({band_names}), expected {meta.nsamples} bands.") layers = _pyramids_from_cog_metadata(xx, meta, resampling=overview_resampling) @@ -711,6 +713,7 @@ def save_cog_with_dask( layers[stats].data, nodata=xx_odc.nodata, yaxis=xx_odc.ydim ) + # Prepare tiles _tiles: List["dask.bag.Bag"] = [] for scale_idx, (mm, img) in enumerate(zip(meta.flatten(), layers)): for sample_idx in range(meta.num_planes): @@ -728,19 +731,23 @@ def save_cog_with_dask( "_stats": _stats, } - tiles_write_order = _tiles[::-1] - if len(tiles_write_order) > 4: - tiles_write_order = [ - dask.bag.concat(tiles_write_order[:4]), - *tiles_write_order[4:], - ] - - bucket, key = s3_parse_url(dst) - if not bucket: - # assume disk output + # Determine output type and initiate uploader + parsed_url = urlparse(dst) + if parsed_url.scheme == "s3": + bucket, key = s3_parse_url(dst) + uploader = S3MultiPartUpload(bucket, key, **aws) + elif parsed_url.scheme == "az": + uploader = AzMultiPartUpload( + account_url=azure.get("account_url"), + container=parsed_url.netloc, + blob=parsed_url.path.lstrip("/"), + credential=azure.get("credential"), + ) + else: + # Assume local disk write = MPUFileSink(dst, parts_base=parts_base) return mpu_write( - tiles_write_order, + _tiles[::-1], write, mk_header=_patch_hdr, user_kw={ @@ -752,15 +759,13 @@ def save_cog_with_dask( **upload_params, ) - upload_params["ContentType"] = ( - "image/tiff;application=geotiff;profile=cloud-optimized" - ) + # Upload tiles + tiles_write_order = _tiles[::-1] # Reverse tiles for writing + if len(tiles_write_order) > 4: # Optimize for larger datasets + tiles_write_order = [dask.bag.concat(tiles_write_order[:4]), *tiles_write_order[4:]] + - cleanup = aws.pop("cleanup", False) - s3_sink = MultiPartUpload(bucket, key, **aws) - if cleanup: - s3_sink.cancel("all") - return s3_sink.upload( + return uploader.upload( tiles_write_order, mk_header=_patch_hdr, user_kw={ diff --git a/tests/test_az.py b/tests/test_az.py new file mode 100644 index 00000000..010b4167 --- /dev/null +++ b/tests/test_az.py @@ -0,0 +1,62 @@ +"""Tests for the Azure MultiPartUpload class.""" +import unittest +from unittest.mock import MagicMock, patch + +from odc.geo.cog._az import MultiPartUpload + + +def test_s3_mpu(): + """Basic test for the MultiPartUpload class.""" + account_url = "https://account_name.blob.core.windows.net" + mpu = MultiPartUpload(account_url, "container", "some.blob", None) + assert mpu.account_url == "https://account_name.blob.core.windows.net" + assert mpu.container == "container" + assert mpu.blob == "some.blob" + assert mpu.credential is None + + +class TestMultiPartUpload(unittest.TestCase): + """Test the MultiPartUpload class.""" + @patch("odc.geo.cog._az.BlobServiceClient") + def test_azure_multipart_upload(self, mock_blob_service_client): + """Test the MultiPartUpload class.""" + # Arrange - mock the Azure Blob SDK + # Mock the blob client and its methods + mock_blob_client = MagicMock() + mock_container_client = MagicMock() + mcc = mock_container_client + mock_blob_service_client.return_value.get_container_client.return_value = mcc + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Simulate return values for Azure Blob SDK methods + mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" + + # Test parameters + account_url = "https://mockaccount.blob.core.windows.net" + container = "mock-container" + blob = "mock-blob" + credential = "mock-sas-token" + + # Act - create an instance of MultiPartUpload and call its methods + azure_upload = MultiPartUpload(account_url, container, blob, credential) + upload_id = azure_upload.initiate() + part1 = azure_upload.write_part(1, b"first chunk of data") + part2 = azure_upload.write_part(2, b"second chunk of data") + etag = azure_upload.finalise([part1, part2]) + + # Assert - check the results + # Check that the initiate method behaves as expected + self.assertEqual(upload_id, "azure-block-upload") + + # Verify the calls to Azure Blob SDK methods + mock_blob_service_client.assert_called_once_with(account_url=account_url, credential=credential) + mock_blob_client.stage_block.assert_any_call(part1["BlockId"], b"first chunk of data") + mock_blob_client.stage_block.assert_any_call(part2["BlockId"], b"second chunk of data") + mock_blob_client.commit_block_list.assert_called_once() + self.assertEqual(etag, "mock-etag") + + # Verify block list passed during finalise + block_list = mock_blob_client.commit_block_list.call_args[0][0] + self.assertEqual(len(block_list), 2) + self.assertEqual(block_list[0].id, part1["BlockId"]) + self.assertEqual(block_list[1].id, part2["BlockId"])