From 610ee14bc0326ab978d41a28be241ab7cde9beb4 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 2 Sep 2024 03:46:38 +0800 Subject: [PATCH] clean apis code --- olah/proxy/meta.py | 60 +++++------------------------------------ olah/proxy/pathsinfo.py | 6 ++--- olah/proxy/tree.py | 60 +++-------------------------------------- olah/server.py | 54 ++++++++++++++++++++++++++++++++----- 4 files changed, 60 insertions(+), 120 deletions(-) diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 3a00f32..60e73f3 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -20,66 +20,18 @@ from olah.utils.repo_utils import get_org_repo from olah.utils.file_utils import make_dirs - async def _meta_cache_generator(save_path: str): cache_rq = await _read_cache_request(save_path) yield cache_rq["headers"] yield cache_rq["content"] -async def meta_proxy_cache( - app: FastAPI, - repo_type: Literal["models", "datasets", "spaces"], - org: str, - repo: str, - commit: str, - request: Request, -): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") - - # save - method = request.method.lower() - repos_path = app.app_settings.repos_path - save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}" - ) - save_path = os.path.join(save_dir, f"meta_{method}.json") - make_dirs(save_path) - - # url - org_repo = get_org_repo(org, repo) - meta_url = urljoin( - app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/revision/{commit}", - ) - headers = {} - if "authorization" in request.headers: - headers["authorization"] = request.headers["authorization"] - async with httpx.AsyncClient() as client: - response = await client.request( - method=request.method, - url=meta_url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - follow_redirects=True, - ) - if response.status_code == 200: - await _write_cache_request( - save_path, response.status_code, response.headers, response.content - ) - else: - raise Exception( - f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}" - ) - - async def _meta_proxy_generator( app: FastAPI, headers: Dict[str, str], meta_url: str, - allow_cache: bool, method: str, + allow_cache: bool, save_path: str, ): async with httpx.AsyncClient(follow_redirects=True) as client: @@ -104,9 +56,10 @@ async def _meta_proxy_generator( for chunk in content_chunks: content += chunk - await _write_cache_request( - save_path, response_status_code, response_headers, bytes(content) - ) + if allow_cache and response_status_code == 200: + await _write_cache_request( + save_path, response_status_code, response_headers, bytes(content) + ) async def meta_generator( @@ -115,6 +68,7 @@ async def meta_generator( org: str, repo: str, commit: str, + override_cache: bool, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -138,7 +92,7 @@ async def meta_generator( f"/api/{repo_type}/{org_repo}/revision/{commit}", ) # proxy - if use_cache: + if use_cache and not override_cache: async for item in _meta_cache_generator(save_path): yield item else: diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py index e627f68..bbcdd39 100644 --- a/olah/proxy/pathsinfo.py +++ b/olah/proxy/pathsinfo.py @@ -45,7 +45,7 @@ async def _pathsinfo_proxy( timeout=WORKER_API_TIMEOUT, ) - if response.status_code == 200: + if allow_cache and response.status_code == 200: make_dirs(save_path) await _write_cache_request( save_path, @@ -63,6 +63,7 @@ async def pathsinfo_generator( repo: str, commit: str, paths: List[str], + override_cache: bool, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -89,10 +90,9 @@ async def pathsinfo_generator( f"/api/{repo_type}/{org_repo}/paths-info/{commit}", ) # proxy - if use_cache: + if use_cache and not override_cache: status, headers, content = await _pathsinfo_cache(save_path) else: - print(path) status, headers, content = await _pathsinfo_proxy( app, headers, pathsinfo_url, allow_cache, method, path, save_path ) diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 1da1a31..f5a21b6 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -19,61 +19,6 @@ from olah.utils.file_utils import make_dirs -async def tree_proxy_cache( - app: FastAPI, - repo_type: Literal["models", "datasets", "spaces"], - org: str, - repo: str, - commit: str, - path: str, - recursive: bool, - request: Request, -): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") - - # save - method = request.method.lower() - repos_path = app.app_settings.repos_path - save_dir = os.path.join( - repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" - ) - if not recursive: - save_path = os.path.join(save_dir, f"tree_{method}.json") - else: - save_path = os.path.join(save_dir, f"tree_{method}_recursive.json") - - # url - org_repo = get_org_repo(org, repo) - tree_url = urljoin( - app.app_settings.config.hf_url_base(), - f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", - ) - headers = {} - if "authorization" in request.headers: - headers["authorization"] = request.headers["authorization"] - async with httpx.AsyncClient() as client: - response = await client.request( - method=request.method, - url=tree_url, - params={"recursive": recursive}, - headers=headers, - timeout=WORKER_API_TIMEOUT, - follow_redirects=True, - ) - if response.status_code == 200: - make_dirs(save_path) - await _write_cache_request( - save_path, response.status_code, response.headers, response.content - ) - elif response.status_code == 404: - pass - else: - raise Exception( - f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}" - ) - - async def _tree_cache_generator(save_path: str): cache_rq = await _read_cache_request(save_path) yield cache_rq["headers"] @@ -111,7 +56,7 @@ async def _tree_proxy_generator( for chunk in content_chunks: content += chunk - if response_status_code == 200: + if allow_cache and response_status_code == 200: make_dirs(save_path) await _write_cache_request( save_path, response_status_code, response_headers, bytes(content) @@ -126,6 +71,7 @@ async def tree_generator( commit: str, path: str, recursive: bool, + override_cache: bool, request: Request, ): headers = {k: v for k, v in request.headers.items()} @@ -151,7 +97,7 @@ async def tree_generator( f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) # proxy - if use_cache: + if use_cache and not override_cache: async for item in _tree_cache_generator(save_path): yield item else: diff --git a/olah/server.py b/olah/server.py index cb4741f..34997a5 100644 --- a/olah/server.py +++ b/olah/server.py @@ -123,7 +123,6 @@ class AppSettings(BaseSettings): # See also: https://huggingface.co/docs/hub/api#repo-listing-api # ====================== async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: - # TODO: the head method of meta apis # FIXME: do not show the private repos to other user besides owner, even though the repo was cached if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() @@ -166,8 +165,25 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: - await meta_proxy_cache(app, repo_type, org, repo, commit, request) - generator = meta_generator(app, repo_type, org, repo, commit_sha, request) + generator = meta_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit=commit_sha, + override_cache=True, + request=request, + ) + else: + generator = meta_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit=commit_sha, + override_cache=False, + request=request, + ) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: @@ -235,7 +251,6 @@ async def tree_proxy_common( recursive: bool, request: Request, ) -> Response: - # TODO ?recursive=True # FIXME: do not show the private repos to other user besides owner, even though the repo was cached path = clean_path(path) if repo_type not in REPO_TYPES_MAPPING.keys(): @@ -279,8 +294,30 @@ async def tree_proxy_common( return error_repo_not_found() # if branch name and online mode, refresh branch info if not app.app_settings.config.offline and commit_sha != commit: - await tree_proxy_cache(app, repo_type, org, repo, commit, path, recursive, request) - generator = tree_generator(app, repo_type, org, repo, commit_sha, path, recursive, request) + generator = tree_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit_sha=commit_sha, + path=path, + recursive=recursive, + override_cache=True, + request=request, + ) + else: + generator = tree_generator( + app=app, + repo_type=repo_type, + org=org, + repo=repo, + commit_sha=commit_sha, + path=path, + recursive=recursive, + override_cache=False, + request=request, + ) + headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: @@ -379,7 +416,10 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st ) if commit_sha is None: return error_repo_not_found() - generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, request) + if not app.app_settings.config.offline and commit_sha != commit: + generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=True, request=request) + else: + generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=False, request=request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: