diff --git a/stream_zip.py b/stream_zip.py index 766b2d0..4acf1aa 100644 --- a/stream_zip.py +++ b/stream_zip.py @@ -5,7 +5,7 @@ import asyncio import secrets import zlib -from typing import Any, Iterable, Tuple, Optional, Deque, Type, AsyncIterable, Callable +from typing import Any, Iterable, Generator, Tuple, Optional, Deque, Type, AsyncIterable, Callable from Crypto.Cipher import AES from Crypto.Hash import HMAC, SHA1 @@ -56,7 +56,7 @@ class _NO_COMPRESSION_32_TYPE(Method): def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple: return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None - def __call__(self, uncompressed_size, crc_32, *args: Any, **kwarg: Any) -> Method: + def __call__(self, uncompressed_size: int, crc_32: int, *args: Any, **kwarg: Any) -> Method: class _NO_COMPRESSION_32_TYPE_STREAMED_TYPE(Method): def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple: return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32 @@ -111,10 +111,10 @@ def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _Met def stream_zip(files: Iterable[MemberFile], chunk_size: int=65536, - get_compressobj=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9), + get_compressobj: _CompressObjGetter=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9), extended_timestamps: bool=True, password: Optional[str]=None, - get_crypto_random=lambda num_bytes: secrets.token_bytes(num_bytes), + get_crypto_random: Callable[[int], bytes]=lambda num_bytes: secrets.token_bytes(num_bytes), ) -> Iterable[bytes]: def evenly_sized(chunks: Iterable[bytes]) -> Iterable[bytes]: @@ -122,7 +122,7 @@ def evenly_sized(chunks: Iterable[bytes]) -> Iterable[bytes]: offset = 0 it = iter(chunks) - def up_to(num): + def up_to(num: int) -> Iterable[bytes]: nonlocal chunk, offset while num: @@ -195,7 +195,7 @@ def _raise_if_beyond(offset: int, maximum: int, exception_class: Type[Exception] if offset > maximum: raise exception_class() - def _with_returned(gen) -> Tuple[Callable[[], Optional[Any]], Iterable[bytes]]: + def _with_returned(gen: Generator[bytes, None, Any]) -> Tuple[Callable[[], Any], Iterable[bytes]]: # We leverage the not-often used "return value" of generators. Here, we want to iterate # over chunks (to encrypt them), but still return the same "return value". So we use a # bit of a trick to extract the return value but still have access to the chunks as @@ -208,40 +208,45 @@ def with_return_value() -> Iterable[bytes]: return ((lambda: return_value), with_return_value()) - def _encrypt_dummy(chunks): + def _encrypt_dummy(chunks: Generator[bytes, None, Any]) -> Generator[bytes, None, Any]: get_return_value, chunks_with_return = _with_returned(chunks) for chunk in chunks_with_return: yield from _(chunk) return get_return_value() - def _encrypt_aes(chunks): - key_length = 32 - salt_length = 16 - password_verification_length = 2 + # This slightly complex getter allows mypy to work out that the _encrypt_aes function is + # only called when we have a non-None password, which then passes type checking for the + # PBKDF2 function that the password is passed into + def _get_encrypt_aes(password: str) -> Callable[[Generator[bytes, None, Any]], Generator[bytes, None, Any]]: + def _encrypt_aes(chunks: Generator[bytes, None, Any]) -> Generator[bytes, None, Any]: + key_length = 32 + salt_length = 16 + password_verification_length = 2 - salt = get_crypto_random(salt_length) - yield from _(salt) + salt = get_crypto_random(salt_length) + yield from _(salt) - keys = PBKDF2(password, salt, 2 * key_length + password_verification_length, 1000) - yield from _(keys[-password_verification_length:]) + keys = PBKDF2(password, salt, 2 * key_length + password_verification_length, 1000) + yield from _(keys[-password_verification_length:]) - encrypter = AES.new( - keys[:key_length], AES.MODE_CTR, - counter=Counter.new(nbits=128, little_endian=True), - ) - hmac = HMAC.new(keys[key_length:key_length*2], digestmod=SHA1) + encrypter = AES.new( + keys[:key_length], AES.MODE_CTR, + counter=Counter.new(nbits=128, little_endian=True), + ) + hmac = HMAC.new(keys[key_length:key_length*2], digestmod=SHA1) - get_return_value, chunks_with_return = _with_returned(chunks) - for chunk in chunks_with_return: - encrypted_chunk = encrypter.encrypt(chunk) - hmac.update(encrypted_chunk) - yield from _(encrypted_chunk) + get_return_value, chunks_with_return = _with_returned(chunks) + for chunk in chunks_with_return: + encrypted_chunk = encrypter.encrypt(chunk) + hmac.update(encrypted_chunk) + yield from _(encrypted_chunk) - yield from _(hmac.digest()[:10]) + yield from _(hmac.digest()[:10]) - return get_return_value() + return get_return_value() + return _encrypt_aes - def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError) @@ -308,7 +313,7 @@ def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, nam 0xffffffff, # Offset of local header - since zip64 ), name_encoded, extra - def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError) @@ -363,7 +368,7 @@ def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, nam file_offset, ), name_encoded, extra - def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_size) -> Iterable[bytes]: + def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_size) -> Generator[bytes, None, Any]: uncompressed_size = 0 compressed_size = 0 crc_32 = zlib.crc32(b'') @@ -390,7 +395,7 @@ def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_s return uncompressed_size, compressed_size, crc_32 - def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError) @@ -452,7 +457,7 @@ def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes ), name_encoded, extra - def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError) @@ -501,7 +506,7 @@ def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes file_offset, ), name_encoded, extra - def _no_compression_buffered_data_size_crc_32(chunks, maximum_size) -> Tuple[Iterable[bytes], int, int]: + def _no_compression_buffered_data_size_crc_32(chunks: Iterable[bytes], maximum_size: int) -> Tuple[Iterable[bytes], int, int]: # We cannot have a data descriptor, and so have to be able to determine the total # length and CRC32 before output ofchunks to client code @@ -520,7 +525,7 @@ def _chunks() -> Iterable[bytes]: return chunks, size, crc_32 - def _no_compression_streamed_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _no_compression_streamed_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError) @@ -580,7 +585,7 @@ def _no_compression_streamed_64_local_header_and_data(compression, aes_size_incr ), name_encoded, extra - def _no_compression_streamed_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks): + def _no_compression_streamed_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]: file_offset = offset _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError) @@ -627,7 +632,7 @@ def _no_compression_streamed_32_local_header_and_data(compression, aes_size_incr file_offset, ), name_encoded, extra - def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_size) -> Iterable[bytes]: + def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_size) -> Generator[bytes, None, Any]: actual_crc_32 = zlib.crc32(b'') size = 0 for chunk in chunks: @@ -675,7 +680,7 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz (_no_compression_streamed_32_local_header_and_data, 0) compression, aes_size_increase, aes_flags, aes_extra, crc_32_mask, encryption_func = \ - (99, 28, aes_flag, aes_extra_struct.pack(aes_extra_signature, 7, 2, b'AE', 3, raw_compression), 0, _encrypt_aes) if password is not None else \ + (99, 28, aes_flag, aes_extra_struct.pack(aes_extra_signature, 7, 2, b'AE', 3, raw_compression), 0, _get_encrypt_aes(password)) if password is not None else \ (raw_compression, 0, 0, b'', 0xffffffff, _encrypt_dummy) central_directory_header_entry, name_encoded, extra = yield from data_func(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, evenly_sized(chunks)) @@ -752,9 +757,9 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz yield from evenly_sized(zipped_chunks) -async def async_stream_zip(member_files, *args, **kwargs) -> AsyncIterable[bytes]: +async def async_stream_zip(member_files: AsyncIterable[MemberFile], *args: Any, **kwargs: Any) -> AsyncIterable[bytes]: - async def to_async_iterable(sync_iterable) -> AsyncIterable[Any]: + async def to_async_iterable(sync_iterable: Iterable[Any]) -> AsyncIterable[Any]: # asyncio.to_thread is not available until Python 3.9, and StopIteration doesn't get # propagated by run_in_executor, so we use a sentinel to detect the end of the iterable done = object() @@ -775,7 +780,7 @@ async def to_async_iterable(sync_iterable) -> AsyncIterable[Any]: break yield value - def to_sync_iterable(async_iterable) -> Iterable[Any]: + def to_sync_iterable(async_iterable: AsyncIterable[Any]) -> Iterable[Any]: # The built-in aiter and anext functions are not available until Python 3.10 async_it = async_iterable.__aiter__() while True: