From 073c270b8c7ae47e400fb5b488c1f5feda930345 Mon Sep 17 00:00:00 2001
From: Michal Charemza <michal@charemza.name>
Date: Sun, 26 May 2024 17:15:47 +0100
Subject: [PATCH] feat: more type annotations (towards being able to enable
 strict checking)

---
 stream_zip.py | 85 +++++++++++++++++++++++++++------------------------
 1 file changed, 45 insertions(+), 40 deletions(-)

diff --git a/stream_zip.py b/stream_zip.py
index 766b2d0..4acf1aa 100644
--- a/stream_zip.py
+++ b/stream_zip.py
@@ -5,7 +5,7 @@
 import asyncio
 import secrets
 import zlib
-from typing import Any, Iterable, Tuple, Optional, Deque, Type, AsyncIterable, Callable
+from typing import Any, Iterable, Generator, Tuple, Optional, Deque, Type, AsyncIterable, Callable
 
 from Crypto.Cipher import AES
 from Crypto.Hash import HMAC, SHA1
@@ -56,7 +56,7 @@ class _NO_COMPRESSION_32_TYPE(Method):
     def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
         return _NO_COMPRESSION_BUFFERED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, None, None
 
-    def __call__(self, uncompressed_size, crc_32, *args: Any, **kwarg: Any) -> Method:
+    def __call__(self, uncompressed_size: int, crc_32: int, *args: Any, **kwarg: Any) -> Method:
         class _NO_COMPRESSION_32_TYPE_STREAMED_TYPE(Method):
             def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _MethodTuple:
                 return _NO_COMPRESSION_STREAMED_32, _NO_AUTO_UPGRADE_CENTRAL_DIRECTORY, default_get_compressobj, uncompressed_size, crc_32
@@ -111,10 +111,10 @@ def _get(self, offset: int, default_get_compressobj: _CompressObjGetter) -> _Met
 
 
 def stream_zip(files: Iterable[MemberFile], chunk_size: int=65536,
-               get_compressobj=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9),
+               get_compressobj: _CompressObjGetter=lambda: zlib.compressobj(wbits=-zlib.MAX_WBITS, level=9),
                extended_timestamps: bool=True,
                password: Optional[str]=None,
-               get_crypto_random=lambda num_bytes: secrets.token_bytes(num_bytes),
+               get_crypto_random: Callable[[int], bytes]=lambda num_bytes: secrets.token_bytes(num_bytes),
 ) -> Iterable[bytes]:
 
     def evenly_sized(chunks: Iterable[bytes]) -> Iterable[bytes]:
@@ -122,7 +122,7 @@ def evenly_sized(chunks: Iterable[bytes]) -> Iterable[bytes]:
         offset = 0
         it = iter(chunks)
 
-        def up_to(num):
+        def up_to(num: int) -> Iterable[bytes]:
             nonlocal chunk, offset
 
             while num:
@@ -195,7 +195,7 @@ def _raise_if_beyond(offset: int, maximum: int, exception_class: Type[Exception]
             if offset > maximum:
                 raise exception_class()
 
-        def _with_returned(gen) -> Tuple[Callable[[], Optional[Any]], Iterable[bytes]]:
+        def _with_returned(gen: Generator[bytes, None, Any]) -> Tuple[Callable[[], Any], Iterable[bytes]]:
             # We leverage the not-often used "return value" of generators. Here, we want to iterate
             # over chunks (to encrypt them), but still return the same "return value". So we use a
             # bit of a trick to extract the return value but still have access to the chunks as
@@ -208,40 +208,45 @@ def with_return_value() -> Iterable[bytes]:
 
             return ((lambda: return_value), with_return_value())
 
-        def _encrypt_dummy(chunks):
+        def _encrypt_dummy(chunks: Generator[bytes, None, Any]) -> Generator[bytes, None, Any]:
             get_return_value, chunks_with_return = _with_returned(chunks)
             for chunk in chunks_with_return:
                 yield from _(chunk)
             return get_return_value()
 
-        def _encrypt_aes(chunks):
-            key_length = 32
-            salt_length = 16
-            password_verification_length = 2
+        # This slightly complex getter allows mypy to work out that the _encrypt_aes function is
+        # only called when we have a non-None password, which then passes type checking for the
+        # PBKDF2 function that the password is passed into
+        def _get_encrypt_aes(password: str) -> Callable[[Generator[bytes, None, Any]], Generator[bytes, None, Any]]:
+            def _encrypt_aes(chunks: Generator[bytes, None, Any]) -> Generator[bytes, None, Any]:
+                key_length = 32
+                salt_length = 16
+                password_verification_length = 2
 
-            salt = get_crypto_random(salt_length)
-            yield from _(salt)
+                salt = get_crypto_random(salt_length)
+                yield from _(salt)
 
-            keys = PBKDF2(password, salt, 2 * key_length + password_verification_length, 1000)
-            yield from _(keys[-password_verification_length:])
+                keys = PBKDF2(password, salt, 2 * key_length + password_verification_length, 1000)
+                yield from _(keys[-password_verification_length:])
 
-            encrypter = AES.new(
-                keys[:key_length], AES.MODE_CTR,
-                counter=Counter.new(nbits=128, little_endian=True),
-            )
-            hmac = HMAC.new(keys[key_length:key_length*2], digestmod=SHA1)
+                encrypter = AES.new(
+                    keys[:key_length], AES.MODE_CTR,
+                    counter=Counter.new(nbits=128, little_endian=True),
+                )
+                hmac = HMAC.new(keys[key_length:key_length*2], digestmod=SHA1)
 
-            get_return_value, chunks_with_return = _with_returned(chunks)
-            for chunk in chunks_with_return:
-                encrypted_chunk = encrypter.encrypt(chunk)
-                hmac.update(encrypted_chunk)
-                yield from _(encrypted_chunk)
+                get_return_value, chunks_with_return = _with_returned(chunks)
+                for chunk in chunks_with_return:
+                    encrypted_chunk = encrypter.encrypt(chunk)
+                    hmac.update(encrypted_chunk)
+                    yield from _(encrypted_chunk)
 
-            yield from _(hmac.digest()[:10])
+                yield from _(hmac.digest()[:10])
 
-            return get_return_value()
+                return get_return_value()
+            return _encrypt_aes
 
-        def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError)
@@ -308,7 +313,7 @@ def _zip_64_local_header_and_data(compression, aes_size_increase, aes_flags, nam
                 0xffffffff,   # Offset of local header - since zip64
             ), name_encoded, extra
 
-        def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError)
@@ -363,7 +368,7 @@ def _zip_32_local_header_and_data(compression, aes_size_increase, aes_flags, nam
                 file_offset,
             ), name_encoded, extra
 
-        def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_size) -> Iterable[bytes]:
+        def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_size) -> Generator[bytes, None, Any]:
             uncompressed_size = 0
             compressed_size = 0
             crc_32 = zlib.crc32(b'')
@@ -390,7 +395,7 @@ def _zip_data(chunks, _get_compress_obj, max_uncompressed_size, max_compressed_s
 
             return uncompressed_size, compressed_size, crc_32
 
-        def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError)
@@ -452,7 +457,7 @@ def _no_compression_64_local_header_and_data(compression, aes_size_increase, aes
             ), name_encoded, extra
 
 
-        def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError)
@@ -501,7 +506,7 @@ def _no_compression_32_local_header_and_data(compression, aes_size_increase, aes
                file_offset,
             ), name_encoded, extra
 
-        def _no_compression_buffered_data_size_crc_32(chunks, maximum_size) -> Tuple[Iterable[bytes], int, int]:
+        def _no_compression_buffered_data_size_crc_32(chunks: Iterable[bytes], maximum_size: int) -> Tuple[Iterable[bytes], int, int]:
             # We cannot have a data descriptor, and so have to be able to determine the total
             # length and CRC32 before output ofchunks to client code
 
@@ -520,7 +525,7 @@ def _chunks() -> Iterable[bytes]:
 
             return chunks, size, crc_32
 
-        def _no_compression_streamed_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _no_compression_streamed_64_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffffffffffff, exception_class=OffsetOverflowError)
@@ -580,7 +585,7 @@ def _no_compression_streamed_64_local_header_and_data(compression, aes_size_incr
             ), name_encoded, extra
 
 
-        def _no_compression_streamed_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks):
+        def _no_compression_streamed_32_local_header_and_data(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, chunks) -> Generator[bytes, None, Any]:
             file_offset = offset
 
             _raise_if_beyond(file_offset, maximum=0xffffffff, exception_class=OffsetOverflowError)
@@ -627,7 +632,7 @@ def _no_compression_streamed_32_local_header_and_data(compression, aes_size_incr
                file_offset,
             ), name_encoded, extra
 
-        def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_size) -> Iterable[bytes]:
+        def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_size) -> Generator[bytes, None, Any]:
             actual_crc_32 = zlib.crc32(b'')
             size = 0
             for chunk in chunks:
@@ -675,7 +680,7 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz
                 (_no_compression_streamed_32_local_header_and_data, 0)
 
             compression, aes_size_increase, aes_flags, aes_extra, crc_32_mask, encryption_func = \
-                (99, 28, aes_flag, aes_extra_struct.pack(aes_extra_signature, 7, 2, b'AE', 3, raw_compression), 0, _encrypt_aes) if password is not None else \
+                (99, 28, aes_flag, aes_extra_struct.pack(aes_extra_signature, 7, 2, b'AE', 3, raw_compression), 0, _get_encrypt_aes(password)) if password is not None else \
                 (raw_compression, 0, 0, b'', 0xffffffff, _encrypt_dummy)
 
             central_directory_header_entry, name_encoded, extra = yield from data_func(compression, aes_size_increase, aes_flags, name_encoded, mod_at_ms_dos, mod_at_unix_extra, aes_extra, external_attr, uncompressed_size, crc_32, crc_32_mask, _get_compress_obj, encryption_func, evenly_sized(chunks))
@@ -752,9 +757,9 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz
     yield from evenly_sized(zipped_chunks)
 
 
-async def async_stream_zip(member_files, *args, **kwargs) -> AsyncIterable[bytes]:
+async def async_stream_zip(member_files: AsyncIterable[MemberFile], *args: Any, **kwargs: Any) -> AsyncIterable[bytes]:
 
-    async def to_async_iterable(sync_iterable) -> AsyncIterable[Any]:
+    async def to_async_iterable(sync_iterable: Iterable[Any]) -> AsyncIterable[Any]:
         # asyncio.to_thread is not available until Python 3.9, and StopIteration doesn't get
         # propagated by run_in_executor, so we use a sentinel to detect the end of the iterable
         done = object()
@@ -775,7 +780,7 @@ async def to_async_iterable(sync_iterable) -> AsyncIterable[Any]:
                 break
             yield value
 
-    def to_sync_iterable(async_iterable) -> Iterable[Any]:
+    def to_sync_iterable(async_iterable: AsyncIterable[Any]) -> Iterable[Any]:
         # The built-in aiter and anext functions are not available until Python 3.10
         async_it = async_iterable.__aiter__()
         while True: