Skip to content

Commit

Permalink
Merge pull request #22 from aiokitchen/feature/raise-exceptions-from-…
Browse files Browse the repository at this point in the history
…response

Raise exceptions from response
  • Loading branch information
mosquito authored Aug 14, 2023
2 parents dca4276 + 68044d5 commit 5da4e35
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 93 deletions.
87 changes: 62 additions & 25 deletions aiohttp_s3_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import io
import logging
import os
import sys
import typing as t
from collections import deque
from contextlib import suppress
from functools import partial
from http import HTTPStatus
from itertools import chain
Expand All @@ -16,8 +16,12 @@
from urllib.parse import quote

from aiohttp import ClientSession, hdrs
from aiohttp.client import _RequestContextManager as RequestContextManager
from aiohttp.client_exceptions import ClientError
# noinspection PyProtectedMember
from aiohttp.client import (
_RequestContextManager as RequestContextManager,
ClientResponse,
)
from aiohttp.client_exceptions import ClientError, ClientResponseError
from aiomisc import asyncbackoff, threaded, threaded_iterable
from aws_request_signer import UNSIGNED_PAYLOAD
from multidict import CIMultiDict, CIMultiDictProxy
Expand All @@ -31,23 +35,29 @@
parse_create_multipart_upload_id, parse_list_objects,
)


log = logging.getLogger(__name__)


CHUNK_SIZE = 2 ** 16
DONE = object()
EMPTY_STR_HASH = hashlib.sha256(b"").hexdigest()
PART_SIZE = 5 * 1024 * 1024 # 5MB

HeadersType = t.Union[t.Dict, CIMultiDict, CIMultiDictProxy]


threaded_iterable_constrained = threaded_iterable(max_size=2)


class AwsError(ClientError):
pass
class AwsError(ClientResponseError):
def __init__(
self, resp: ClientResponse, message: str, *history: ClientResponse
):
super().__init__(
headers=resp.headers,
history=(resp, *history),
message=message,
request_info=resp.request_info,
status=resp.status,
)


class AwsUploadError(AwsError):
Expand All @@ -58,6 +68,19 @@ class AwsDownloadError(AwsError):
pass


if sys.version_info < (3, 8):
from contextlib import suppress

@threaded
def unlink_path(path: Path) -> None:
with suppress(FileNotFoundError):
os.unlink(path.resolve())
else:
@threaded
def unlink_path(path: Path) -> None:
path.unlink(missing_ok=True)


@threaded
def concat_files(
target_file: Path, files: t.List[t.IO[bytes]], buffer_size: int,
Expand Down Expand Up @@ -117,7 +140,6 @@ def file_sender(

async_file_sender = threaded_iterable_constrained(file_sender)


DataType = t.Union[bytes, str, t.AsyncIterable[bytes]]
ParamsType = t.Optional[t.Mapping[str, str]]

Expand Down Expand Up @@ -282,8 +304,11 @@ async def _create_multipart_upload(
payload = await resp.read()
if resp.status != HTTPStatus.OK:
raise AwsUploadError(
f"Wrong status code {resp.status} from s3 with message "
f"{payload.decode()}.",
resp,
(
f"Wrong status code {resp.status} from s3 "
f"with message {payload.decode()}."
),
)
return parse_create_multipart_upload_id(payload)

Expand All @@ -308,8 +333,11 @@ async def _complete_multipart_upload(
if resp.status != HTTPStatus.OK:
payload = await resp.text()
raise AwsUploadError(
f"Wrong status code {resp.status} from s3 with message "
f"{payload}.",
resp,
(
f"Wrong status code {resp.status} from s3 "
f"with message {payload}."
),
)

async def _put_part(
Expand All @@ -331,8 +359,11 @@ async def _put_part(
payload = await resp.text()
if resp.status != HTTPStatus.OK:
raise AwsUploadError(
f"Wrong status code {resp.status} from s3 with message "
f"{payload}.",
resp,
(
f"Wrong status code {resp.status} from s3 "
f"with message {payload}."
),
)
return resp.headers["Etag"].strip('"')

Expand Down Expand Up @@ -519,7 +550,6 @@ async def _download_range(
writer: t.Callable[[bytes, int, int], t.Coroutine],
*,
etag: str,
pos: int,
range_start: int,
req_range_start: int,
req_range_end: int,
Expand All @@ -546,8 +576,11 @@ async def _download_range(
async with self.get(object_name, headers=headers, **kwargs) as resp:
if resp.status not in (HTTPStatus.PARTIAL_CONTENT, HTTPStatus.OK):
raise AwsDownloadError(
f"Got wrong status code {resp.status} on range download "
f"of {object_name}",
resp,
(
f"Got wrong status code {resp.status} on "
f"range download of {object_name}"
),
)
while True:
chunk = await resp.content.read(buffer_size)
Expand Down Expand Up @@ -594,7 +627,6 @@ async def _download_worker(
object_name,
writer,
etag=etag,
pos=(req_range_start - range_start),
range_start=range_start,
req_range_start=req_range_start,
req_range_end=req_range_end - 1,
Expand Down Expand Up @@ -632,8 +664,11 @@ async def get_file_parallel(
async with self.head(str(object_name), headers=headers) as resp:
if resp.status != HTTPStatus.OK:
raise AwsDownloadError(
f"Got response for HEAD request for {object_name}"
f"of a wrong status {resp.status}",
resp,
(
f"Got response for HEAD request for "
f"{object_name} of a wrong status {resp.status}"
),
)
etag = resp.headers["Etag"]
file_size = int(resp.headers["Content-Length"])
Expand Down Expand Up @@ -692,8 +727,7 @@ async def get_file_parallel(
"Error on file download. Removing possibly incomplete file %s",
file_path,
)
with suppress(FileNotFoundError):
os.unlink(file_path)
await unlink_path(file_path)
raise

async def list_objects_v2(
Expand Down Expand Up @@ -746,8 +780,11 @@ async def list_objects_v2(
async with self.get(str(object_name), params=params) as resp:
if resp.status != HTTPStatus.OK:
raise AwsDownloadError(
f"Got response with wrong status for GET request for "
f"{object_name} with prefix '{prefix}'",
resp,
(
"Got response with wrong status for GET request "
f"for {object_name} with prefix '{prefix}'"
),
)
payload = await resp.read()
metadata, continuation_token = parse_list_objects(payload)
Expand Down
Loading

0 comments on commit 5da4e35

Please sign in to comment.