From 33e7cf1b30472bc9c9fdd0a71d49093bcccddac8 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Wed, 14 Aug 2024 18:38:51 +0800 Subject: [PATCH 1/8] add a cache stats tool --- olah/cache/__init__.py | 0 olah/{utils => cache}/bitset.py | 2 +- olah/{utils => cache}/olah_cache.py | 17 ++++++--- olah/cache/stat.py | 54 +++++++++++++++++++++++++++++ olah/proxy/files.py | 2 +- pyproject.toml | 2 +- 6 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 olah/cache/__init__.py rename olah/{utils => cache}/bitset.py (96%) rename olah/{utils => cache}/olah_cache.py (95%) create mode 100644 olah/cache/stat.py diff --git a/olah/cache/__init__.py b/olah/cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/utils/bitset.py b/olah/cache/bitset.py similarity index 96% rename from olah/utils/bitset.py rename to olah/cache/bitset.py index b290d3a..84c8ae7 100644 --- a/olah/utils/bitset.py +++ b/olah/cache/bitset.py @@ -76,4 +76,4 @@ def __str__(self): Returns: str: A string representation of the Bitset object, showing the binary representation of each byte. """ - return "".join(bin(byte)[2:].zfill(8) for byte in self.bits) + return "".join(bin(byte)[2:].zfill(8)[::-1] for byte in self.bits) diff --git a/olah/utils/olah_cache.py b/olah/cache/olah_cache.py similarity index 95% rename from olah/utils/olah_cache.py rename to olah/cache/olah_cache.py index 3a3d033..3b3c92d 100644 --- a/olah/utils/olah_cache.py +++ b/olah/cache/olah_cache.py @@ -52,7 +52,7 @@ def block_number(self) -> int: return self._block_number @property - def block_mask(self) -> int: + def block_mask(self) -> Bitset: return self._block_mask def get_header_size(self): @@ -76,11 +76,18 @@ def _valid_header(self): @staticmethod def read(stream) -> "OlahCacheHeader": obj = OlahCacheHeader() - magic, version, block_size, file_size, block_mask_size = struct.unpack( - "<4sQQQQ", stream.read(OlahCacheHeader.HEADER_FIX_SIZE) + try: + magic = struct.unpack( + "<4s", stream.read(4) + ) + except struct.error: + raise Exception("File is not a Olah cache file.") + if magic[0] != OlahCacheHeader.MAGIC_NUMBER: + raise Exception("File is not a Olah cache file.") + + version, block_size, file_size, block_mask_size = struct.unpack( + " str: + if size > 1024 * 1024 * 1024: + return f"{int(size / (1024 * 1024 * 1024)):.4f}GB" + elif size > 1024 * 1024: + return f"{int(size / (1024 * 1024)):.4f}MB" + elif size > 1024: + return f"{int(size / (1024)):.4f}KB" + else: + return f"{size:.4f}B" + +def insert_newlines(input_str, every=10): + return '\n'.join(input_str[i:i+every] for i in range(0, len(input_str), every)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Olah Cache Visualization Tool.") + parser.add_argument("--file", "-f", type=str, required=True, help="The path of Olah cache file") + parser.add_argument("--export", "-e", type=str, default="", help="Export the cached file if all blocks are cached") + args = parser.parse_args() + print(args) + + with open(args.file, "rb") as f: + f.seek(0, os.SEEK_END) + bin_size = f.tell() + + try: + cache = OlahCache(args.file) + except Exception as e: + print(e) + sys.exit(1) + print(f"File: {args.file}") + print(f"Olah Cache Version: {cache.header.version}") + print(f"File Size: {get_size_human(cache.header.file_size)}") + print(f"Cache Total Size: {get_size_human(bin_size)}") + print(f"Block Size: {cache.header.block_size}") + print(f"Block Number: {cache.header.block_number}") + print(f"Cache Status: ") + cache_status = cache.header.block_mask.__str__()[:cache.header._block_number] + print(insert_newlines(cache_status, every=50)) + + if args.export != "": + if all([c == "1" for c in cache_status]): + with open(args.file, "rb") as f: + f.seek(cache._get_header_size(), os.SEEK_SET) + with open(args.export, "wb") as fout: + fout.write(f.read()) + else: + print("Some blocks are not cached, so the export is skipped.") \ No newline at end of file diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 715a0ae..28bdbe6 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -24,7 +24,7 @@ HUGGINGFACE_HEADER_X_LINKED_SIZE, ORIGINAL_LOC, ) -from olah.utils.olah_cache import OlahCache +from olah.cache.olah_cache import OlahCache from olah.utils.url_utils import ( RemoteInfo, add_query_param, diff --git a/pyproject.toml b/pyproject.toml index 7d10bd2..48f379c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "olah" -version = "0.2.0" +version = "0.2.1" description = "Self-hosted lightweight huggingface mirror." readme = "README.md" requires-python = ">=3.8" From fe7c68b3ea5dcff53668ea29d4194cca407c8e63 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 15 Aug 2024 16:47:01 +0800 Subject: [PATCH 2/8] code format --- olah/database/__init__.py | 0 olah/database/models.py | 26 +++++++++++ olah/proxy/lfs.py | 2 +- olah/proxy/meta.py | 9 +++- olah/utils/logging.py | 96 --------------------------------------- olah/utils/olah_utils.py | 17 +++++++ requirements.txt | 1 + 7 files changed, 52 insertions(+), 99 deletions(-) create mode 100644 olah/database/__init__.py create mode 100644 olah/database/models.py create mode 100644 olah/utils/olah_utils.py diff --git a/olah/database/__init__.py b/olah/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/database/models.py b/olah/database/models.py new file mode 100644 index 0000000..1d4d9fc --- /dev/null +++ b/olah/database/models.py @@ -0,0 +1,26 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import os +from peewee import * +import datetime + +from olah.utils.olah_utils import get_olah_path + + + +db_path = os.path.join(get_olah_path(), "database.db") +db = SqliteDatabase(db_path) + + +class BaseModel(Model): + class Meta: + database = db + + +class User(BaseModel): + username = CharField(unique=True) diff --git a/olah/proxy/lfs.py b/olah/proxy/lfs.py index a85dd26..69938b6 100644 --- a/olah/proxy/lfs.py +++ b/olah/proxy/lfs.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 1fac6ab..9d210d0 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. @@ -29,6 +29,7 @@ async def meta_cache_generator(app: FastAPI, save_path: str): break yield chunk + async def meta_proxy_cache( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], @@ -66,7 +67,10 @@ async def meta_proxy_cache( with open(save_path, "wb") as meta_file: meta_file.write(response.content) else: - raise Exception(f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}") + raise Exception( + f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}" + ) + async def meta_proxy_generator( app: FastAPI, @@ -98,6 +102,7 @@ async def meta_proxy_generator( with open(save_path, "wb") as f: f.write(bytes(content)) + async def meta_generator( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], diff --git a/olah/utils/logging.py b/olah/utils/logging.py index 8e9c16d..dac6bc0 100644 --- a/olah/utils/logging.py +++ b/olah/utils/logging.py @@ -21,13 +21,6 @@ from olah.constants import DEFAULT_LOGGER_DIR -server_error_msg = ( - "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -) -moderation_msg = ( - "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." -) - handler = None @@ -143,95 +136,6 @@ def flush(self): self.linebuf = "" -def disable_torch_init(): - """ - Disable the redundant torch default initialization to accelerate model creation. - """ - import torch - - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -def get_gpu_memory(max_gpus=None): - """Get available memory for each GPU.""" - gpu_memory = [] - num_gpus = ( - torch.cuda.device_count() - if max_gpus is None - else min(max_gpus, torch.cuda.device_count()) - ) - - for gpu_id in range(num_gpus): - with torch.cuda.device(gpu_id): - device = torch.cuda.current_device() - gpu_properties = torch.cuda.get_device_properties(device) - total_memory = gpu_properties.total_memory / (1024**3) - allocated_memory = torch.cuda.memory_allocated() / (1024**3) - available_memory = total_memory - allocated_memory - gpu_memory.append(available_memory) - return gpu_memory - - -def violates_moderation(text): - """ - Check whether the text violates OpenAI moderation API. - """ - url = "https://api.openai.com/v1/moderations" - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], - } - text = text.replace("\n", "") - data = "{" + '"input": ' + f'"{text}"' + "}" - data = data.encode("utf-8") - try: - ret = requests.post(url, headers=headers, data=data, timeout=5) - flagged = ret.json()["results"][0]["flagged"] - except requests.exceptions.RequestException as e: - flagged = False - except KeyError as e: - flagged = False - - return flagged - - -# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, -# Use this function to make sure it can be correctly loaded. -def clean_flant5_ckpt(ckpt_path): - index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") - index_json = json.load(open(index_file, "r")) - - weightmap = index_json["weight_map"] - - share_weight_file = weightmap["shared.weight"] - share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ - "shared.weight" - ] - - for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: - weight_file = weightmap[weight_name] - weight = torch.load(os.path.join(ckpt_path, weight_file)) - weight[weight_name] = share_weight - torch.save(weight, os.path.join(ckpt_path, weight_file)) - - -def pretty_print_semaphore(semaphore): - if semaphore is None: - return "None" - return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" - - -get_window_url_params_js = """ -function() { - const params = new URLSearchParams(window.location.search); - url_params = Object.fromEntries(params); - console.log("url_params", url_params); - return url_params; - } -""" - - def iter_over_async( async_gen: AsyncGenerator, event_loop: AbstractEventLoop ) -> Generator: diff --git a/olah/utils/olah_utils.py b/olah/utils/olah_utils.py new file mode 100644 index 0000000..dd686e4 --- /dev/null +++ b/olah/utils/olah_utils.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import platform +import os + + +def get_olah_path() -> str: + if platform.system() == "Windows": + olah_path = os.path.expanduser("~\\.olah") + else: + olah_path = os.path.expanduser("~/.olah") + return olah_path diff --git a/requirements.txt b/requirements.txt index b748e2e..f1ec968 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pytest==8.2.2 cachetools==5.4.0 PyYAML==6.0.1 tenacity==8.5.0 +peewee==3.17.6 From 128cb6b70aeeea9c5efacd9bbbe4104b80659c82 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 16 Aug 2024 01:14:19 +0800 Subject: [PATCH 3/8] git tree api --- olah/proxy/tree.py | 141 +++++++++++++++++++++++++++++++++++++++++++++ olah/server.py | 86 +++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 olah/proxy/tree.py diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py new file mode 100644 index 0000000..b0125a5 --- /dev/null +++ b/olah/proxy/tree.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import os +import shutil +import tempfile +from typing import Dict, Literal +from urllib.parse import urljoin +from fastapi import FastAPI, Request + +import httpx +from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT + +from olah.utils.rule_utils import check_cache_rules_hf +from olah.utils.repo_utils import get_org_repo +from olah.utils.file_utils import make_dirs + + +async def tree_cache_generator(save_path: str): + yield {} + with open(save_path, "rb") as f: + while True: + chunk = f.read(CHUNK_SIZE) + if not chunk: + break + yield chunk + + +async def tree_proxy_cache( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + request: Request, +): + # save + repos_path = app.app_settings.repos_path + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, "tree.json") + make_dirs(save_path) + + # url + org_repo = get_org_repo(org, repo) + tree_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/tree/{commit}", + ) + headers = {} + if "authorization" in request.headers: + headers["authorization"] = request.headers["authorization"] + async with httpx.AsyncClient() as client: + response = await client.request( + method="GET", + url=tree_url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + follow_redirects=True, + ) + if response.status_code == 200: + with open(save_path, "wb") as tree_file: + tree_file.write(response.content) + else: + raise Exception( + f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" + ) + + +async def tree_proxy_generator( + app: FastAPI, + headers: Dict[str, str], + tree_url: str, + allow_cache: bool, + save_path: str, +): + async with httpx.AsyncClient(follow_redirects=True) as client: + content_chunks = [] + async with client.stream( + method="GET", + url=tree_url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + yield response_headers + + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + content_chunks.append(raw_chunk) + yield raw_chunk + + content = bytearray() + for chunk in content_chunks: + content += chunk + with open(save_path, "wb") as f: + f.write(bytes(content)) + + +async def tree_generator( + app: FastAPI, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + commit: str, + request: Request, +): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + + # save + repos_path = app.app_settings.repos_path + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, "tree.json") + make_dirs(save_path) + + use_cache = os.path.exists(save_path) + allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) + + org_repo = get_org_repo(org, repo) + tree_url = urljoin( + app.app_settings.config.hf_url_base(), + f"/api/{repo_type}/{org_repo}/tree/{commit}", + ) + # proxy + if use_cache: + async for item in tree_cache_generator(app, save_path): + yield item + else: + async for item in tree_proxy_generator( + app, headers, tree_url, allow_cache, save_path + ): + yield item diff --git a/olah/server.py b/olah/server.py index 60e4083..9a127c6 100644 --- a/olah/server.py +++ b/olah/server.py @@ -26,6 +26,8 @@ import git import httpx +from olah.proxy.tree import tree_generator, tree_proxy_cache + BASE_SETTINGS = False if not BASE_SETTINGS: try: @@ -184,6 +186,16 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request ) +@app.get("/api/{repo_type}/{org}/{repo}") +async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): + if not app.app_settings.config.offline: + new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + else: + new_commit = "main" + return await meta_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request + ) + @app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") async def meta_proxy_commit2( @@ -204,6 +216,80 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: repo_type=repo_type, org=org, repo=repo, commit=commit, request=request ) +# Git Tree +async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: + # TODO: the head method of meta apis + # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + if repo_type not in REPO_TYPES_MAPPING.keys(): + return error_page_not_found() + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return error_repo_not_found() + # Check Mirror Path + for mirror_path in app.app_settings.config.mirrors_path: + try: + git_path = os.path.join(mirror_path, repo_type, org, repo) + if os.path.exists(git_path): + local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + meta_data = local_repo.get_tree(commit) + if meta_data is None: + continue + return JSONResponse(content=meta_data) + except git.exc.InvalidGitRepositoryError: + logger.warning(f"Local repository {git_path} is not a valid git reposity.") + continue + + # Proxy the HF File Meta + try: + if not app.app_settings.config.offline and not await check_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ): + return error_repo_not_found() + commit_sha = await get_commit_hf( + app, + repo_type, + org, + repo, + commit=commit, + authorization=request.headers.get("authorization", None), + ) + if commit_sha is None: + return error_repo_not_found() + # if branch name and online mode, refresh branch info + if not app.app_settings.config.offline and commit_sha != commit: + await tree_proxy_cache(app, repo_type, org, repo, commit, request) + generator = tree_generator(app, repo_type, org, repo, commit_sha, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + except httpx.ConnectTimeout: + traceback.print_exc() + return Response(status_code=504) + + +@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}") +async def tree_proxy_commit2( + repo_type: str, org: str, repo: str, commit: str, request: Request +): + return await tree_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) + + +@app.get("/api/{repo_type}/{org_repo}/tree/{commit}") +async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + org, repo = parse_org_repo(org_repo) + if org is None and repo is None: + return error_repo_not_found() + + return await tree_proxy_common( + repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + ) + + # ====================== # Authentication API Hooks # ====================== From 01ff9d6a8c592ee4ce5921089a9ea5933518174e Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 16 Aug 2024 01:28:15 +0800 Subject: [PATCH 4/8] code format --- olah/proxy/tree.py | 8 ++------ olah/server.py | 6 +++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index b0125a5..875930d 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -40,9 +40,7 @@ async def tree_proxy_cache( ): # save repos_path = app.app_settings.repos_path - save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" - ) + save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}") save_path = os.path.join(save_dir, "tree.json") make_dirs(save_path) @@ -116,9 +114,7 @@ async def tree_generator( # save repos_path = app.app_settings.repos_path - save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" - ) + save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}") save_path = os.path.join(save_dir, "tree.json") make_dirs(save_path) diff --git a/olah/server.py b/olah/server.py index 9a127c6..20ba9f9 100644 --- a/olah/server.py +++ b/olah/server.py @@ -230,10 +230,10 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) - meta_data = local_repo.get_tree(commit) - if meta_data is None: + tree_data = local_repo.get_tree(commit) + if tree_data is None: continue - return JSONResponse(content=meta_data) + return JSONResponse(content=tree_data) except git.exc.InvalidGitRepositoryError: logger.warning(f"Local repository {git_path} is not a valid git reposity.") continue From 13933f35f251cf47187c2c2145ac7b176ce8f41a Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 16 Aug 2024 01:44:20 +0800 Subject: [PATCH 5/8] update meta and tree apis --- olah/proxy/meta.py | 8 ++++---- olah/proxy/tree.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 9d210d0..f661d52 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -20,7 +20,7 @@ from olah.utils.file_utils import make_dirs -async def meta_cache_generator(app: FastAPI, save_path: str): +async def _meta_cache_generator(save_path: str): yield {} with open(save_path, "rb") as f: while True: @@ -72,7 +72,7 @@ async def meta_proxy_cache( ) -async def meta_proxy_generator( +async def _meta_proxy_generator( app: FastAPI, headers: Dict[str, str], meta_url: str, @@ -132,10 +132,10 @@ async def meta_generator( ) # proxy if use_cache: - async for item in meta_cache_generator(app, save_path): + async for item in _meta_cache_generator(save_path): yield item else: - async for item in meta_proxy_generator( + async for item in _meta_proxy_generator( app, headers, meta_url, allow_cache, save_path ): yield item diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 875930d..963125c 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -20,7 +20,7 @@ from olah.utils.file_utils import make_dirs -async def tree_cache_generator(save_path: str): +async def _tree_cache_generator(save_path: str): yield {} with open(save_path, "rb") as f: while True: @@ -70,7 +70,7 @@ async def tree_proxy_cache( ) -async def tree_proxy_generator( +async def _tree_proxy_generator( app: FastAPI, headers: Dict[str, str], tree_url: str, @@ -128,10 +128,10 @@ async def tree_generator( ) # proxy if use_cache: - async for item in tree_cache_generator(app, save_path): + async for item in _tree_cache_generator(save_path): yield item else: - async for item in tree_proxy_generator( + async for item in _tree_proxy_generator( app, headers, tree_url, allow_cache, save_path ): yield item From ba65afdc4e6b5414a207c51224c639474017a8b0 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 16 Aug 2024 02:11:39 +0800 Subject: [PATCH 6/8] skip 404 status --- olah/proxy/files.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 28bdbe6..0484e04 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -194,6 +194,8 @@ async def _file_full_header( ) elif response.status_code == 403: pass + elif response.status_code == 404: + pass else: raise Exception( f"Unexpected HTTP status code {response.status_code}" From 5b74ebaeb8980cfdf6da8c026c26daecfe3c7fec Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Fri, 16 Aug 2024 03:46:45 +0800 Subject: [PATCH 7/8] update api cache format --- olah/proxy/files.py | 43 +---------------------------------- olah/proxy/meta.py | 39 ++++++++++++++++++------------- olah/proxy/tree.py | 47 +++++++++++++++++++++++--------------- olah/server.py | 14 +++++++++--- olah/utils/cache_utils.py | 48 +++++++++++++++++++++++++++++++++++++++ olah/utils/repo_utils.py | 15 +++++++----- 6 files changed, 121 insertions(+), 85 deletions(-) create mode 100644 olah/utils/cache_utils.py diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 0484e04..5d2a5c8 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -25,6 +25,7 @@ ORIGINAL_LOC, ) from olah.cache.olah_cache import OlahCache +from olah.utils.cache_utils import _read_cache_request, _write_cache_request from olah.utils.url_utils import ( RemoteInfo, add_query_param, @@ -80,48 +81,6 @@ def get_contiguous_ranges( range_start_pos = end_pos return ranges_and_cache_list - -async def _write_cache_request( - head_path: str, status_code: int, headers: Dict[str, str], content: bytes -) -> None: - """ - Write the request's status code, headers, and content to a cache file. - - Args: - head_path (str): The path to the cache file. - status_code (int): The status code of the request. - headers (Dict[str, str]): The dictionary of response headers. - content (bytes): The content of the request. - - Returns: - None - """ - rq = { - "status_code": status_code, - "headers": headers, - "content": content.hex(), - } - with open(head_path, "w", encoding="utf-8") as f: - f.write(json.dumps(rq, ensure_ascii=False)) - - -async def _read_cache_request(head_path: str) -> Dict[str, str]: - """ - Read the request's status code, headers, and content from a cache file. - - Args: - head_path (str): The path to the cache file. - - Returns: - Dict[str, str]: A dictionary containing the status code, headers, and content of the request. - """ - with open(head_path, "r", encoding="utf-8") as f: - rq = json.loads(f.read()) - - rq["content"] = bytes.fromhex(rq["content"]) - return rq - - async def _file_full_header( app, save_path: str, diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index f661d52..3a00f32 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -15,19 +15,16 @@ import httpx from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT +from olah.utils.cache_utils import _read_cache_request, _write_cache_request from olah.utils.rule_utils import check_cache_rules_hf from olah.utils.repo_utils import get_org_repo from olah.utils.file_utils import make_dirs async def _meta_cache_generator(save_path: str): - yield {} - with open(save_path, "rb") as f: - while True: - chunk = f.read(CHUNK_SIZE) - if not chunk: - break - yield chunk + cache_rq = await _read_cache_request(save_path) + yield cache_rq["headers"] + yield cache_rq["content"] async def meta_proxy_cache( @@ -38,12 +35,16 @@ async def meta_proxy_cache( commit: str, request: Request, ): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + # save + method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}" ) - save_path = os.path.join(save_dir, "meta.json") + save_path = os.path.join(save_dir, f"meta_{method}.json") make_dirs(save_path) # url @@ -57,15 +58,16 @@ async def meta_proxy_cache( headers["authorization"] = request.headers["authorization"] async with httpx.AsyncClient() as client: response = await client.request( - method="GET", + method=request.method, url=meta_url, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True, ) if response.status_code == 200: - with open(save_path, "wb") as meta_file: - meta_file.write(response.content) + await _write_cache_request( + save_path, response.status_code, response.headers, response.content + ) else: raise Exception( f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}" @@ -77,16 +79,18 @@ async def _meta_proxy_generator( headers: Dict[str, str], meta_url: str, allow_cache: bool, + method: str, save_path: str, ): async with httpx.AsyncClient(follow_redirects=True) as client: content_chunks = [] async with client.stream( - method="GET", + method=method, url=meta_url, headers=headers, timeout=WORKER_API_TIMEOUT, ) as response: + response_status_code = response.status_code response_headers = response.headers yield response_headers @@ -99,8 +103,10 @@ async def _meta_proxy_generator( content = bytearray() for chunk in content_chunks: content += chunk - with open(save_path, "wb") as f: - f.write(bytes(content)) + + await _write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) async def meta_generator( @@ -115,11 +121,12 @@ async def meta_generator( headers.pop("host") # save + method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}" ) - save_path = os.path.join(save_dir, "meta.json") + save_path = os.path.join(save_dir, f"meta_{method}.json") make_dirs(save_path) use_cache = os.path.exists(save_path) @@ -136,6 +143,6 @@ async def meta_generator( yield item else: async for item in _meta_proxy_generator( - app, headers, meta_url, allow_cache, save_path + app, headers, meta_url, allow_cache, method, save_path ): yield item diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 963125c..2f96f9f 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -15,19 +15,16 @@ import httpx from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT +from olah.utils.cache_utils import _read_cache_request, _write_cache_request from olah.utils.rule_utils import check_cache_rules_hf from olah.utils.repo_utils import get_org_repo from olah.utils.file_utils import make_dirs async def _tree_cache_generator(save_path: str): - yield {} - with open(save_path, "rb") as f: - while True: - chunk = f.read(CHUNK_SIZE) - if not chunk: - break - yield chunk + cache_rq = await _read_cache_request(save_path) + yield cache_rq["headers"] + yield cache_rq["content"] async def tree_proxy_cache( @@ -38,10 +35,16 @@ async def tree_proxy_cache( commit: str, request: Request, ): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + # save + method = request.method.lower() repos_path = app.app_settings.repos_path - save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}") - save_path = os.path.join(save_dir, "tree.json") + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, f"tree_{method}.json") make_dirs(save_path) # url @@ -55,15 +58,16 @@ async def tree_proxy_cache( headers["authorization"] = request.headers["authorization"] async with httpx.AsyncClient() as client: response = await client.request( - method="GET", + method=request.method, url=tree_url, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True, ) if response.status_code == 200: - with open(save_path, "wb") as tree_file: - tree_file.write(response.content) + await _write_cache_request( + save_path, response.status_code, response.headers, response.content + ) else: raise Exception( f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" @@ -75,16 +79,18 @@ async def _tree_proxy_generator( headers: Dict[str, str], tree_url: str, allow_cache: bool, + method: str, save_path: str, ): async with httpx.AsyncClient(follow_redirects=True) as client: content_chunks = [] async with client.stream( - method="GET", + method=method, url=tree_url, headers=headers, timeout=WORKER_API_TIMEOUT, ) as response: + response_status_code = response.status_code response_headers = response.headers yield response_headers @@ -97,8 +103,10 @@ async def _tree_proxy_generator( content = bytearray() for chunk in content_chunks: content += chunk - with open(save_path, "wb") as f: - f.write(bytes(content)) + + await _write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) async def tree_generator( @@ -113,9 +121,12 @@ async def tree_generator( headers.pop("host") # save + method = request.method.lower() repos_path = app.app_settings.repos_path - save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}") - save_path = os.path.join(save_dir, "tree.json") + save_dir = os.path.join( + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + ) + save_path = os.path.join(save_dir, f"tree_{method}.json") make_dirs(save_path) use_cache = os.path.exists(save_path) @@ -132,6 +143,6 @@ async def tree_generator( yield item else: async for item in _tree_proxy_generator( - app, headers, tree_url, allow_cache, save_path + app, headers, tree_url, allow_cache, method, save_path ): yield item diff --git a/olah/server.py b/olah/server.py index 20ba9f9..49df26f 100644 --- a/olah/server.py +++ b/olah/server.py @@ -173,6 +173,7 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re return Response(status_code=504) +@app.head("/api/{repo_type}/{org_repo}") @app.get("/api/{repo_type}/{org_repo}") async def meta_proxy(repo_type: str, org_repo: str, request: Request): org, repo = parse_org_repo(org_repo) @@ -180,23 +181,28 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): return error_repo_not_found() if not app.app_settings.config.offline: new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + if new_commit is None: + return error_repo_not_found() else: new_commit = "main" return await meta_proxy_common( repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request ) +@app.head("/api/{repo_type}/{org}/{repo}") @app.get("/api/{repo_type}/{org}/{repo}") async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): if not app.app_settings.config.offline: new_commit = await get_newest_commit_hf(app, repo_type, org, repo) + if new_commit is None: + return error_repo_not_found() else: new_commit = "main" return await meta_proxy_common( repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request ) - +@app.head("/api/{repo_type}/{org}/{repo}/revision/{commit}") @app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") async def meta_proxy_commit2( repo_type: str, org: str, repo: str, commit: str, request: Request @@ -205,7 +211,7 @@ async def meta_proxy_commit2( repo_type=repo_type, org=org, repo=repo, commit=commit, request=request ) - +@app.head("/api/{repo_type}/{org_repo}/revision/{commit}") @app.get("/api/{repo_type}/{org_repo}/revision/{commit}") async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): org, repo = parse_org_repo(org_repo) @@ -230,6 +236,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) + # TODO: Local git repo trees tree_data = local_repo.get_tree(commit) if tree_data is None: continue @@ -269,7 +276,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re traceback.print_exc() return Response(status_code=504) - +@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}") @app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}") async def tree_proxy_commit2( repo_type: str, org: str, repo: str, commit: str, request: Request @@ -279,6 +286,7 @@ async def tree_proxy_commit2( ) +@app.head("/api/{repo_type}/{org_repo}/tree/{commit}") @app.get("/api/{repo_type}/{org_repo}/tree/{commit}") async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): org, repo = parse_org_repo(org_repo) diff --git a/olah/utils/cache_utils.py b/olah/utils/cache_utils.py new file mode 100644 index 0000000..293922a --- /dev/null +++ b/olah/utils/cache_utils.py @@ -0,0 +1,48 @@ + + + +import json +from typing import Dict, Mapping, Union + + +async def _write_cache_request( + save_path: str, status_code: int, headers: Union[Dict[str, str], Mapping], content: bytes +) -> None: + """ + Write the request's status code, headers, and content to a cache file. + + Args: + head_path (str): The path to the cache file. + status_code (int): The status code of the request. + headers (Dict[str, str]): The dictionary of response headers. + content (bytes): The content of the request. + + Returns: + None + """ + if not isinstance(headers, dict): + headers = {k.lower():v for k, v in headers.items()} + rq = { + "status_code": status_code, + "headers": headers, + "content": content.hex(), + } + with open(save_path, "w", encoding="utf-8") as f: + f.write(json.dumps(rq, ensure_ascii=False)) + + +async def _read_cache_request(save_path: str) -> Dict[str, str]: + """ + Read the request's status code, headers, and content from a cache file. + + Args: + save_path (str): The path to the cache file. + + Returns: + Dict[str, str]: A dictionary containing the status code, headers, and content of the request. + """ + with open(save_path, "r", encoding="utf-8") as f: + rq = json.loads(f.read()) + + rq["content"] = bytes.fromhex(rq["content"]) + return rq diff --git a/olah/utils/repo_utils.py b/olah/utils/repo_utils.py index 41f0c19..ebbcbd1 100644 --- a/olah/utils/repo_utils.py +++ b/olah/utils/repo_utils.py @@ -130,7 +130,7 @@ async def get_newest_commit_hf_offline( repo_type: Optional[Literal["models", "datasets", "spaces"]], org: str, repo: str, -) -> str: +) -> Optional[str]: """ Retrieves the newest commit hash for a repository in offline mode. @@ -146,7 +146,7 @@ async def get_newest_commit_hf_offline( """ repos_path = app.app_settings.repos_path save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) - files = glob.glob(os.path.join(save_dir, "*", "meta.json")) + files = glob.glob(os.path.join(save_dir, "*", "meta_head.json")) time_revisions = [] for file in files: @@ -156,7 +156,10 @@ async def get_newest_commit_hf_offline( time_revisions.append((datetime_object, obj["sha"])) time_revisions = sorted(time_revisions) - return time_revisions[-1][1] + if len(time_revisions) == 0: + return None + else: + return time_revisions[-1][1] async def get_newest_commit_hf( @@ -182,16 +185,16 @@ async def get_newest_commit_hf( app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}" ) if app.app_settings.config.offline: - return get_newest_commit_hf_offline(app, repo_type, org, repo) + return await get_newest_commit_hf_offline(app, repo_type, org, repo) try: async with httpx.AsyncClient() as client: response = await client.get(url, timeout=WORKER_API_TIMEOUT) if response.status_code != 200: - return get_newest_commit_hf_offline(app, repo_type, org, repo) + return await get_newest_commit_hf_offline(app, repo_type, org, repo) obj = json.loads(response.text) return obj.get("sha", None) except: - return get_newest_commit_hf_offline(app, repo_type, org, repo) + return await get_newest_commit_hf_offline(app, repo_type, org, repo) async def get_commit_hf_offline( From 3501520e84e0b0e1cebb29d072357de2c1c3c119 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 22 Aug 2024 16:22:55 +0800 Subject: [PATCH 8/8] tree api bug fix --- olah/errors.py | 10 +++++ olah/mirror/repos.py | 85 ++++++++++++++++++++++++++++++++++++++--- olah/proxy/tree.py | 23 ++++++----- olah/server.py | 27 ++++++------- olah/utils/url_utils.py | 9 +++++ 5 files changed, 126 insertions(+), 28 deletions(-) diff --git a/olah/errors.py b/olah/errors.py index 7768275..a020c57 100644 --- a/olah/errors.py +++ b/olah/errors.py @@ -28,3 +28,13 @@ def error_page_not_found() -> Response: }, status_code=404, ) + +def error_entry_not_found(branch: str, path: str) -> Response: + return Response( + headers={ + "x-error-code": "EntryNotFound", + "x-error-message": f"{path} does not exist on \"{branch}\"", + }, + status_code=404, + ) + diff --git a/olah/mirror/repos.py b/olah/mirror/repos.py index 2c43ddd..fbae83e 100644 --- a/olah/mirror/repos.py +++ b/olah/mirror/repos.py @@ -65,19 +65,70 @@ def _get_description(self, commit: Commit) -> str: readme = self._get_readme(commit) return self._remove_card(readme) - def _get_entry_files(self, tree, include_dir=False) -> List[str]: + def _get_tree_files_recursive(self, tree, include_dir=False) -> List[str]: out_paths = [] for entry in tree: if entry.type == "tree": - out_paths.extend(self._get_entry_files(entry)) + out_paths.extend(self._get_tree_files_recursive(entry)) if include_dir: out_paths.append(entry.path) else: out_paths.append(entry.path) return out_paths - def _get_tree_files(self, commit: Commit) -> List[str]: - return self._get_entry_files(commit.tree) + def _get_commit_files_recursive(self, commit: Commit) -> List[str]: + return self._get_tree_files_recursive(commit.tree) + + def _get_tree_files(self, tree: Tree) -> List[Dict[str, Union[int, str]]]: + entries = [] + for entry in tree: + lfs = False + if entry.type != "tree": + t = "file" + repr_size = entry.size + if repr_size > 120 and repr_size < 150: + # check lfs + lfs_data = entry.data_stream.read().decode("utf-8") + match_groups = re.match( + r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", + lfs_data, + ) + if match_groups is not None: + lfs = True + sha256 = match_groups.group(1) + repr_size = int(match_groups.group(2)) + lfs_data = { + "oid": sha256, + "size": repr_size, + "pointerSize": entry.size, + } + else: + t = "directory" + repr_size = 0 + + if not lfs: + entries.append( + { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.name, + } + ) + else: + entries.append( + { + "type": t, + "oid": entry.hexsha, + "size": repr_size, + "path": entry.name, + "lfs": lfs_data, + } + ) + return entries + + def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]: + return self._get_tree_files(commit.tree) def _get_earliest_commit(self) -> Commit: earliest_commit = None @@ -92,7 +143,27 @@ def _get_earliest_commit(self) -> Commit: return earliest_commit - def get_meta(self, commit_hash: str) -> Dict[str, Any]: + def get_tree(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]: + try: + commit = self._git_repo.commit(commit_hash) + except gitdb.exc.BadName: + return None + + path_part = path.split("/") + tree = commit.tree + items = self._get_tree_files(tree=tree) + for part in path_part: + if len(part.strip()) == 0: + continue + if part not in [ + item["path"] for item in items if item["type"] == "directory" + ]: + return None + tree = tree[part] + items = self._get_tree_files(tree=tree) + return items + + def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: try: commit = self._git_repo.commit(commit_hash) except gitdb.exc.BadName: @@ -117,7 +188,9 @@ def get_meta(self, commit_hash: str) -> Dict[str, Any]: meta.cardData = yaml.load( self._match_card(self._get_readme(commit)), Loader=yaml.CLoader ) - meta.siblings = [{"rfilename": p} for p in self._get_tree_files(commit)] + meta.siblings = [ + {"rfilename": p} for p in self._get_commit_files_recursive(commit) + ] meta.createdAt = self._get_earliest_commit().committed_datetime.strftime( "%Y-%m-%dT%H:%M:%S.%fZ" ) diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 2f96f9f..357470a 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -33,6 +33,7 @@ async def tree_proxy_cache( org: str, repo: str, commit: str, + path: str, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -42,16 +43,15 @@ async def tree_proxy_cache( method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) save_path = os.path.join(save_dir, f"tree_{method}.json") - make_dirs(save_path) # url org_repo = get_org_repo(org, repo) tree_url = urljoin( app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/tree/{commit}", + f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) headers = {} if "authorization" in request.headers: @@ -65,9 +65,12 @@ async def tree_proxy_cache( follow_redirects=True, ) if response.status_code == 200: + make_dirs(save_path) await _write_cache_request( save_path, response.status_code, response.headers, response.content ) + elif response.status_code == 404: + pass else: raise Exception( f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" @@ -104,9 +107,11 @@ async def _tree_proxy_generator( for chunk in content_chunks: content += chunk - await _write_cache_request( - save_path, response_status_code, response_headers, bytes(content) - ) + if response_status_code == 200: + make_dirs(save_path) + await _write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) async def tree_generator( @@ -115,6 +120,7 @@ async def tree_generator( org: str, repo: str, commit: str, + path: str, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -124,10 +130,9 @@ async def tree_generator( method = request.method.lower() repos_path = app.app_settings.repos_path save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}" + repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) save_path = os.path.join(save_dir, f"tree_{method}.json") - make_dirs(save_path) use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) @@ -135,7 +140,7 @@ async def tree_generator( org_repo = get_org_repo(org, repo) tree_url = urljoin( app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/tree/{commit}", + f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) # proxy if use_cache: diff --git a/olah/server.py b/olah/server.py index 49df26f..2736f37 100644 --- a/olah/server.py +++ b/olah/server.py @@ -27,6 +27,7 @@ import httpx from olah.proxy.tree import tree_generator, tree_proxy_cache +from olah.utils.url_utils import clean_path BASE_SETTINGS = False if not BASE_SETTINGS: @@ -223,9 +224,10 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: ) # Git Tree -async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: +async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, path: str, request: Request) -> Response: # TODO: the head method of meta apis # FIXME: do not show the private repos to other user besides owner, even though the repo was cached + path = clean_path(path) if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() if not await check_proxy_rules_hf(app, repo_type, org, repo): @@ -236,8 +238,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) - # TODO: Local git repo trees - tree_data = local_repo.get_tree(commit) + tree_data = local_repo.get_tree(commit, path) if tree_data is None: continue return JSONResponse(content=tree_data) @@ -268,33 +269,33 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: - await tree_proxy_cache(app, repo_type, org, repo, commit, request) - generator = tree_generator(app, repo_type, org, repo, commit_sha, request) + await tree_proxy_cache(app, repo_type, org, repo, commit, path, request) + generator = tree_generator(app, repo_type, org, repo, commit_sha, path, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: traceback.print_exc() return Response(status_code=504) -@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}") -@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}") +@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") +@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}") async def tree_proxy_commit2( - repo_type: str, org: str, repo: str, commit: str, request: Request + repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request ): return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request ) -@app.head("/api/{repo_type}/{org_repo}/tree/{commit}") -@app.get("/api/{repo_type}/{org_repo}/tree/{commit}") -async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): +@app.head("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") +@app.get("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}") +async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, file_path: str, request: Request): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await tree_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request ) diff --git a/olah/utils/url_utils.py b/olah/utils/url_utils.py index 588ac9c..d8257c5 100644 --- a/olah/utils/url_utils.py +++ b/olah/utils/url_utils.py @@ -162,3 +162,12 @@ def remove_query_param(url: str, param_name: str) -> str: new_url = urlunparse(parsed_url._replace(query=new_query)) return new_url + + +def clean_path(path: str) -> str: + while ".." in path: + path = path.replace("..", "") + path = path.replace("\\", "/") + while "//" in path: + path = path.replace("//", "/") + return path \ No newline at end of file