diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py new file mode 100644 index 0000000..b0125a5 --- /dev/null +++ b/olah/proxy/tree.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import os +import shutil +import tempfile +from typing import Dict, Literal +from urllib.parse import urljoin +from fastapi import FastAPI, Request + +import httpx +from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT + +from olah.utils.rule_utils import check_cache_rules_hf +from olah.utils.repo_utils import get_org_repo +from olah.utils.file_utils import make_dirs + + +async def tree_cache_generator(save_path: str): + yield {} + with open(save_path, "rb") as f: + while True: + chunk = f.read(CHUNK_SIZE) + if not chunk: + break + yield chunk + + +async def tree_proxy_cache( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + request: Request, +): + # save + repos_path = app.app_settings.repos_path + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, "tree.json") + make_dirs(save_path) + + # url + org_repo = get_org_repo(org, repo) + tree_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/tree/{commit}", + ) + headers = {} + if "authorization" in request.headers: + headers["authorization"] = request.headers["authorization"] + async with httpx.AsyncClient() as client: + response = await client.request( + method="GET", + url=tree_url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + follow_redirects=True, + ) + if response.status_code == 200: + with open(save_path, "wb") as tree_file: + tree_file.write(response.content) + else: + raise Exception( + f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" + ) + + +async def tree_proxy_generator( + app: FastAPI, + headers: Dict[str, str], + tree_url: str, + allow_cache: bool, + save_path: str, +): + async with httpx.AsyncClient(follow_redirects=True) as client: + content_chunks = [] + async with client.stream( + method="GET", + url=tree_url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + yield response_headers + + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + content_chunks.append(raw_chunk) + yield raw_chunk + + content = bytearray() + for chunk in content_chunks: + content += chunk + with open(save_path, "wb") as f: + f.write(bytes(content)) + + +async def tree_generator( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + request: Request, +): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + + # save + repos_path = app.app_settings.repos_path + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, "tree.json") + make_dirs(save_path) + + use_cache = os.path.exists(save_path) + allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) + + org_repo = get_org_repo(org, repo) + tree_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/tree/{commit}", + ) + # proxy + if use_cache: + async for item in tree_cache_generator(app, save_path): + yield item + else: + async for item in tree_proxy_generator( + app, headers, tree_url, allow_cache, save_path + ): + yield item diff --git a/olah/server.py b/olah/server.py index 60e4083..9a127c6 100644 --- a/olah/server.py +++ b/olah/server.py @@ -26,6 +26,8 @@ import git import httpx +from olah.proxy.tree import tree_generator, tree_proxy_cache + BASE_SETTINGS = False if not BASE_SETTINGS: try: @@ -184,6 +186,16 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request ) +@app.get("/api/{repo_type}/{org}/{repo}") +async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): + if not app.app_settings.config.offline: + new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + else: + new_commit = "main" + return await meta_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request + ) + @app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") async def meta_proxy_commit2( @@ -204,6 +216,80 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: repo_type=repo_type, org=org, repo=repo, commit=commit, request=request ) +# Git Tree +async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: + # TODO: the head method of meta apis + # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + if repo_type not in REPO_TYPES_MAPPING.keys(): + return error_page_not_found() + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + meta_data = local_repo.get_tree(commit) + if meta_data is None: + continue + return JSONResponse(content=meta_data) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Meta + try: + if not app.app_settings.config.offline and not await check_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ): + return error_repo_not_found() + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) + if commit_sha is None: + return error_repo_not_found() + # if branch name and online mode, refresh branch info + if not app.app_settings.config.offline and commit_sha != commit: + await tree_proxy_cache(app, repo_type, org, repo, commit, request) + generator = tree_generator(app, repo_type, org, repo, commit_sha, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + except httpx.ConnectTimeout: + traceback.print_exc() + return Response(status_code=504) + + +@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}") +async def tree_proxy_commit2( + repo_type: str, org: str, repo: str, commit: str, request: Request +): + return await tree_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) + + +@app.get("/api/{repo_type}/{org_repo}/tree/{commit}") +async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return error_repo_not_found() + + return await tree_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) + + # ====================== # Authentication API Hooks # ======================