diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 9995a0ab..0b51c1c8 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -2,6 +2,7 @@ from typing import Any, Union import dask +from dask.delayed import Delayed from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -33,6 +34,11 @@ def max_part(self) -> int: class AzMultiPartUpload(AzureLimits, MultiPartUploadBase): + """ + Azure Blob Storage multipart upload. + """ + # pylint: disable=too-many-instance-attributes + def __init__( self, account_url: str, container: str, blob: str, credential: Any = None ): @@ -94,10 +100,11 @@ def finalise(self, parts: list[dict[str, Any]]) -> str: self.blob_client.commit_block_list(block_list) return self.blob_client.get_blob_properties().etag - def cancel(self): + def cancel(self, other: str = ""): """ Cancel the upload by clearing the block list. """ + assert other == "" self.block_ids.clear() @property @@ -118,7 +125,7 @@ def started(self) -> bool: """ return bool(self.block_ids) - def writer(self, kw: dict[str, Any], client: Any = None): + def writer(self, kw: dict[str, Any], *, client: Any = None): """ Return a stateless writer compatible with Dask. """ @@ -130,12 +137,12 @@ def upload( *, mk_header: Any = None, mk_footer: Any = None, - user_kw: dict[str, Any] = None, + user_kw: dict[str, Any] | None = None, writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, **kw, - ) -> dask.delayed.Delayed: + ) -> Delayed: """ Upload chunks to Azure Blob Storage with multipart uploads. diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index c9060bee..0fc9b4c8 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -7,9 +7,11 @@ """ from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any, Union, TYPE_CHECKING -import dask.bag +if TYPE_CHECKING: + # pylint: disable=import-outside-toplevel,import-error + import dask.bag class MultiPartUploadBase(ABC): @@ -18,34 +20,28 @@ class MultiPartUploadBase(ABC): @abstractmethod def initiate(self, **kwargs) -> str: """Initiate a multipart upload and return an identifier.""" - pass @abstractmethod def write_part(self, part: int, data: bytes) -> dict[str, Any]: """Upload a single part.""" - pass @abstractmethod def finalise(self, parts: list[dict[str, Any]]) -> str: """Finalise the upload with a list of parts.""" - pass @abstractmethod def cancel(self, other: str = ""): """Cancel the multipart upload.""" - pass @property @abstractmethod def url(self) -> str: """Return the URL of the upload target.""" - pass @property @abstractmethod def started(self) -> bool: """Check if the multipart upload has been initiated.""" - pass @abstractmethod def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: @@ -55,7 +51,6 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: :param kw: Additional parameters for the writer. :param client: Dask client for distributed execution. """ - pass @abstractmethod def upload( @@ -64,7 +59,7 @@ def upload( *, mk_header: Any = None, mk_footer: Any = None, - user_kw: dict[str, Any] = None, + user_kw: dict[str, Any] | None = None, writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, @@ -82,4 +77,3 @@ def upload( :param client: Dask client for distributed execution. :return: A Dask delayed object representing the finalised upload. """ - pass diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index fae90408..3ab86e21 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -24,6 +24,7 @@ from ..types import Shape2d, SomeNodata, Unset, shape_ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink +from ._multipart import MultiPartUploadBase from ._shared import ( GDAL_COMP, @@ -736,22 +737,26 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": - if have.s3: + if have.botocore: from ._s3 import S3MultiPartUpload, s3_parse_url bucket, key = s3_parse_url(dst) - uploader = S3MultiPartUpload(bucket, key, **aws) + uploader: MultiPartUploadBase = S3MultiPartUpload(bucket, key, **aws) else: raise RuntimeError("Please install `boto3` to use S3") elif parsed_url.scheme == "az": if have.azure: from ._az import AzMultiPartUpload + assert azure is not None + assert "account_url" in azure + assert "credential" in azure + uploader = AzMultiPartUpload( - account_url=azure.get("account_url"), + account_url=azure["account_url"], container=parsed_url.netloc, blob=parsed_url.path.lstrip("/"), - credential=azure.get("credential"), + credential=azure["credential"], ) else: raise RuntimeError("Please install `azure-storage-blob` to use Azure")