Skip to content

Commit

Permalink
Merge pull request #125 from uktrade/feat/tighter-method-definition
Browse files Browse the repository at this point in the history
feat: a Method has a clearer definition (that passes strict type checking)
  • Loading branch information
michalc authored May 26, 2024
2 parents 1da1475 + 05854ad commit 4f16169
Showing 1 changed file with 76 additions and 48 deletions.
124 changes: 76 additions & 48 deletions stream_zip.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime
from struct import Struct
Expand All @@ -11,8 +12,9 @@
from Crypto.Util import Counter
from Crypto.Protocol.KDF import PBKDF2

#################
# Private methods

################################
# Private sentinel objects/types

_NO_COMPRESSION_BUFFERED_32 = object()
_NO_COMPRESSION_BUFFERED_64 = object()
Expand All @@ -24,43 +26,55 @@
_AUTO_UPGRADE_CENTRAL_DIRECTORY = object()
_NO_AUTO_UPGRADE_CENTRAL_DIRECTORY = object()

_MethodReturnValue = Tuple[object, object, Callable[[], 'zlib._Compress'], Optional[int], Optional[int]]

def __NO_COMPRESSION_BUFFERED_32(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __NO_COMPRESSION_BUFFERED_64(offset, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _NO_COMPRESSION_BUFFERED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __NO_COMPRESSION_STREAMED_32(uncompressed_size, crc_32) -> 'Method':
def method_compressobj(offset, default_get_compressobj):
return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return method_compressobj

def __NO_COMPRESSION_STREAMED_64(uncompressed_size, crc_32) -> 'Method':
def method_compressobj(offset, default_get_compressobj):
return _NO_COMPRESSION_STREAMED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return method_compressobj

################
# Public methods

Method = Callable[[int, Callable[[], 'zlib._Compress']], _MethodReturnValue]

def NO_COMPRESSION_32(uncompressed_size: int, crc_32: int) -> Method:
return __NO_COMPRESSION_STREAMED_32(uncompressed_size, crc_32)

def NO_COMPRESSION_64(uncompressed_size: int, crc_32: int) -> Method:
return __NO_COMPRESSION_STREAMED_64(uncompressed_size, crc_32)

def ZIP_32(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) ->_MethodReturnValue:
return _ZIP_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def ZIP_64(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']) -> _MethodReturnValue:
return _ZIP_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def ZIP_AUTO(uncompressed_size: int, level: int=9) -> Method:
def method_compressobj(offset: int, default_get_compressobj: Callable[[], 'zlib._Compress']):
# Used internally to fetch the (default) zlib Compress object
_CompressObjGetter = Callable[[], 'zlib._Compress']

# Used by the internals of stream_zip - a "public" Method is a tuple of 5 things that controls the
# format/process of making each member of the ZIP file.
_MethodTuple = Tuple[
object, # Sentinel of the methods above
object, # Sentinel of auto upgrade central directory or not
_CompressObjGetter, # Function to get the zlib Compress object for
Optional[int], # The uncompressed size of the file if known
Optional[int], # The CRC32 of the file if known
]

# A "Method" is an instance of a class that has a _get function that returns a _MethodTuple
class Method(ABC):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
pass

class _ZIP_64_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _ZIP_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

class _ZIP_32_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _ZIP_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

class _NO_COMPRESSION_32_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __call__(self, uncompressed_size, crc_32, *args: Any, **kwarg: Any) -> Method:
class _NO_COMPRESSION_32_TYPE_STREAMED_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32

return _NO_COMPRESSION_32_TYPE_STREAMED_TYPE()

class _NO_COMPRESSION_64_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_BUFFERED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None

def __call__(self, uncompressed_size: int, crc_32: int) -> Method:
class _NO_COMPRESSION_64_TYPE_STREAMED_TYPE(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
return _NO_COMPRESSION_STREAMED_64, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
return _NO_COMPRESSION_64_TYPE_STREAMED_TYPE()

class _ZIP_AUTO_TYPE():
def __call__(self, uncompressed_size: int, level: int=9) -> Method:
# The limit of 4293656841 is calculated using the logic from a zlib function
# https://github.com/madler/zlib/blob/04f42ceca40f73e2978b50e93806c2a18c1281fc/deflate.c#L696
# Specifically, worked out by assuming the compressed size of a stream cannot be bigger than
Expand All @@ -73,12 +87,30 @@ def method_compressobj(offset: int, default_get_compressobj: Callable[[], 'zlib.
# https://stackoverflow.com/q/76371334/1319998
# so Python could be causing extra deflate-chunks output which could break the limit. However, couldn't
# get output of sized 4293656841 to break the Zip32 bound of 0xffffffff here for any level, including 0
method = _ZIP_64 if uncompressed_size > 4293656841 or offset > 0xffffffff else _ZIP_32
return (method, _AUTO_UPGRADE_CENTRAL_DIRECTORY, lambda: zlib.compressobj(level=level, memLevel=8, wbits=-zlib.MAX_WBITS), None, None)
return method_compressobj

class _ZIP_AUTO_TYPE_INNER(Method):
def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
method = _ZIP_64 if uncompressed_size > 4293656841 or offset > 0xffffffff else _ZIP_32
return (method, _AUTO_UPGRADE_CENTRAL_DIRECTORY, lambda: zlib.compressobj(level=level, memLevel=8, wbits=-zlib.MAX_WBITS), None, None)

return _ZIP_AUTO_TYPE_INNER()


###############################
# Public sentinel objects/types

# Methods / objects that return Methods (some of these are both)
ZIP_64 = _ZIP_64_TYPE()
ZIP_32 = _ZIP_32_TYPE()
NO_COMPRESSION_32 = _NO_COMPRESSION_32_TYPE()
NO_COMPRESSION_64 = _NO_COMPRESSION_64_TYPE()
ZIP_AUTO = _ZIP_AUTO_TYPE()

# Each member file is a tuple of its name, last modified date, file mode, Method, and its bytes
MemberFile = Tuple[str, datetime, int, Method, Iterable[bytes]]


def stream_zip(files: Iterable[Tuple[str, datetime, int, Method, Iterable[bytes]]], chunk_size: int=65536,
def stream_zip(files: Iterable[MemberFile], chunk_size: int=65536,
get_compressobj=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9),
extended_timestamps: bool=True,
password: Optional[str]=None,
Expand Down Expand Up @@ -611,11 +643,7 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz
raise UncompressedSizeIntegrityError()

for name, modified_at, mode, method, chunks in files:
method = \
__NO_COMPRESSION_BUFFERED_32 if method is NO_COMPRESSION_32 else \
__NO_COMPRESSION_BUFFERED_64 if method is NO_COMPRESSION_64 else \
method
_method, _auto_upgrade_central_directory, _get_compress_obj, uncompressed_size, crc_32 = method(offset, get_compressobj)
_method, _auto_upgrade_central_directory, _get_compress_obj, uncompressed_size, crc_32 = method._get(offset, get_compressobj)

name_encoded = name.encode('utf-8')
_raise_if_beyond(len(name_encoded), maximum=0xffff, exception_class=NameLengthOverflowError)
Expand Down

0 comments on commit 4f16169

Please sign in to comment.