Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: a Method has a clearer definition (that passes strict type checking) #125

Merged
merged 1 commit into from
May 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 76 additions & 48 deletions stream_zip.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime
from struct import Struct
Expand All @@ -11,8 +12,9 @@
from Crypto.Util import Counter
from Crypto.Protocol.KDF import PBKDF2

#################
# Private methods

################################
# Private sentinel objects/types

_NO_COMPRESSION_BUFFERED_32 = object()
_NO_COMPRESSION_BUFFERED_64 = object()
Expand All @@ -24,43 +26,55 @@
_AUTO_UPGRADE_CENTRAL_DIRECTORY = object()
_NO_AUTO_UPGRADE_CENTRAL_DIRECTORY = object()

_MethodReturnValue = Tuple[object, object, Callable[[], 'zlib._Compress'], Optional[int], Optional[int]]

def __NO_COMPRESSION_BUFFERED_32(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __NO_COMPRESSION_BUFFERED_64(offset, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _NO_COMPRESSION_BUFFERED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __NO_COMPRESSION_STREAMED_32(uncompressed_size, crc_32) -> 'Method':
def method_compressobj(offset, default_get_compressobj):
return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return method_compressobj

def __NO_COMPRESSION_STREAMED_64(uncompressed_size, crc_32) -> 'Method':
def method_compressobj(offset, default_get_compressobj):
return _NO_COMPRESSION_STREAMED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return method_compressobj

################
# Public methods

Method = Callable[[int, Callable[[], 'zlib._Compress']], _MethodReturnValue]

def NO_COMPRESSION_32(uncompressed_size: int, crc_32: int) -> Method:
return __NO_COMPRESSION_STREAMED_32(uncompressed_size, crc_32)

def NO_COMPRESSION_64(uncompressed_size: int, crc_32: int) -> Method:
return __NO_COMPRESSION_STREAMED_64(uncompressed_size, crc_32)

def ZIP_32(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) ->_MethodReturnValue:
return _ZIP_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def ZIP_64(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _ZIP_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def ZIP_AUTO(uncompressed_size: int, level: int=9) -> Method:
def method_compressobj(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']):
# Used internally to fetch the (default) zlib Compress object
_CompressObjGetter = Callable[[], 'zlib._Compress']

# Used by the internals of stream_zip - a "public" Method is a tuple of 5 things that controls the
# format/process of making each member of the ZIP file.
_MethodTuple = Tuple[
object, # Sentinel of the methods above
object, # Sentinel of auto upgrade central directory or not
_CompressObjGetter, # Function to get the zlib Compress object for
Optional[int], # The uncompressed size of the file if known
Optional[int], # The CRC32 of the file if known
]

# A "Method" is an instance of a class that has a _get function that returns a _MethodTuple
class Method(ABC):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
pass

class _ZIP_64_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _ZIP_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

class _ZIP_32_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _ZIP_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

class _NO_COMPRESSION_32_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __call__(self, uncompressed_size, crc_32, *args: Any, **kwarg: Any) -> Method:
class _NO_COMPRESSION_32_TYPE_STREAMED_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32

return _NO_COMPRESSION_32_TYPE_STREAMED_TYPE()

class _NO_COMPRESSION_64_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_BUFFERED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __call__(self, uncompressed_size: int, crc_32: int) -> Method:
class _NO_COMPRESSION_64_TYPE_STREAMED_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_STREAMED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return _NO_COMPRESSION_64_TYPE_STREAMED_TYPE()

class _ZIP_AUTO_TYPE():
def __call__(self, uncompressed_size: int, level: int=9) -> Method:
# The limit of 4293656841 is calculated using the logic from a zlib function
# https://github.com/madler/zlib/blob/04f42ceca40f73e2978b50e93806c2a18c1281fc/deflate.c#L696
# Specifically, worked out by assuming the compressed size of a stream cannot be bigger than
Expand All @@ -73,12 +87,30 @@ def method_compressobj(offset: int, default_get_compressobj: Callable[[], 'zlib.
# https://stackoverflow.com/q/76371334/1319998
# so Python could be causing extra deflate-chunks output which could break the limit. However, couldn't
# get output of sized 4293656841 to break the Zip32 bound of 0xffffffff here for any level, including 0
method = _ZIP_64 if uncompressed_size > 4293656841 or offset > 0xffffffff else _ZIP_32
return (method, _AUTO_UPGRADE_CENTRAL_DIRECTORY, lambda: zlib.compressobj(level=level, memLevel=8, wbits=-zlib.MAX_WBITS), None, None)
return method_compressobj

class _ZIP_AUTO_TYPE_INNER(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
method = _ZIP_64 if uncompressed_size > 4293656841 or offset > 0xffffffff else _ZIP_32
return (method, _AUTO_UPGRADE_CENTRAL_DIRECTORY, lambda: zlib.compressobj(level=level, memLevel=8, wbits=-zlib.MAX_WBITS), None, None)

return _ZIP_AUTO_TYPE_INNER()


###############################
# Public sentinel objects/types

# Methods / objects that return Methods (some of these are both)
ZIP_64 = _ZIP_64_TYPE()
ZIP_32 = _ZIP_32_TYPE()
NO_COMPRESSION_32 = _NO_COMPRESSION_32_TYPE()
NO_COMPRESSION_64 = _NO_COMPRESSION_64_TYPE()
ZIP_AUTO = _ZIP_AUTO_TYPE()

# Each member file is a tuple of its name, last modified date, file mode, Method, and its bytes
MemberFile = Tuple[str, datetime, int, Method, Iterable[bytes]]


def stream_zip(files: Iterable[Tuple[str, datetime, int, Method, Iterable[bytes]]], chunk_size: int=65536,
def stream_zip(files: Iterable[MemberFile], chunk_size: int=65536,
get_compressobj=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9),
extended_timestamps: bool=True,
password: Optional[str]=None,
Expand Down Expand Up @@ -611,11 +643,7 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz
raise UncompressedSizeIntegrityError()

for name, modified_at, mode, method, chunks in files:
method = \
__NO_COMPRESSION_BUFFERED_32 if method is NO_COMPRESSION_32 else \
__NO_COMPRESSION_BUFFERED_64 if method is NO_COMPRESSION_64 else \
method
_method, _auto_upgrade_central_directory, _get_compress_obj, uncompressed_size, crc_32 = method(offset, get_compressobj)
_method, _auto_upgrade_central_directory, _get_compress_obj, uncompressed_size, crc_32 = method._get(offset, get_compressobj)

name_encoded = name.encode('utf-8')
_raise_if_beyond(len(name_encoded), maximum=0xffff, exception_class=NameLengthOverflowError)
Expand Down
Loading