From 5abae3918b37fb4042c23b8c69fe9f8ec1819bb0 Mon Sep 17 00:00:00 2001 From: Michal Charemza Date: Sat, 25 May 2024 17:50:33 +0100 Subject: [PATCH] feat: type annotate async function (and an internal function) This adds basic type annotation to async_stream_zip (and another internal function). And there is a bit of rejigging to get types to check, but other than adding return types the client-facing API should be the same. This is inspired by https://github.com/uktrade/stream-zip/discussions/120 There will probably be more changes after this PR. --- stream_zip.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/stream_zip.py b/stream_zip.py index 278bc95..761f150 100644 --- a/stream_zip.py +++ b/stream_zip.py @@ -3,7 +3,7 @@ import asyncio import secrets import zlib -from typing import Any, Iterable, Tuple, Optional, Deque, Type, AsyncIterable +from typing import Any, Iterable, Tuple, Optional, Deque, Type, AsyncIterable, Callable from Crypto.Cipher import AES from Crypto.Hash import HMAC, SHA1 @@ -156,7 +156,7 @@ def _raise_if_beyond(offset: int, maximum: int, exception_class: Type[Exception] if offset > maximum: raise exception_class() - def _with_returned(gen): + def _with_returned(gen) -> Tuple[Callable[[], Optional[Any]], Iterable[bytes]]: # We leverage the not-often used "return value" of generators. Here, we want to iterate # over chunks (to encrypt them), but still return the same "return value". So we use a # bit of a trick to extract the return value but still have access to the chunks as @@ -719,7 +719,7 @@ def _no_compression_streamed_data(chunks, uncompressed_size, crc_32, maximum_siz async def async_stream_zip(member_files, *args, **kwargs) -> AsyncIterable[bytes]: - async def to_async_iterable(sync_iterable): + async def to_async_iterable(sync_iterable) -> AsyncIterable[Any]: # asyncio.to_thread is not available until Python 3.9, and StopIteration doesn't get # propagated by run_in_executor, so we use a sentinel to detect the end of the iterable done = object() @@ -729,17 +729,18 @@ async def to_async_iterable(sync_iterable): try: import contextvars except ImportError: - get_args = lambda: (next, it, done) + get_func_args: Callable[[], Tuple[Callable[..., Any], Tuple[Any, ...]]] = lambda: (next, (it, done)) else: - get_args = lambda: (contextvars.copy_context().run, next, it, done) + get_func_args = lambda: (contextvars.copy_context().run, (next, it, done)) while True: - value = await loop.run_in_executor(None, *get_args()) + func, args = get_func_args() + value = await loop.run_in_executor(None, func, *args) if value is done: break yield value - def to_sync_iterable(async_iterable): + def to_sync_iterable(async_iterable) -> Iterable[Any]: # The built-in aiter and anext functions are not available until Python 3.10 async_it = async_iterable.__aiter__() while True: