diff --git a/olah/errors.py b/olah/errors.py index 7768275..a020c57 100644 --- a/olah/errors.py +++ b/olah/errors.py @@ -28,3 +28,13 @@ def error_page_not_found() -> Response: }, status_code=404, ) + +def error_entry_not_found(branch: str, path: str) -> Response: + return Response( + headers={ + "x-error-code": "EntryNotFound", + "x-error-message": f"{path} does not exist on \"{branch}\"", + }, + status_code=404, + ) + diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py index 2c43ddd..fbae83e 100644 --- a/olah/mirror/repos.py +++ b/olah/mirror/repos.py @@ -65,19 +65,70 @@ def _get_description(self, commit: Commit) -> str: readme = self._get_readme(commit) return self._remove_card(readme) - def _get_entry_files(self, tree, include_dir=False) -> List[str]: + def _get_tree_files_recursive(self, tree, include_dir=False) -> List[str]: out_paths = [] for entry in tree: if entry.type == "tree": - out_paths.extend(self._get_entry_files(entry)) + out_paths.extend(self._get_tree_files_recursive(entry)) if include_dir: out_paths.append(entry.path) else: out_paths.append(entry.path) return out_paths - def _get_tree_files(self, commit: Commit) -> List[str]: - return self._get_entry_files(commit.tree) + def _get_commit_files_recursive(self, commit: Commit) -> List[str]: + return self._get_tree_files_recursive(commit.tree) + + def _get_tree_files(self, tree: Tree) -> List[Dict[str, Union[int, str]]]: + entries = [] + for entry in tree: + lfs = False + if entry.type != "tree": + t = "file" + repr_size = entry.size + if repr_size > 120 and repr_size < 150: + # check lfs + lfs_data = entry.data_stream.read().decode("utf-8") + match_groups = re.match( + r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", + lfs_data, + ) + if match_groups is not None: + lfs = True + sha256 = match_groups.group(1) + repr_size = int(match_groups.group(2)) + lfs_data = { + "oid": sha256, + "size": repr_size, + "pointerSize": entry.size, + } + else: + t = "directory" + repr_size = 0 + + if not lfs: + entries.append( + { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.name, + } + ) + else: + entries.append( + { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.name, + "lfs": lfs_data, + } + ) + return entries + + def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]: + return self._get_tree_files(commit.tree) def _get_earliest_commit(self) -> Commit: earliest_commit = None @@ -92,7 +143,27 @@ def _get_earliest_commit(self) -> Commit: return earliest_commit - def get_meta(self, commit_hash: str) -> Dict[str, Any]: + def get_tree(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + path_part = path.split("/") + tree = commit.tree + items = self._get_tree_files(tree=tree) + for part in path_part: + if len(part.strip()) == 0: + continue + if part not in [ + item["path"] for item in items if item["type"] == "directory" + ]: + return None + tree = tree[part] + items = self._get_tree_files(tree=tree) + return items + + def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: try: commit = self._git_repo.commit(commit_hash) except gitdb.exc.BadName: @@ -117,7 +188,9 @@ def get_meta(self, commit_hash: str) -> Dict[str, Any]: meta.cardData = yaml.load( self._match_card(self._get_readme(commit)), Loader=yaml.CLoader ) - meta.siblings = [{"rfilename": p} for p in self._get_tree_files(commit)] + meta.siblings = [ + {"rfilename": p} for p in self._get_commit_files_recursive(commit) + ] meta.createdAt = self._get_earliest_commit().committed_datetime.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" ) diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 2f96f9f..357470a 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -33,6 +33,7 @@ async def tree_proxy_cache( org: str, repo: str, commit: str, + path: str, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -42,16 +43,15 @@ async def tree_proxy_cache( method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) save_path = os.path.join(save_dir, f"tree_{method}.json") - make_dirs(save_path) # url org_repo = get_org_repo(org, repo) tree_url = urljoin( app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/tree/{commit}", + f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) headers = {} if "authorization" in request.headers: @@ -65,9 +65,12 @@ async def tree_proxy_cache( follow_redirects=True, ) if response.status_code == 200: + make_dirs(save_path) await _write_cache_request( save_path, response.status_code, response.headers, response.content ) + elif response.status_code == 404: + pass else: raise Exception( f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" @@ -104,9 +107,11 @@ async def _tree_proxy_generator( for chunk in content_chunks: content += chunk - await _write_cache_request( - save_path, response_status_code, response_headers, bytes(content) - ) + if response_status_code == 200: + make_dirs(save_path) + await _write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) async def tree_generator( @@ -115,6 +120,7 @@ async def tree_generator( org: str, repo: str, commit: str, + path: str, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -124,10 +130,9 @@ async def tree_generator( method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) save_path = os.path.join(save_dir, f"tree_{method}.json") - make_dirs(save_path) use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) @@ -135,7 +140,7 @@ async def tree_generator( org_repo = get_org_repo(org, repo) tree_url = urljoin( app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/tree/{commit}", + f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) # proxy if use_cache: diff --git a/olah/server.py b/olah/server.py index 49df26f..2736f37 100644 --- a/olah/server.py +++ b/olah/server.py @@ -27,6 +27,7 @@ import httpx from olah.proxy.tree import tree_generator, tree_proxy_cache +from olah.utils.url_utils import clean_path BASE_SETTINGS = False if not BASE_SETTINGS: @@ -223,9 +224,10 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: ) # Git Tree -async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: +async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, path: str, request: Request) -> Response: # TODO: the head method of meta apis # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + path = clean_path(path) if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() if not await check_proxy_rules_hf(app, repo_type, org, repo): @@ -236,8 +238,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) - # TODO: Local git repo trees - tree_data = local_repo.get_tree(commit) + tree_data = local_repo.get_tree(commit, path) if tree_data is None: continue return JSONResponse(content=tree_data) @@ -268,33 +269,33 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: - await tree_proxy_cache(app, repo_type, org, repo, commit, request) - generator = tree_generator(app, repo_type, org, repo, commit_sha, request) + await tree_proxy_cache(app, repo_type, org, repo, commit, path, request) + generator = tree_generator(app, repo_type, org, repo, commit_sha, path, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: traceback.print_exc() return Response(status_code=504) -@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}") -@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}") +@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") +@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") async def tree_proxy_commit2( - repo_type: str, org: str, repo: str, commit: str, request: Request + repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request ): return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request ) -@app.head("/api/{repo_type}/{org_repo}/tree/{commit}") -@app.get("/api/{repo_type}/{org_repo}/tree/{commit}") -async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): +@app.head("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") +@app.get("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") +async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, file_path: str, request: Request): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request ) diff --git a/olah/utils/url_utils.py b/olah/utils/url_utils.py index 588ac9c..d8257c5 100644 --- a/olah/utils/url_utils.py +++ b/olah/utils/url_utils.py @@ -162,3 +162,12 @@ def remove_query_param(url: str, param_name: str) -> str: new_url = urlunparse(parsed_url._replace(query=new_query)) return new_url + + +def clean_path(path: str) -> str: + while ".." in path: + path = path.replace("..", "") + path = path.replace("\\", "/") + while "//" in path: + path = path.replace("//", "/") + return path \ No newline at end of file