From 1e4d6c9fcd6fef2a4f4794bc8a69defa31f96c82 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 5 Sep 2024 13:24:24 +0800 Subject: [PATCH] authorization bug fix --- olah/proxy/commits.py | 11 +-- olah/proxy/files.py | 19 +++-- olah/proxy/meta.py | 11 +-- olah/proxy/pathsinfo.py | 6 +- olah/proxy/tree.py | 11 +-- olah/server.py | 167 +++++++++++++++++++++++++++++++--------- 6 files changed, 166 insertions(+), 59 deletions(-) diff --git a/olah/proxy/commits.py b/olah/proxy/commits.py index 3b6db1f..4e3323e 100644 --- a/olah/proxy/commits.py +++ b/olah/proxy/commits.py @@ -6,7 +6,7 @@ # https://opensource.org/licenses/MIT. import os -from typing import Dict, Literal, Mapping +from typing import Dict, Literal, Mapping, Optional from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -73,13 +73,14 @@ async def commits_generator( repo: str, commit: str, override_cache: bool, - request: Request, + method: str, + authorization: Optional[str], ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") + headers = {} + if authorization is not None: + headers["authorization"] = authorization # save - method = request.method.lower() repos_path = app.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/commits/{commit}" diff --git a/olah/proxy/files.py b/olah/proxy/files.py index b9e4ce3..0f17c59 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -473,8 +473,17 @@ async def _file_realtime_stream( if "host" in request_headers: request_headers["host"] = urlparse(hf_url).netloc - - generator = pathsinfo_generator(app, repo_type, org, repo, commit, [file_path], override_cache=False, method="post") + generator = pathsinfo_generator( + app, + repo_type, + org, + repo, + commit, + [file_path], + override_cache=False, + method="post", + authorization=request.headers.get("authorization", None), + ) status_code = await generator.__anext__() headers = await generator.__anext__() content = await generator.__anext__() @@ -493,14 +502,14 @@ async def _file_realtime_stream( yield response.headers yield response.body return - + if len(pathsinfo) != 1: response = error_proxy_timeout() yield response.status_code yield response.headers yield response.body return - + pathinfo = pathsinfo[0] if "size" not in pathinfo: response = error_proxy_timeout() @@ -535,7 +544,7 @@ async def _file_realtime_stream( response_headers["etag"] = f'"{content_hash[:32]}-10"' yield 200 yield response_headers - + async with httpx.AsyncClient() as client: if method.lower() == "get": async for each_chunk in _file_chunk_get( diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index a5ae97b..6f0f904 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -8,7 +8,7 @@ import os import shutil import tempfile -from typing import Dict, Literal +from typing import Dict, Literal, Optional from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -69,13 +69,14 @@ async def meta_generator( repo: str, commit: str, override_cache: bool, - request: Request, + method: str, + authorization: Optional[str], ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") + headers = {} + if authorization is not None: + headers["authorization"] = authorization # save - method = request.method.lower() repos_path = app.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}" diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py index 0c65643..1e9115e 100644 --- a/olah/proxy/pathsinfo.py +++ b/olah/proxy/pathsinfo.py @@ -7,7 +7,7 @@ import json import os -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Optional from urllib.parse import quote, urljoin from fastapi import FastAPI, Request @@ -66,8 +66,12 @@ async def pathsinfo_generator( paths: List[str], override_cache: bool, method: str, + authorization: Optional[str], ): headers = {} + if authorization is not None: + headers["authorization"] = authorization + # save repos_path = app.app_settings.config.repos_path diff --git a/olah/proxy/tree.py b/olah/proxy/tree.py index 9738b78..22e4467 100644 --- a/olah/proxy/tree.py +++ b/olah/proxy/tree.py @@ -6,7 +6,7 @@ # https://opensource.org/licenses/MIT. import os -from typing import Dict, Literal, Mapping +from typing import Dict, Literal, Mapping, Optional from urllib.parse import urljoin from fastapi import FastAPI, Request @@ -75,13 +75,14 @@ async def tree_generator( recursive: bool, expand: bool, override_cache: bool, - request: Request, + method: str, + authorization: Optional[str], ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") + headers = {} + if authorization is not None: + headers["authorization"] = authorization # save - method = request.method.lower() repos_path = app.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" diff --git a/olah/server.py b/olah/server.py index 9da33e8..aea230a 100644 --- a/olah/server.py +++ b/olah/server.py @@ -182,7 +182,7 @@ async def custom_404_handler(_, __): # File Meta Info API Hooks # See also: https://huggingface.co/docs/hub/api#repo-listing-api # ====================== -async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: +async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, method: str, authorization: Optional[str]) -> Response: # FIXME: do not show the private repos to other user besides owner, even though the repo was cached if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() @@ -206,15 +206,15 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re try: if not app.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_repo_not_found() if not await check_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_revision_not_found(revision=commit) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ) if commit_sha is None: return error_repo_not_found() @@ -227,7 +227,8 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re repo=repo, commit=commit_sha, override_cache=True, - request=request, + method=method, + authorization=authorization, ) else: generator = meta_generator( @@ -237,7 +238,8 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re repo=repo, commit=commit_sha, override_cache=False, - request=request, + method=method, + authorization=authorization, ) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) @@ -259,7 +261,12 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): else: new_commit = "main" return await meta_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=new_commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) @app.head("/api/{repo_type}/{org}/{repo}") @@ -272,27 +279,46 @@ async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): else: new_commit = "main" return await meta_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=new_commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) + @app.head("/api/{repo_type}/{org}/{repo}/revision/{commit}") @app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") async def meta_proxy_commit2( repo_type: str, org: str, repo: str, commit: str, request: Request ): return await meta_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) + @app.head("/api/{repo_type}/{org_repo}/revision/{commit}") @app.get("/api/{repo_type}/{org_repo}/revision/{commit}") -async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): +async def meta_proxy_commit( + repo_type: str, org_repo: str, commit: str, request: Request +): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await meta_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) @@ -305,7 +331,8 @@ async def tree_proxy_common( path: str, recursive: bool, expand: bool, - request: Request, + method: str, + authorization: Optional[str] ) -> Response: # FIXME: do not show the private repos to other user besides owner, even though the repo was cached path = clean_path(path) @@ -331,15 +358,15 @@ async def tree_proxy_common( try: if not app.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_repo_not_found() if not await check_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_revision_not_found(revision=commit) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ) if commit_sha is None: return error_repo_not_found() @@ -355,7 +382,8 @@ async def tree_proxy_common( recursive=recursive, expand=expand, override_cache=True, - request=request, + method=method, + authorization=authorization, ) else: generator = tree_generator( @@ -368,7 +396,8 @@ async def tree_proxy_common( recursive=recursive, expand=expand, override_cache=False, - request=request, + method=method, + authorization=authorization, ) status_code = await generator.__anext__() @@ -399,7 +428,8 @@ async def tree_proxy_commit2( path=file_path, recursive=recursive, expand=expand, - request=request, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) @@ -426,11 +456,12 @@ async def tree_proxy_commit( path=file_path, recursive=recursive, expand=expand, - request=request, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) # Git Pathsinfo -async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: str, paths: List[str], request: Request) -> Response: +async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: str, paths: List[str], method: str, authorization: Optional[str]) -> Response: # TODO: the head method of meta apis # FIXME: do not show the private repos to other user besides owner, even though the repo was cached paths = [clean_path(path) for path in paths] @@ -456,22 +487,42 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st try: if not app.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_repo_not_found() if not await check_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_revision_not_found(revision=commit) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ) if commit_sha is None: return error_repo_not_found() if not app.app_settings.config.offline and commit_sha != commit: - generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=True, method=request.method.lower()) + generator = pathsinfo_generator( + app, + repo_type, + org, + repo, + commit_sha, + paths, + override_cache=True, + method=method, + authorization=authorization, + ) else: - generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=False, method=request.method.lower()) + generator = pathsinfo_generator( + app, + repo_type, + org, + repo, + commit_sha, + paths, + override_cache=False, + method=method, + authorization=authorization, + ) status_code = await generator.__anext__() headers = await generator.__anext__() return StreamingResponse(generator, status_code=status_code, headers=headers) @@ -479,29 +530,54 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st traceback.print_exc() return Response(status_code=504) + @app.head("/api/{repo_type}/{org}/{repo}/paths-info/{commit}") @app.post("/api/{repo_type}/{org}/{repo}/paths-info/{commit}") async def pathsinfo_proxy_commit2( - repo_type: str, org: str, repo: str, commit: str, paths: Annotated[List[str], Form()], request: Request + repo_type: str, + org: str, + repo: str, + commit: str, + paths: Annotated[List[str], Form()], + request: Request, ): return await pathsinfo_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, paths=paths, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + paths=paths, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) @app.head("/api/{repo_type}/{org_repo}/paths-info/{commit}") @app.post("/api/{repo_type}/{org_repo}/paths-info/{commit}") -async def pathsinfo_proxy_commit(repo_type: str, org_repo: str, commit: str, paths: Annotated[List[str], Form()], request: Request): +async def pathsinfo_proxy_commit( + repo_type: str, + org_repo: str, + commit: str, + paths: Annotated[List[str], Form()], + request: Request, +): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await pathsinfo_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, paths=paths, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + paths=paths, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) + # Git Commits -async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response: +async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, method: str, authorization: Optional[str]) -> Response: # FIXME: do not show the private repos to other user besides owner, even though the repo was cached if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() @@ -525,15 +601,15 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, try: if not app.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_repo_not_found() if not await check_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ): return error_revision_not_found(revision=commit) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit=commit, - authorization=request.headers.get("authorization", None), + authorization=authorization, ) if commit_sha is None: return error_repo_not_found() @@ -546,7 +622,8 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, repo=repo, commit=commit_sha, override_cache=True, - request=request, + method=method, + authorization=authorization, ) else: generator = commits_generator( @@ -556,7 +633,8 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, repo=repo, commit=commit_sha, override_cache=False, - request=request, + method=method, + authorization=authorization, ) status_code = await generator.__anext__() headers = await generator.__anext__() @@ -572,21 +650,34 @@ async def commits_proxy_commit2( repo_type: str, org: str, repo: str, commit: str, request: Request ): return await commits_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) @app.head("/api/{repo_type}/{org_repo}/commits/{commit}") @app.get("/api/{repo_type}/{org_repo}/commits/{commit}") -async def commits_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): +async def commits_proxy_commit( + repo_type: str, org_repo: str, commit: str, request: Request +): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() return await commits_proxy_common( - repo_type=repo_type, org=org, repo=repo, commit=commit, request=request + repo_type=repo_type, + org=org, + repo=repo, + commit=commit, + method=request.method.lower(), + authorization=request.headers.get("authorization", None), ) + # ====================== # Authentication API Hooks # ======================