Skip to content

Commit

Permalink
fix: mypy errors in cog module
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jan 7, 2025
1 parent bc8811c commit 9db7b9a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
15 changes: 11 additions & 4 deletions odc/geo/cog/_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Union

import dask
from dask.delayed import Delayed

from ._mpu import mpu_write
from ._multipart import MultiPartUploadBase
Expand Down Expand Up @@ -33,6 +34,11 @@ def max_part(self) -> int:


class AzMultiPartUpload(AzureLimits, MultiPartUploadBase):
"""
Azure Blob Storage multipart upload.
"""
# pylint: disable=too-many-instance-attributes

def __init__(
self, account_url: str, container: str, blob: str, credential: Any = None
):
Expand Down Expand Up @@ -94,10 +100,11 @@ def finalise(self, parts: list[dict[str, Any]]) -> str:
self.blob_client.commit_block_list(block_list)
return self.blob_client.get_blob_properties().etag

def cancel(self):
def cancel(self, other: str = ""):
"""
Cancel the upload by clearing the block list.
"""
assert other == ""
self.block_ids.clear()

@property
Expand All @@ -118,7 +125,7 @@ def started(self) -> bool:
"""
return bool(self.block_ids)

def writer(self, kw: dict[str, Any], client: Any = None):
def writer(self, kw: dict[str, Any], *, client: Any = None):
"""
Return a stateless writer compatible with Dask.
"""
Expand All @@ -130,12 +137,12 @@ def upload(
*,
mk_header: Any = None,
mk_footer: Any = None,
user_kw: dict[str, Any] = None,
user_kw: dict[str, Any] | None = None,
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
**kw,
) -> dask.delayed.Delayed:
) -> Delayed:
"""
Upload chunks to Azure Blob Storage with multipart uploads.
Expand Down
16 changes: 5 additions & 11 deletions odc/geo/cog/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Union
from typing import Any, Union, TYPE_CHECKING

import dask.bag
if TYPE_CHECKING:
# pylint: disable=import-outside-toplevel,import-error
import dask.bag


class MultiPartUploadBase(ABC):
Expand All @@ -18,34 +20,28 @@ class MultiPartUploadBase(ABC):
@abstractmethod
def initiate(self, **kwargs) -> str:
"""Initiate a multipart upload and return an identifier."""
pass

@abstractmethod
def write_part(self, part: int, data: bytes) -> dict[str, Any]:
"""Upload a single part."""
pass

@abstractmethod
def finalise(self, parts: list[dict[str, Any]]) -> str:
"""Finalise the upload with a list of parts."""
pass

@abstractmethod
def cancel(self, other: str = ""):
"""Cancel the multipart upload."""
pass

@property
@abstractmethod
def url(self) -> str:
"""Return the URL of the upload target."""
pass

@property
@abstractmethod
def started(self) -> bool:
"""Check if the multipart upload has been initiated."""
pass

@abstractmethod
def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any:
Expand All @@ -55,7 +51,6 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any:
:param kw: Additional parameters for the writer.
:param client: Dask client for distributed execution.
"""
pass

@abstractmethod
def upload(
Expand All @@ -64,7 +59,7 @@ def upload(
*,
mk_header: Any = None,
mk_footer: Any = None,
user_kw: dict[str, Any] = None,
user_kw: dict[str, Any] | None = None,
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
Expand All @@ -82,4 +77,3 @@ def upload(
:param client: Dask client for distributed execution.
:return: A Dask delayed object representing the finalised upload.
"""
pass
13 changes: 9 additions & 4 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..types import Shape2d, SomeNodata, Unset, shape_
from ._mpu import mpu_write
from ._mpu_fs import MPUFileSink
from ._multipart import MultiPartUploadBase

from ._shared import (
GDAL_COMP,
Expand Down Expand Up @@ -736,22 +737,26 @@ def save_cog_with_dask(
# Determine output type and initiate uploader
parsed_url = urlparse(dst)
if parsed_url.scheme == "s3":
if have.s3:
if have.botocore:
from ._s3 import S3MultiPartUpload, s3_parse_url

bucket, key = s3_parse_url(dst)
uploader = S3MultiPartUpload(bucket, key, **aws)
uploader: MultiPartUploadBase = S3MultiPartUpload(bucket, key, **aws)
else:
raise RuntimeError("Please install `boto3` to use S3")
elif parsed_url.scheme == "az":
if have.azure:
from ._az import AzMultiPartUpload

assert azure is not None
assert "account_url" in azure
assert "credential" in azure

uploader = AzMultiPartUpload(
account_url=azure.get("account_url"),
account_url=azure["account_url"],
container=parsed_url.netloc,
blob=parsed_url.path.lstrip("/"),
credential=azure.get("credential"),
credential=azure["credential"],
)
else:
raise RuntimeError("Please install `azure-storage-blob` to use Azure")
Expand Down

0 comments on commit 9db7b9a

Please sign in to comment.