diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py index fbae83e..3b3527c 100644 --- a/olah/mirror/repos.py +++ b/olah/mirror/repos.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Union import gitdb from git import Commit, Optional, Repo, Tree +from git.objects.base import IndexObjUnion from gitdb.base import OStream import yaml @@ -65,66 +66,75 @@ def _get_description(self, commit: Commit) -> str: readme = self._get_readme(commit) return self._remove_card(readme) - def _get_tree_files_recursive(self, tree, include_dir=False) -> List[str]: + def _get_tree_filenames_recursive(self, tree, include_dir=False) -> List[str]: out_paths = [] for entry in tree: if entry.type == "tree": - out_paths.extend(self._get_tree_files_recursive(entry)) + out_paths.extend(self._get_tree_filenames_recursive(entry)) if include_dir: out_paths.append(entry.path) else: out_paths.append(entry.path) return out_paths - def _get_commit_files_recursive(self, commit: Commit) -> List[str]: - return self._get_tree_files_recursive(commit.tree) + def _get_commit_filenames_recursive(self, commit: Commit) -> List[str]: + return self._get_tree_filenames_recursive(commit.tree) - def _get_tree_files(self, tree: Tree) -> List[Dict[str, Union[int, str]]]: - entries = [] - for entry in tree: - lfs = False - if entry.type != "tree": - t = "file" - repr_size = entry.size - if repr_size > 120 and repr_size < 150: - # check lfs - lfs_data = entry.data_stream.read().decode("utf-8") - match_groups = re.match( - r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", - lfs_data, - ) - if match_groups is not None: - lfs = True - sha256 = match_groups.group(1) - repr_size = int(match_groups.group(2)) - lfs_data = { - "oid": sha256, - "size": repr_size, - "pointerSize": entry.size, - } - else: - t = "directory" - repr_size = 0 - - if not lfs: - entries.append( - { - "type": t, - "oid": entry.hexsha, - "size": repr_size, - "path": entry.name, - } + def _get_path_info(self, entry: IndexObjUnion) -> Dict[str, Union[int, str]]: + lfs = False + if entry.type != "tree": + t = "file" + repr_size = entry.size + if repr_size > 120 and repr_size < 150: + # check lfs + lfs_data = entry.data_stream.read().decode("utf-8") + match_groups = re.match( + r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", + lfs_data, ) - else: - entries.append( - { - "type": t, - "oid": entry.hexsha, + if match_groups is not None: + lfs = True + sha256 = match_groups.group(1) + repr_size = int(match_groups.group(2)) + lfs_data = { + "oid": sha256, "size": repr_size, - "path": entry.name, - "lfs": lfs_data, + "pointerSize": entry.size, } - ) + else: + t = "directory" + repr_size = entry.size + + if not lfs: + item = { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.path, + "name": entry.name, + } + else: + item = { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.path, + "name": entry.name, + "lfs": lfs_data, + } + return item + + def _get_tree_files( + self, tree: Tree, recursive: bool = False + ) -> List[Dict[str, Union[int, str]]]: + entries = [] + for entry in tree: + entries.append(self._get_path_info(entry=entry)) + + if recursive: + for entry in tree: + if entry.type == "tree": + entries.extend(self._get_tree_files(entry, recursive=recursive)) return entries def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]: @@ -143,24 +153,67 @@ def _get_earliest_commit(self) -> Commit: return earliest_commit - def get_tree(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]: + def get_index_object_by_path( + self, commit_hash: str, path: str + ) -> Optional[IndexObjUnion]: try: commit = self._git_repo.commit(commit_hash) except gitdb.exc.BadName: return None - path_part = path.split("/") + path_part = [part for part in path_part if len(part.strip()) != 0] tree = commit.tree items = self._get_tree_files(tree=tree) - for part in path_part: - if len(part.strip()) == 0: - continue - if part not in [ - item["path"] for item in items if item["type"] == "directory" - ]: - return None + if len(path_part) == 0: + return None + for i, part in enumerate(path_part): + if i != len(path_part) - 1: + if part not in [ + item["name"] for item in items if item["type"] == "directory" + ]: + return None + else: + if part not in [ + item["name"] for item in items + ]: + return None tree = tree[part] - items = self._get_tree_files(tree=tree) + if tree.type == "tree": + items = self._get_tree_files(tree=tree, recursive=False) + return tree + + def get_pathinfos( + self, commit_hash: str, paths: List[str] + ) -> Optional[List[Dict[str, Any]]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + results = [] + for path in paths: + index_obj = self.get_index_object_by_path( + commit_hash=commit_hash, path=path + ) + if index_obj is not None: + results.append(self._get_path_info(index_obj)) + + for r in results: + r.pop("name") + return results + + def get_tree( + self, commit_hash: str, path: str, recursive: bool = False + ) -> Optional[Dict[str, Any]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + index_obj = self.get_index_object_by_path(commit_hash=commit_hash, path=path) + items = self._get_tree_files(tree=index_obj, recursive=recursive) + for r in items: + r.pop("name") return items def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: @@ -189,7 +242,7 @@ def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: self._match_card(self._get_readme(commit)), Loader=yaml.CLoader ) meta.siblings = [ - {"rfilename": p} for p in self._get_commit_files_recursive(commit) + {"rfilename": p} for p in self._get_commit_filenames_recursive(commit) ] meta.createdAt = self._get_earliest_commit().committed_datetime.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py new file mode 100644 index 0000000..e627f68 --- /dev/null +++ b/olah/proxy/pathsinfo.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import json +import os +from typing import Dict, List, Literal +from urllib.parse import quote, urljoin +from fastapi import FastAPI, Request + +import httpx +from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT + +from olah.utils.cache_utils import _read_cache_request, _write_cache_request +from olah.utils.rule_utils import check_cache_rules_hf +from olah.utils.repo_utils import get_org_repo +from olah.utils.file_utils import make_dirs + + +async def _pathsinfo_cache(save_path: str): + cache_rq = await _read_cache_request(save_path) + return cache_rq["status_code"], cache_rq["headers"], cache_rq["content"] + + +async def _pathsinfo_proxy( + app: FastAPI, + headers: Dict[str, str], + pathsinfo_url: str, + allow_cache: bool, + method: str, + path: str, + save_path: str, +): + headers = {k: v for k, v in headers.items()} + headers.pop("content-length") + async with httpx.AsyncClient(follow_redirects=True) as client: + response = await client.request( + method=method, + url=pathsinfo_url, + headers=headers, + data={"paths": path}, + timeout=WORKER_API_TIMEOUT, + ) + + if response.status_code == 200: + make_dirs(save_path) + await _write_cache_request( + save_path, + response.status_code, + response.headers, + bytes(response.content), + ) + return response.status_code, response.headers, response.content + + +async def pathsinfo_generator( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + paths: List[str], + request: Request, +): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + + # save + method = request.method.lower() + repos_path = app.app_settings.repos_path + + final_content = [] + for path in paths: + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/paths-info/{commit}/{path}" + ) + + save_path = os.path.join(save_dir, f"paths-info_{method}.json") + + use_cache = os.path.exists(save_path) + allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) + + org_repo = get_org_repo(org, repo) + pathsinfo_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/paths-info/{commit}", + ) + # proxy + if use_cache: + status, headers, content = await _pathsinfo_cache(save_path) + else: + print(path) + status, headers, content = await _pathsinfo_proxy( + app, headers, pathsinfo_url, allow_cache, method, path, save_path + ) + + try: + content_json = json.loads(content) + except json.JSONDecodeError: + continue + if status == 200 and isinstance(content_json, list): + final_content.extend(content_json) + + yield {'content-type': 'application/json'} + yield json.dumps(final_content, ensure_ascii=True) diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 357470a..1da1a31 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -6,8 +6,6 @@ # https://opensource.org/licenses/MIT. import os -import shutil -import tempfile from typing import Dict, Literal from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -21,12 +19,6 @@ from olah.utils.file_utils import make_dirs -async def _tree_cache_generator(save_path: str): - cache_rq = await _read_cache_request(save_path) - yield cache_rq["headers"] - yield cache_rq["content"] - - async def tree_proxy_cache( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], @@ -34,6 +26,7 @@ async def tree_proxy_cache( repo: str, commit: str, path: str, + recursive: bool, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -45,7 +38,10 @@ async def tree_proxy_cache( save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) - save_path = os.path.join(save_dir, f"tree_{method}.json") + if not recursive: + save_path = os.path.join(save_dir, f"tree_{method}.json") + else: + save_path = os.path.join(save_dir, f"tree_{method}_recursive.json") # url org_repo = get_org_repo(org, repo) @@ -60,6 +56,7 @@ async def tree_proxy_cache( response = await client.request( method=request.method, url=tree_url, + params={"recursive": recursive}, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True, @@ -77,12 +74,18 @@ async def tree_proxy_cache( ) +async def _tree_cache_generator(save_path: str): + cache_rq = await _read_cache_request(save_path) + yield cache_rq["headers"] + yield cache_rq["content"] + async def _tree_proxy_generator( app: FastAPI, headers: Dict[str, str], tree_url: str, allow_cache: bool, method: str, + recursive: bool, save_path: str, ): async with httpx.AsyncClient(follow_redirects=True) as client: @@ -90,6 +93,7 @@ async def _tree_proxy_generator( async with client.stream( method=method, url=tree_url, + params={"recursive": recursive}, headers=headers, timeout=WORKER_API_TIMEOUT, ) as response: @@ -121,6 +125,7 @@ async def tree_generator( repo: str, commit: str, path: str, + recursive: bool, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -132,7 +137,10 @@ async def tree_generator( save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) - save_path = os.path.join(save_dir, f"tree_{method}.json") + if not recursive: + save_path = os.path.join(save_dir, f"tree_{method}.json") + else: + save_path = os.path.join(save_dir, f"tree_{method}_recursive.json") use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) @@ -148,6 +156,6 @@ async def tree_generator( yield item else: async for item in _tree_proxy_generator( - app, headers, tree_url, allow_cache, method, save_path + app, headers, tree_url, allow_cache, method, recursive, save_path ): yield item diff --git a/olah/server.py b/olah/server.py index e0a4988..cb4741f 100644 --- a/olah/server.py +++ b/olah/server.py @@ -10,9 +10,9 @@ import glob import argparse import traceback -from typing import Annotated, Optional, Union +from typing import Annotated, List, Optional, Union from urllib.parse import urljoin -from fastapi import FastAPI, Header, Request +from fastapi import FastAPI, Header, Request, Form from fastapi.responses import ( FileResponse, HTMLResponse, @@ -26,6 +26,7 @@ import git import httpx +from olah.proxy.pathsinfo import pathsinfo_generator from olah.proxy.tree import tree_generator, tree_proxy_cache from olah.utils.url_utils import clean_path @@ -223,9 +224,18 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: repo_type=repo_type, org=org, repo=repo, commit=commit, request=request ) + # Git Tree -async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, path: str, request: Request) -> Response: - # TODO: the head method of meta apis +async def tree_proxy_common( + repo_type: str, + org: str, + repo: str, + commit: str, + path: str, + recursive: bool, + request: Request, +) -> Response: + # TODO ?recursive=True # FIXME: do not show the private repos to other user besides owner, even though the repo was cached path = clean_path(path) if repo_type not in REPO_TYPES_MAPPING.keys(): @@ -238,7 +248,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, pa git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) - tree_data = local_repo.get_tree(commit, path) + tree_data = local_repo.get_tree(commit, path, recursive=recursive) if tree_data is None: continue return JSONResponse(content=tree_data) @@ -269,33 +279,132 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, pa return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: - await tree_proxy_cache(app, repo_type, org, repo, commit, path, request) - generator = tree_generator(app, repo_type, org, repo, commit_sha, path, request) + await tree_proxy_cache(app, repo_type, org, repo, commit, path, recursive, request) + generator = tree_generator(app, repo_type, org, repo, commit_sha, path, recursive, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: traceback.print_exc() return Response(status_code=504) + @app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") @app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") async def tree_proxy_commit2( - repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request + repo_type: str, + org: str, + repo: str, + commit: str, + file_path: str, + request: Request, + recursive: bool = False, ): return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + path=file_path, + recursive=recursive, + request=request, ) @app.head("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") @app.get("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") -async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, file_path: str, request: Request): +async def tree_proxy_commit( + repo_type: str, + org_repo: str, + commit: str, + file_path: str, + request: Request, + recursive: bool = False, +): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + path=file_path, + recursive=recursive, + request=request, + ) + + +# TODO: paths-info +async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: str, paths: List[str], request: Request) -> Response: + # TODO: the head method of meta apis + # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + paths = [clean_path(path) for path in paths] + if repo_type not in REPO_TYPES_MAPPING.keys(): + return error_page_not_found() + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + tree_data = local_repo.get_pathinfos(commit, paths) + if tree_data is None: + continue + return JSONResponse(content=tree_data) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Meta + try: + if not app.app_settings.config.offline and not await check_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ): + return error_repo_not_found() + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) + if commit_sha is None: + return error_repo_not_found() + generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + except httpx.ConnectTimeout: + traceback.print_exc() + return Response(status_code=504) + +@app.head("/api/{repo_type}/{org}/{repo}/paths-info/{commit}") +@app.post("/api/{repo_type}/{org}/{repo}/paths-info/{commit}") +async def pathsinfo_proxy_commit2( + repo_type: str, org: str, repo: str, commit: str, paths: Annotated[List[str], Form()], request: Request +): + return await pathsinfo_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, paths=paths, request=request + ) + + +@app.head("/api/{repo_type}/{org_repo}/paths-info/{commit}") +@app.post("/api/{repo_type}/{org_repo}/paths-info/{commit}") +async def pathsinfo_proxy_commit(repo_type: str, org_repo: str, commit: str, paths: Annotated[List[str], Form()], request: Request): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return error_repo_not_found() + + return await pathsinfo_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, paths=paths, request=request ) diff --git a/requirements.txt b/requirements.txt index 404bbcd..52f3e72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ tenacity==8.5.0 peewee==3.17.6 typing_inspect==0.9.0 jinja2==3.1.4 +python-multipart==0.0.9