Skip to content

Commit

Permalink
feat: support offline command to download assets
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Jul 25, 2024
1 parent 2274037 commit 49c7014
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 19 deletions.
47 changes: 45 additions & 2 deletions fastapi_cdn_host/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
from pathlib import Path
from typing import Generator, Union

import anyio
import typer
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn
from typing_extensions import Annotated

from .client import CdnHostBuilder, HttpSniff

app = typer.Typer()


Expand Down Expand Up @@ -68,6 +72,44 @@ def patch_app(path: Union[str, Path], remove=True) -> Generator[Path, None, None
print(f"Auto remove temp file: {app_file}")


@contextmanager
def progressbar(msg, color="cyan", transient=True) -> Generator[None, None, None]:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=transient,
) as progress:
progress.add_task(f"[{color}]{msg}...", total=None)
yield


async def download_offline_assets(dirname="static") -> None:
cwd = await anyio.Path.cwd()
static_root = cwd / dirname
if not await static_root.exists():
await static_root.mkdir(parents=True)
print(f"Directory {static_root} created.")
else:
async for p in static_root.glob("swagger-ui*.js"):
relative_path = p.relative_to(cwd)
print(f"{relative_path} already exists. abort!")
return
with progressbar("Comparing cdn host response speed", transient=False):
urls = await CdnHostBuilder.sniff_the_fastest()
print("Result:", urls)
with progressbar("Fetching files from cdn", color="blue"):
url_list = [urls.js, urls.css, urls.redoc]
contents = await HttpSniff.bulk_fetch(url_list, get_content=True)
for url, content in zip(url_list, contents):
if not content:
print(f"Failed to fetch content from {url}")
else:
path = static_root / Path(url).name
size = await path.write_bytes(content)
print(f"Write to {path} with {size=}")
print("Done.")


@app.command()
def dev(
path: Annotated[
Expand All @@ -93,9 +135,10 @@ def dev(
typer.Option(help="Whether enable production mode."),
] = False,
):
if path == "offline":
if str(path) == "offline":
# TODO: download assets to local
pass
anyio.run(download_offline_assets)
return
with patch_app(path, remove) as file:
mode = "run" if prod else "dev"
cmd = f"PYTHONPATH=. fastapi {mode} {file}"
Expand Down
93 changes: 76 additions & 17 deletions fastapi_cdn_host/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from enum import Enum
from pathlib import Path
from ssl import SSLError
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
overload,
)

import anyio
import httpx
Expand Down Expand Up @@ -120,10 +131,20 @@ class AssetUrl:


class HttpSniff:
@staticmethod
cached: Dict[str, bytes] = {}

@classmethod
async def fetch(
client: httpx.AsyncClient, url: str, results: list, index: int
cls,
client: httpx.AsyncClient,
url: str,
results: list,
index: int,
try_cache=False,
) -> None:
if try_cache and (content := cls.cached.get(url)):
results[index] = content
return
try:
r = await client.get(url)
except (
Expand All @@ -136,42 +157,80 @@ async def fetch(
...
else:
if r.status_code < 300:
results[index] = r.content
results[index] = cls.cached[url] = r.content

@classmethod
async def find_fastest_host(
cls, urls: List[str], total_seconds=5, loop_interval=0.1
) -> str:
if us := await cls.get_fast_hosts(
if us := await cls.bulk_fetch(
urls, loop_interval, total_seconds, return_first_completed=True
):
return us[0]
return urls[0]

@overload
@classmethod
async def get_fast_hosts(
async def bulk_fetch(
cls,
urls: List[str],
wait_seconds=0.8,
total_seconds=3,
return_first_completed=False,
) -> List[str]:
wait_seconds: float = 0.8,
total_seconds: float = 3,
return_first_completed: bool = False,
get_content: Literal[False] = False,
) -> List[str]: ...

@overload
@classmethod
async def bulk_fetch(
cls,
urls: List[str],
wait_seconds: float = 0.8,
total_seconds: float = 3,
return_first_completed: bool = False,
get_content: Literal[True] = True,
) -> List[bytes]: ...

@classmethod
async def bulk_fetch(
cls,
urls: List[str],
wait_seconds: float = 0.8,
total_seconds: float = 3,
return_first_completed: bool = False,
get_content: bool = False,
) -> Union[List[str], List[bytes]]:
total = len(urls)
results = [None] * total
thod = 1 if return_first_completed else total - 1
async with httpx.AsyncClient(
timeout=total_seconds, follow_redirects=True
) as client:
async with anyio.create_task_group() as tg:
for i, url in enumerate(urls):
tg.start_soon(cls.fetch, client, url, results, i)
for _ in range(math.ceil(total_seconds / wait_seconds)):
await anyio.sleep(wait_seconds)
if sum(r is not None for r in results) >= thod:
tg.cancel_scope.cancel()
break
tg.start_soon(cls.fetch, client, url, results, i, get_content)
if not get_content:
thod = 1 if return_first_completed else total - 1
for _ in range(math.ceil(total_seconds / wait_seconds)):
await anyio.sleep(wait_seconds)
if sum(r is not None for r in results) >= thod:
tg.cancel_scope.cancel()
break
if get_content:
return [i or b"" for i in results]
return [url for url, res in zip(urls, results) if res is not None]

@classmethod
async def get_fast_hosts(
cls,
urls: List[str],
wait_seconds=0.8,
total_seconds=3,
return_first_completed=False,
) -> List[str]:
return await cls.bulk_fetch(
urls, wait_seconds, total_seconds, return_first_completed
)


class CdnHostBuilder:
swagger_ui_version = "5"
Expand Down

0 comments on commit 49c7014

Please sign in to comment.