Skip to content

Commit

Permalink
authorization bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 5, 2024
1 parent 8c2dd3f commit 1e4d6c9
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 59 deletions.
11 changes: 6 additions & 5 deletions olah/proxy/commits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://opensource.org/licenses/MIT.

import os
from typing import Dict, Literal, Mapping
from typing import Dict, Literal, Mapping, Optional
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand Down Expand Up @@ -73,13 +73,14 @@ async def commits_generator(
repo: str,
commit: str,
override_cache: bool,
request: Request,
method: str,
authorization: Optional[str],
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")
headers = {}
if authorization is not None:
headers["authorization"] = authorization

# save
method = request.method.lower()
repos_path = app.app_settings.config.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/commits/{commit}"
Expand Down
19 changes: 14 additions & 5 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,17 @@ async def _file_realtime_stream(
if "host" in request_headers:
request_headers["host"] = urlparse(hf_url).netloc


generator = pathsinfo_generator(app, repo_type, org, repo, commit, [file_path], override_cache=False, method="post")
generator = pathsinfo_generator(
app,
repo_type,
org,
repo,
commit,
[file_path],
override_cache=False,
method="post",
authorization=request.headers.get("authorization", None),
)
status_code = await generator.__anext__()
headers = await generator.__anext__()
content = await generator.__anext__()
Expand All @@ -493,14 +502,14 @@ async def _file_realtime_stream(
yield response.headers
yield response.body
return

if len(pathsinfo) != 1:
response = error_proxy_timeout()
yield response.status_code
yield response.headers
yield response.body
return

pathinfo = pathsinfo[0]
if "size" not in pathinfo:
response = error_proxy_timeout()
Expand Down Expand Up @@ -535,7 +544,7 @@ async def _file_realtime_stream(
response_headers["etag"] = f'"{content_hash[:32]}-10"'
yield 200
yield response_headers

async with httpx.AsyncClient() as client:
if method.lower() == "get":
async for each_chunk in _file_chunk_get(
Expand Down
11 changes: 6 additions & 5 deletions olah/proxy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import shutil
import tempfile
from typing import Dict, Literal
from typing import Dict, Literal, Optional
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand Down Expand Up @@ -69,13 +69,14 @@ async def meta_generator(
repo: str,
commit: str,
override_cache: bool,
request: Request,
method: str,
authorization: Optional[str],
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")
headers = {}
if authorization is not None:
headers["authorization"] = authorization

# save
method = request.method.lower()
repos_path = app.app_settings.config.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
Expand Down
6 changes: 5 additions & 1 deletion olah/proxy/pathsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import json
import os
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Optional
from urllib.parse import quote, urljoin
from fastapi import FastAPI, Request

Expand Down Expand Up @@ -66,8 +66,12 @@ async def pathsinfo_generator(
paths: List[str],
override_cache: bool,
method: str,
authorization: Optional[str],
):
headers = {}
if authorization is not None:
headers["authorization"] = authorization

# save
repos_path = app.app_settings.config.repos_path

Expand Down
11 changes: 6 additions & 5 deletions olah/proxy/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://opensource.org/licenses/MIT.

import os
from typing import Dict, Literal, Mapping
from typing import Dict, Literal, Mapping, Optional
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand Down Expand Up @@ -75,13 +75,14 @@ async def tree_generator(
recursive: bool,
expand: bool,
override_cache: bool,
request: Request,
method: str,
authorization: Optional[str],
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")
headers = {}
if authorization is not None:
headers["authorization"] = authorization

# save
method = request.method.lower()
repos_path = app.app_settings.config.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}"
Expand Down
Loading

0 comments on commit 1e4d6c9

Please sign in to comment.