diff --git a/cftool/web.py b/cftool/web.py index ed8f909..ae1b86c 100644 --- a/cftool/web.py +++ b/cftool/web.py @@ -1,14 +1,30 @@ +import json +import time import socket import logging +import requests +from io import BytesIO from typing import Any from typing import Dict from typing import Type +from typing import TypeVar +from typing import Callable from typing import Optional +from typing import Awaitable from .misc import get_err_msg from .constants import WEB_ERR_CODE +try: + from PIL import Image + from PIL import ImageOps +except: + + class Image: + Image = None + + ImageOps = None try: from fastapi import Response from fastapi import HTTPException @@ -16,6 +32,13 @@ except: Response = HTTPException = None BaseModel = object +try: + from aiohttp import ClientSession +except: + ClientSession = None + + +TResponse = TypeVar("TResponse") class RuntimeError(BaseModel): @@ -72,3 +95,98 @@ def raise_err(err: Exception) -> None: if HTTPException is None: raise raise HTTPException(status_code=WEB_ERR_CODE, detail=get_err_msg(err)) + + +async def get(url: str, session: ClientSession, **kwargs: Any) -> bytes: + async with session.get(url, **kwargs) as response: + return await response.read() + + +async def post( + url: str, + json: Dict[str, Any], + session: ClientSession, + **kwargs: Any, +) -> Dict[str, Any]: + async with session.post(url, json=json, **kwargs) as response: + return await response.json() + + +def log_endpoint(endpoint: str, data: BaseModel) -> None: + msg = f"{endpoint} endpoint entered with kwargs : {json.dumps(data.dict(), ensure_ascii=False)}" + logging.debug(msg) + + +def log_times(endpoint: str, times: Dict[str, float]) -> None: + times["__total__"] = sum(times.values()) + logging.debug(f"elapsed time of endpoint {endpoint} : {json.dumps(times)}") + + +async def download_raw(session: ClientSession, url: str, **kw: Any) -> bytes: + try: + return await get(url, session, **kw) + except Exception: + return requests.get(url, **kw).content + + +async def download_image(session: ClientSession, url: str, **kw: Any) -> Image.Image: + raw_data = None + try: + raw_data = await download_raw(session, url, **kw) + image = Image.open(BytesIO(raw_data)) + try: + image = ImageOps.exif_transpose(image) + finally: + return image + except Exception as err: + if raw_data is None: + msg = f"raw | None | err | {err}" + else: + try: + msg = raw_data.decode("utf-8") + except: + msg = f"raw | {raw_data[:20]} | err | {err}" + raise ValueError(msg) + + +async def retry_with( + download_fn: Callable[[ClientSession, str], Awaitable[TResponse]], + session: ClientSession, + url: str, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> TResponse: + msg = "" + for i in range(retry): + try: + res = await download_fn(session, url, **kw) + if i > 0: + logging.warning(f"succeeded after {i} retries") + return res + except Exception as err: + msg = str(err) + time.sleep(interval) + raise ValueError(f"{msg}\n(After {retry} retries)") + + +async def download_raw_with_retry( + session: ClientSession, + url: str, + *, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> bytes: + return await retry_with(download_raw, session, url, retry, interval, **kw) + + +async def download_image_with_retry( + session: ClientSession, + url: str, + *, + retry: int = 3, + interval: int = 1, + **kw: Any, +) -> Image.Image: + return await retry_with(download_image, session, url, retry, interval, **kw)