diff --git a/fastapi_cdn_host/cli.py b/fastapi_cdn_host/cli.py index 2bde7e3..b320d7a 100644 --- a/fastapi_cdn_host/cli.py +++ b/fastapi_cdn_host/cli.py @@ -7,10 +7,14 @@ from pathlib import Path from typing import Generator, Union +import anyio import typer from rich import print +from rich.progress import Progress, SpinnerColumn, TextColumn from typing_extensions import Annotated +from .client import CdnHostBuilder, HttpSniff + app = typer.Typer() @@ -68,6 +72,44 @@ def patch_app(path: Union[str, Path], remove=True) -> Generator[Path, None, None print(f"Auto remove temp file: {app_file}") +@contextmanager +def progressbar(msg, color="cyan", transient=True) -> Generator[None, None, None]: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + transient=transient, + ) as progress: + progress.add_task(f"[{color}]{msg}...", total=None) + yield + + +async def download_offline_assets(dirname="static") -> None: + cwd = await anyio.Path.cwd() + static_root = cwd / dirname + if not await static_root.exists(): + await static_root.mkdir(parents=True) + print(f"Directory {static_root} created.") + else: + async for p in static_root.glob("swagger-ui*.js"): + relative_path = p.relative_to(cwd) + print(f"{relative_path} already exists. abort!") + return + with progressbar("Comparing cdn host response speed", transient=False): + urls = await CdnHostBuilder.sniff_the_fastest() + print("Result:", urls) + with progressbar("Fetching files from cdn", color="blue"): + url_list = [urls.js, urls.css, urls.redoc] + contents = await HttpSniff.bulk_fetch(url_list, get_content=True) + for url, content in zip(url_list, contents): + if not content: + print(f"Failed to fetch content from {url}") + else: + path = static_root / Path(url).name + size = await path.write_bytes(content) + print(f"Write to {path} with {size=}") + print("Done.") + + @app.command() def dev( path: Annotated[ @@ -93,9 +135,10 @@ def dev( typer.Option(help="Whether enable production mode."), ] = False, ): - if path == "offline": + if str(path) == "offline": # TODO: download assets to local - pass + anyio.run(download_offline_assets) + return with patch_app(path, remove) as file: mode = "run" if prod else "dev" cmd = f"PYTHONPATH=. fastapi {mode} {file}" diff --git a/fastapi_cdn_host/client.py b/fastapi_cdn_host/client.py index 1223ed8..4437afe 100644 --- a/fastapi_cdn_host/client.py +++ b/fastapi_cdn_host/client.py @@ -6,7 +6,18 @@ from enum import Enum from pathlib import Path from ssl import SSLError -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + cast, + overload, +) import anyio import httpx @@ -120,10 +131,20 @@ class AssetUrl: class HttpSniff: - @staticmethod + cached: Dict[str, bytes] = {} + + @classmethod async def fetch( - client: httpx.AsyncClient, url: str, results: list, index: int + cls, + client: httpx.AsyncClient, + url: str, + results: list, + index: int, + try_cache=False, ) -> None: + if try_cache and (content := cls.cached.get(url)): + results[index] = content + return try: r = await client.get(url) except ( @@ -136,42 +157,80 @@ async def fetch( ... else: if r.status_code < 300: - results[index] = r.content + results[index] = cls.cached[url] = r.content @classmethod async def find_fastest_host( cls, urls: List[str], total_seconds=5, loop_interval=0.1 ) -> str: - if us := await cls.get_fast_hosts( + if us := await cls.bulk_fetch( urls, loop_interval, total_seconds, return_first_completed=True ): return us[0] return urls[0] + @overload @classmethod - async def get_fast_hosts( + async def bulk_fetch( cls, urls: List[str], - wait_seconds=0.8, - total_seconds=3, - return_first_completed=False, - ) -> List[str]: + wait_seconds: float = 0.8, + total_seconds: float = 3, + return_first_completed: bool = False, + get_content: Literal[False] = False, + ) -> List[str]: ... + + @overload + @classmethod + async def bulk_fetch( + cls, + urls: List[str], + wait_seconds: float = 0.8, + total_seconds: float = 3, + return_first_completed: bool = False, + get_content: Literal[True] = True, + ) -> List[bytes]: ... + + @classmethod + async def bulk_fetch( + cls, + urls: List[str], + wait_seconds: float = 0.8, + total_seconds: float = 3, + return_first_completed: bool = False, + get_content: bool = False, + ) -> Union[List[str], List[bytes]]: total = len(urls) results = [None] * total - thod = 1 if return_first_completed else total - 1 async with httpx.AsyncClient( timeout=total_seconds, follow_redirects=True ) as client: async with anyio.create_task_group() as tg: for i, url in enumerate(urls): - tg.start_soon(cls.fetch, client, url, results, i) - for _ in range(math.ceil(total_seconds / wait_seconds)): - await anyio.sleep(wait_seconds) - if sum(r is not None for r in results) >= thod: - tg.cancel_scope.cancel() - break + tg.start_soon(cls.fetch, client, url, results, i, get_content) + if not get_content: + thod = 1 if return_first_completed else total - 1 + for _ in range(math.ceil(total_seconds / wait_seconds)): + await anyio.sleep(wait_seconds) + if sum(r is not None for r in results) >= thod: + tg.cancel_scope.cancel() + break + if get_content: + return [i or b"" for i in results] return [url for url, res in zip(urls, results) if res is not None] + @classmethod + async def get_fast_hosts( + cls, + urls: List[str], + wait_seconds=0.8, + total_seconds=3, + return_first_completed=False, + ) -> List[str]: + return await cls.bulk_fetch( + urls, wait_seconds, total_seconds, return_first_completed + ) + class CdnHostBuilder: swagger_ui_version = "5"