diff --git a/olah/errors.py b/olah/errors.py index 77491f7..7bb8b24 100644 --- a/olah/errors.py +++ b/olah/errors.py @@ -20,8 +20,9 @@ def error_repo_not_found() -> JSONResponse: ) -def error_page_not_found() -> Response: - return Response( +def error_page_not_found() -> JSONResponse: + return JSONResponse( + content={"error":"Sorry, we can't find the page you are looking for."}, headers={ "x-error-code": "RepoNotFound", "x-error-message": "Sorry, we can't find the page you are looking for.", @@ -29,7 +30,7 @@ def error_page_not_found() -> Response: status_code=404, ) -def error_entry_not_found(branch: str, path: str) -> Response: +def error_entry_not_found_branch(branch: str, path: str) -> Response: return Response( headers={ "x-error-code": "EntryNotFound", @@ -38,6 +39,15 @@ def error_entry_not_found(branch: str, path: str) -> Response: status_code=404, ) +def error_entry_not_found() -> Response: + return Response( + headers={ + "x-error-code": "EntryNotFound", + "x-error-message": "Entry not found", + }, + status_code=404, + ) + def error_revision_not_found(revision: str) -> Response: return JSONResponse( content={"error": f"Invalid rev id: {revision}"}, @@ -47,3 +57,22 @@ def error_revision_not_found(revision: str) -> Response: }, status_code=404, ) + +# Olah Custom Messages +def error_proxy_timeout() -> Response: + return Response( + headers={ + "x-error-code": "ProxyTimeout", + "x-error-message": "Proxy Timeout", + }, + status_code=504, + ) + +def error_proxy_invalid_data() -> Response: + return Response( + headers={ + "x-error-code": "ProxyInvalidData", + "x-error-message": "Proxy Invalid Data", + }, + status_code=504, + ) \ No newline at end of file diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py index e27e8a5..3bba7c6 100644 --- a/olah/mirror/repos.py +++ b/olah/mirror/repos.py @@ -236,6 +236,30 @@ def get_tree( for r in items: r.pop("name") return items + + def get_commits(self, commit_hash: str) -> Optional[Dict[str, Any]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + parent_commits: List[Commit] = list(commit.parents) + parent_commits = parent_commits.insert(0, commit) + items = [] + for each_commit in parent_commits: + item = { + "id": each_commit.hexsha, + "title": each_commit.message, + "message": "", + "authors": [], + "date": each_commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + } + item["authors"].append({ + "name": each_commit.author.name, + "avatar": None + }) + items.append(item) + return items def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: try: diff --git a/olah/proxy/commits.py b/olah/proxy/commits.py new file mode 100644 index 0000000..3b6db1f --- /dev/null +++ b/olah/proxy/commits.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import os +from typing import Dict, Literal, Mapping +from urllib.parse import urljoin +from fastapi import FastAPI, Request + +import httpx +from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT + +from olah.utils.cache_utils import read_cache_request, write_cache_request +from olah.utils.rule_utils import check_cache_rules_hf +from olah.utils.repo_utils import get_org_repo +from olah.utils.file_utils import make_dirs + + +async def _commits_cache_generator(save_path: str): + cache_rq = await read_cache_request(save_path) + yield cache_rq["status_code"] + yield cache_rq["headers"] + yield cache_rq["content"] + + +async def _commits_proxy_generator( + app: FastAPI, + headers: Dict[str, str], + commits_url: str, + method: str, + params: Mapping[str, str], + allow_cache: bool, + save_path: str, +): + async with httpx.AsyncClient(follow_redirects=True) as client: + content_chunks = [] + async with client.stream( + method=method, + url=commits_url, + params=params, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_status_code = response.status_code + response_headers = response.headers + yield response_status_code + yield response_headers + + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + content_chunks.append(raw_chunk) + yield raw_chunk + + content = bytearray() + for chunk in content_chunks: + content += chunk + + if allow_cache and response_status_code == 200: + make_dirs(save_path) + await write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) + + +async def commits_generator( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + override_cache: bool, + request: Request, +): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + + # save + method = request.method.lower() + repos_path = app.app_settings.config.repos_path + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/commits/{commit}" + ) + save_path = os.path.join(save_dir, f"commits_{method}.json") + + use_cache = os.path.exists(save_path) + allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) + + org_repo = get_org_repo(org, repo) + commits_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/commits/{commit}", + ) + # proxy + if use_cache and not override_cache: + async for item in _commits_cache_generator(save_path): + yield item + else: + async for item in _commits_proxy_generator( + app, headers, commits_url, method, {}, allow_cache, save_path + ): + yield item diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 0240d44..b9e4ce3 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -26,6 +26,7 @@ ORIGINAL_LOC, ) from olah.cache.olah_cache import OlahCache +from olah.errors import error_entry_not_found, error_proxy_invalid_data, error_proxy_timeout from olah.proxy.pathsinfo import pathsinfo_generator from olah.utils.cache_utils import read_cache_request, write_cache_request from olah.utils.disk_utils import touch_file_access_time @@ -474,27 +475,38 @@ async def _file_realtime_stream( generator = pathsinfo_generator(app, repo_type, org, repo, commit, [file_path], override_cache=False, method="post") + status_code = await generator.__anext__() headers = await generator.__anext__() content = await generator.__anext__() try: pathsinfo = json.loads(content) except json.JSONDecodeError: - yield 504 - yield {} - yield b"" + response = error_proxy_invalid_data() + yield response.status_code + yield response.headers + yield response.body + return + + if len(pathsinfo) == 0: + response = error_entry_not_found() + yield response.status_code + yield response.headers + yield response.body return if len(pathsinfo) != 1: - yield 504 - yield {} - yield b"" + response = error_proxy_timeout() + yield response.status_code + yield response.headers + yield response.body return pathinfo = pathsinfo[0] if "size" not in pathinfo: - yield 504 - yield {} - yield b"" + response = error_proxy_timeout() + yield response.status_code + yield response.headers + yield response.body return file_size = pathinfo["size"] diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py index 29643e9..0c65643 100644 --- a/olah/proxy/pathsinfo.py +++ b/olah/proxy/pathsinfo.py @@ -102,5 +102,6 @@ async def pathsinfo_generator( if status == 200 and isinstance(content_json, list): final_content.extend(content_json) + yield 200 yield {'content-type': 'application/json'} yield json.dumps(final_content, ensure_ascii=True) diff --git a/olah/server.py b/olah/server.py index 1c0056e..9da33e8 100644 --- a/olah/server.py +++ b/olah/server.py @@ -28,6 +28,7 @@ import git import httpx +from olah.proxy.commits import commits_generator from olah.proxy.pathsinfo import pathsinfo_generator from olah.proxy.tree import tree_generator from olah.utils.disk_utils import convert_bytes_to_human_readable, convert_to_bytes, get_folder_size, sort_files_by_access_time, sort_files_by_modify_time, sort_files_by_size @@ -67,6 +68,8 @@ from olah.constants import REPO_TYPES_MAPPING from olah.utils.logging import build_logger +logger = None + # ====================== # Utilities # ====================== @@ -166,6 +169,15 @@ class AppSettings(BaseSettings): # The address of the model controller. config: OlahConfig = OlahConfig() + +# ====================== +# Exception handlers +# ====================== +@app.exception_handler(404) +async def custom_404_handler(_, __): + return error_page_not_found() + + # ====================== # File Meta Info API Hooks # See also: https://huggingface.co/docs/hub/api#repo-listing-api @@ -417,7 +429,7 @@ async def tree_proxy_commit( request=request, ) - +# Git Pathsinfo async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: str, paths: List[str], request: Request) -> Response: # TODO: the head method of meta apis # FIXME: do not show the private repos to other user besides owner, even though the repo was cached @@ -440,7 +452,7 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st logger.warning(f"Local repository {git_path} is not a valid git reposity.") continue - # Proxy the HF File Meta + # Proxy the HF File pathsinfo try: if not app.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, @@ -460,8 +472,9 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=True, method=request.method.lower()) else: generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=False, method=request.method.lower()) + status_code = await generator.__anext__() headers = await generator.__anext__() - return StreamingResponse(generator, headers=headers) + return StreamingResponse(generator, status_code=status_code, headers=headers) except httpx.ConnectTimeout: traceback.print_exc() return Response(status_code=504) @@ -487,6 +500,92 @@ async def pathsinfo_proxy_commit(repo_type: str, org_repo: str, commit: str, pat repo_type=repo_type, org=org, repo=repo, commit=commit, paths=paths, request=request ) +# Git Commits +async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: + # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + if repo_type not in REPO_TYPES_MAPPING.keys(): + return error_page_not_found() + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + commits_data = local_repo.get_commits(commit) + if commits_data is None: + continue + return JSONResponse(content=commits_data) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Commits + try: + if not app.app_settings.config.offline: + if not await check_commit_hf(app, repo_type, org, repo, commit=None, + authorization=request.headers.get("authorization", None), + ): + return error_repo_not_found() + if not await check_commit_hf(app, repo_type, org, repo, commit=commit, + authorization=request.headers.get("authorization", None), + ): + return error_revision_not_found(revision=commit) + commit_sha = await get_commit_hf(app, repo_type, org, repo, commit=commit, + authorization=request.headers.get("authorization", None), + ) + if commit_sha is None: + return error_repo_not_found() + # if branch name and online mode, refresh branch info + if not app.app_settings.config.offline and commit_sha != commit: + generator = commits_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit=commit_sha, + override_cache=True, + request=request, + ) + else: + generator = commits_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit=commit_sha, + override_cache=False, + request=request, + ) + status_code = await generator.__anext__() + headers = await generator.__anext__() + return StreamingResponse(generator, status_code=status_code, headers=headers) + except httpx.ConnectTimeout: + traceback.print_exc() + return Response(status_code=504) + + +@app.head("/api/{repo_type}/{org}/{repo}/commits/{commit}") +@app.get("/api/{repo_type}/{org}/{repo}/commits/{commit}") +async def commits_proxy_commit2( + repo_type: str, org: str, repo: str, commit: str, request: Request +): + return await commits_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) + + +@app.head("/api/{repo_type}/{org_repo}/commits/{commit}") +@app.get("/api/{repo_type}/{org_repo}/commits/{commit}") +async def commits_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return error_repo_not_found() + + return await commits_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) # ====================== # Authentication API Hooks