From 3696a905e1d1544d8ebb8e0459bcb80e3e525b8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mariusz=20Kry=C5=84ski?= Date: Thu, 14 Nov 2019 23:22:19 +0100 Subject: [PATCH] S3 streaming --- fs_s3fs/_s3fs.py | 190 ++++---------------------------- fs_s3fs/_s3fs_file.py | 245 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+), 171 deletions(-) create mode 100644 fs_s3fs/_s3fs_file.py diff --git a/fs_s3fs/_s3fs.py b/fs_s3fs/_s3fs.py index d7f98e2..ac2a6f1 100644 --- a/fs_s3fs/_s3fs.py +++ b/fs_s3fs/_s3fs.py @@ -8,9 +8,7 @@ from datetime import datetime import io import itertools -import os from ssl import SSLError -import tempfile import threading import mimetypes @@ -18,7 +16,6 @@ from botocore.exceptions import ClientError, EndpointConnectionError import six -from six import text_type from fs import ResourceType from fs.base import FS @@ -29,6 +26,8 @@ from fs.path import basename, dirname, forcedir, join, normpath, relpath from fs.time import datetime_to_epoch +from ._s3fs_file import S3InputFile, S3OutputFile + def _make_repr(class_name, *args, **kwargs): """ @@ -57,115 +56,6 @@ def __repr__(self): return "{}({})".format(class_name, ", ".join(arguments)) -class S3File(io.IOBase): - """Proxy for a S3 file.""" - - @classmethod - def factory(cls, filename, mode, on_close): - """Create a S3File backed with a temporary file.""" - _temp_file = tempfile.TemporaryFile() - proxy = cls(_temp_file, filename, mode, on_close=on_close) - return proxy - - def __repr__(self): - return _make_repr( - self.__class__.__name__, self.__filename, text_type(self.__mode) - ) - - def __init__(self, f, filename, mode, on_close=None): - self._f = f - self.__filename = filename - self.__mode = mode - self._on_close = on_close - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - @property - def raw(self): - return self._f - - def close(self): - if self._on_close is not None: - self._on_close(self) - - @property - def closed(self): - return self._f.closed - - def fileno(self): - return self._f.fileno() - - def flush(self): - return self._f.flush() - - def isatty(self): - return self._f.asatty() - - def readable(self): - return self.__mode.reading - - def readline(self, limit=-1): - return self._f.readline(limit) - - def readlines(self, hint=-1): - if hint == -1: - return self._f.readlines(hint) - else: - size = 0 - lines = [] - for line in iter(self._f.readline, b""): - lines.append(line) - size += len(line) - if size > hint: - break - return lines - - def seek(self, offset, whence=os.SEEK_SET): - if whence not in (os.SEEK_CUR, os.SEEK_END, os.SEEK_SET): - raise ValueError("invalid value for 'whence'") - self._f.seek(offset, whence) - return self._f.tell() - - def seekable(self): - return True - - def tell(self): - return self._f.tell() - - def writable(self): - return self.__mode.writing - - def writelines(self, lines): - return self._f.writelines(lines) - - def read(self, n=-1): - if not self.__mode.reading: - raise IOError("not open for reading") - return self._f.read(n) - - def readall(self): - return self._f.readall() - - def readinto(self, b): - return self._f.readinto() - - def write(self, b): - if not self.__mode.writing: - raise IOError("not open for reading") - self._f.write(b) - return len(b) - - def truncate(self, size=None): - if size is None: - size = self._f.tell() - self._f.truncate(size) - return size - - @contextlib.contextmanager def s3errors(path): """Translate S3 errors to FSErrors.""" @@ -527,29 +417,18 @@ def openbin(self, path, mode="r", buffering=-1, **options): _path = self.validatepath(path) _key = self._path_to_key(_path) - if _mode.create: + if _mode.appending: + raise errors.ResourceError(path, msg="append mode is not supported") - def on_close_create(s3file): - """Called when the S3 file closes, to upload data.""" + if _mode.create: + if self.strict: try: - s3file.raw.seek(0) - with s3errors(path): - self.client.upload_fileobj( - s3file.raw, - self._bucket_name, - _key, - ExtraArgs=self._get_upload_args(_key), - ) - finally: - s3file.raw.close() - - try: - dir_path = dirname(_path) - if dir_path != "/": - _dir_key = self._path_to_dir_key(dir_path) - self._get_object(dir_path, _dir_key) - except errors.ResourceNotFound: - raise errors.ResourceNotFound(path) + dir_path = dirname(_path) + if dir_path != "/": + _dir_key = self._path_to_dir_key(dir_path) + self._get_object(dir_path, _dir_key) + except errors.ResourceNotFound: + raise errors.ResourceNotFound(path) try: info = self._getinfo(path) @@ -561,50 +440,19 @@ def on_close_create(s3file): if info.is_dir: raise errors.FileExpected(path) - s3file = S3File.factory(path, _mode, on_close=on_close_create) - if _mode.appending: - try: - with s3errors(path): - self.client.download_fileobj( - self._bucket_name, - _key, - s3file.raw, - ExtraArgs=self.download_args, - ) - except errors.ResourceNotFound: - pass - else: - s3file.seek(0, os.SEEK_END) - - return s3file + obj = self.s3.Object(self._bucket_name, _key) + return S3OutputFile( + obj, + upload_kwargs=self._get_upload_args(_key) + ) if self.strict: info = self.getinfo(path) if info.is_dir: raise errors.FileExpected(path) - def on_close(s3file): - """Called when the S3 file closes, to upload the data.""" - try: - if _mode.writing: - s3file.raw.seek(0, os.SEEK_SET) - with s3errors(path): - self.client.upload_fileobj( - s3file.raw, - self._bucket_name, - _key, - ExtraArgs=self._get_upload_args(_key), - ) - finally: - s3file.raw.close() - - s3file = S3File.factory(path, _mode, on_close=on_close) - with s3errors(path): - self.client.download_fileobj( - self._bucket_name, _key, s3file.raw, ExtraArgs=self.download_args - ) - s3file.seek(0, os.SEEK_SET) - return s3file + obj = self.s3.Object(self._bucket_name, _key) + return S3InputFile(obj) def remove(self, path): self.check() diff --git a/fs_s3fs/_s3fs_file.py b/fs_s3fs/_s3fs_file.py new file mode 100644 index 0000000..ed18f4e --- /dev/null +++ b/fs_s3fs/_s3fs_file.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Mariusz KryƄski +# (C) 2019 Michael Penkov +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements file-like objects for reading and writing from/to S3.""" + +import io +from functools import wraps +import botocore.exceptions +import logging +import sys + +logger = logging.getLogger(__name__) + + +def check_if_open(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + if self.closed: + logger.warning("file is already closed") + return + return method(self, *args, **kwargs) + + return wrapper + + +class S3InputFile(io.RawIOBase): + def __init__(self, s3_object): + self._s3_object = s3_object + self._position = 0 + self._stream = None + + @property + def size(self): + if not hasattr(self, "_size"): + self._size = self._s3_object.content_length + return self._size + + @property + def has_size(self): + return hasattr(self, "_size") + + def _set_position(self, new_position): + if new_position != self._position: + if self._stream: + self._stream.close() + self._stream = None + self._position = new_position + + def seek(self, offset, whence=io.SEEK_SET): + if whence == io.SEEK_SET: + self._set_position(offset) + elif whence == io.SEEK_CUR: + self._set_position(self._position + offset) + elif whence == io.SEEK_END: + if offset > 0: + raise ValueError( + "invalid offset, for SEEK_END it should be less or equal 0" + ) + self._set_position(self.size + offset) + else: + raise ValueError("invalid whence %r".format(whence)) + return self._position + + def read(self, size=-1): + if size == 0 or self.has_size and self._position >= self.size: + return b"" + + if not self._stream: + range_str = "bytes={}-".format(self._position) + try: + response = self._s3_object.get(Range=range_str) + except botocore.exceptions.ClientError as e: + error = e.response.get("Error", {}) + if error.get("Code") == "InvalidRange": + if "ActualObjectSize" in error: + self._size = int(error["ActualObjectSize"]) + return b"" + raise + content_range = response.get("ContentRange") + if content_range: + _, length = content_range.rsplit("/") + self._size = int(length) + self._stream = response["Body"] + + read_args = (size,) if size >= 0 else () + data = self._stream.read(*read_args) + self._position += len(data) + return data + + def readall(self): + return self.read() + + def readinto(self, buf): + data = self.read(len(buf)) + data_len = len(data) + buf[:data_len] = data + return data_len + + def close(self): + if self._stream: + self._stream.close() + self._stream = None + + def readable(self): + return True + + def seekable(self): + return True + + +DEFAULT_MIN_PART_SIZE = 50 * 1024 ** 2 +"""Default minimum part size for S3 multipart uploads""" + +MIN_MIN_PART_SIZE = 5 * 1024 ** 2 +"""The absolute minimum permitted by Amazon.""" + + +class S3OutputFile(io.BufferedIOBase): + """Writes bytes to S3. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__( + self, + s3_object, + min_part_size=DEFAULT_MIN_PART_SIZE, + upload_kwargs=None, + ): + self._upload_kwargs = upload_kwargs or {} + if min_part_size < MIN_MIN_PART_SIZE: + logger.warning( + "S3 requires minimum part size >= 5MB; multipart upload may fail" + ) + + self._object = s3_object + self._min_part_size = min_part_size + self._mp = self._object.initiate_multipart_upload(**self._upload_kwargs) + + self._buf = b'' + self._total_bytes = 0 + self._total_parts = 0 + self._parts = [] + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def flush(self): + pass + + @property + def closed(self): + return self._mp is None + + def writable(self): + """Return True if the stream supports writing.""" + return True + + def tell(self): + """Return the current stream position.""" + return self._total_bytes + + def detach(self): + raise io.UnsupportedOperation("detach() not supported") + + @check_if_open + def write(self, b): + """Write the given buffer (bytes, bytearray, memoryview or any buffer + interface implementation) to the S3 file. + + For more information about buffers, see + https://docs.python.org/3/c-api/buffer.html + + There's buffering happening under the covers, so this may not actually + do any HTTP transfer right away.""" + + if self._buf: + self._buf += b + else: + self._buf = b + + length = len(b) + self._total_bytes += length + + if len(self._buf) >= self._min_part_size: + self._upload_next_part() + + return length + + @check_if_open + def close(self): + logger.debug("closing") + + if tuple(sys.exc_info()) != (None, None, None): + self.terminate() + return + + if self._buf: + self._upload_next_part() + + if self._total_bytes: + self._mp.complete(MultipartUpload={"Parts": self._parts}) + logger.debug("completed multipart upload") + else: + # + # AWS complains with "The XML you provided was not well-formed or + # did not validate against our published schema" when the input is + # completely empty => abort the upload, no file created. + # + # We work around this by creating an empty file explicitly. + # + logger.debug("empty input, ignoring multipart upload") + self.terminate() + self._object.put(Body=b"", **self._upload_kwargs) + self._mp = None + logger.debug("successfully closed") + + @check_if_open + def terminate(self): + """Cancel the underlying multipart upload.""" + assert self._mp, "no multipart upload in progress" + self._mp.abort() + self._mp = None + + def _upload_next_part(self): + part_num = self._total_parts + 1 + logger.info( + "uploading part #%i, %i bytes (total %.3fGB)", + part_num, + len(self._buf), + self._total_bytes / 1024.0 ** 3, + ) + part = self._mp.Part(part_num) + upload = part.upload(Body=self._buf) + self._parts.append({"ETag": upload["ETag"], "PartNumber": part_num}) + logger.debug("upload of part #%i finished" % part_num) + + self._total_parts += 1 + self._buf = bytes()