Skip to content

Commit

Permalink
More precise error message, custom error types for proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 4, 2024
1 parent e4ec74d commit c4be6b2
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 15 deletions.
35 changes: 32 additions & 3 deletions olah/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ def error_repo_not_found() -> JSONResponse:
)


def error_page_not_found() -> Response:
return Response(
def error_page_not_found() -> JSONResponse:
return JSONResponse(
content={"error":"Sorry, we can't find the page you are looking for."},
headers={
"x-error-code": "RepoNotFound",
"x-error-message": "Sorry, we can't find the page you are looking for.",
},
status_code=404,
)

def error_entry_not_found(branch: str, path: str) -> Response:
def error_entry_not_found_branch(branch: str, path: str) -> Response:
return Response(
headers={
"x-error-code": "EntryNotFound",
Expand All @@ -38,6 +39,15 @@ def error_entry_not_found(branch: str, path: str) -> Response:
status_code=404,
)

def error_entry_not_found() -> Response:
return Response(
headers={
"x-error-code": "EntryNotFound",
"x-error-message": "Entry not found",
},
status_code=404,
)

def error_revision_not_found(revision: str) -> Response:
return JSONResponse(
content={"error": f"Invalid rev id: {revision}"},
Expand All @@ -47,3 +57,22 @@ def error_revision_not_found(revision: str) -> Response:
},
status_code=404,
)

# Olah Custom Messages
def error_proxy_timeout() -> Response:
return Response(
headers={
"x-error-code": "ProxyTimeout",
"x-error-message": "Proxy Timeout",
},
status_code=504,
)

def error_proxy_invalid_data() -> Response:
return Response(
headers={
"x-error-code": "ProxyInvalidData",
"x-error-message": "Proxy Invalid Data",
},
status_code=504,
)
24 changes: 24 additions & 0 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,30 @@ def get_tree(
for r in items:
r.pop("name")
return items

def get_commits(self, commit_hash: str) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

parent_commits: List[Commit] = list(commit.parents)
parent_commits = parent_commits.insert(0, commit)
items = []
for each_commit in parent_commits:
item = {
"id": each_commit.hexsha,
"title": each_commit.message,
"message": "",
"authors": [],
"date": each_commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
}
item["authors"].append({
"name": each_commit.author.name,
"avatar": None
})
items.append(item)
return items

def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]:
try:
Expand Down
105 changes: 105 additions & 0 deletions olah/proxy/commits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os
from typing import Dict, Literal, Mapping
from urllib.parse import urljoin
from fastapi import FastAPI, Request

import httpx
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.cache_utils import read_cache_request, write_cache_request
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs


async def _commits_cache_generator(save_path: str):
cache_rq = await read_cache_request(save_path)
yield cache_rq["status_code"]
yield cache_rq["headers"]
yield cache_rq["content"]


async def _commits_proxy_generator(
app: FastAPI,
headers: Dict[str, str],
commits_url: str,
method: str,
params: Mapping[str, str],
allow_cache: bool,
save_path: str,
):
async with httpx.AsyncClient(follow_redirects=True) as client:
content_chunks = []
async with client.stream(
method=method,
url=commits_url,
params=params,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_status_code = response.status_code
response_headers = response.headers
yield response_status_code
yield response_headers

async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
content_chunks.append(raw_chunk)
yield raw_chunk

content = bytearray()
for chunk in content_chunks:
content += chunk

if allow_cache and response_status_code == 200:
make_dirs(save_path)
await write_cache_request(
save_path, response_status_code, response_headers, bytes(content)
)


async def commits_generator(
app: FastAPI,
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
commit: str,
override_cache: bool,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.config.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/commits/{commit}"
)
save_path = os.path.join(save_dir, f"commits_{method}.json")

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

org_repo = get_org_repo(org, repo)
commits_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/commits/{commit}",
)
# proxy
if use_cache and not override_cache:
async for item in _commits_cache_generator(save_path):
yield item
else:
async for item in _commits_proxy_generator(
app, headers, commits_url, method, {}, allow_cache, save_path
):
yield item
30 changes: 21 additions & 9 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ORIGINAL_LOC,
)
from olah.cache.olah_cache import OlahCache
from olah.errors import error_entry_not_found, error_proxy_invalid_data, error_proxy_timeout
from olah.proxy.pathsinfo import pathsinfo_generator
from olah.utils.cache_utils import read_cache_request, write_cache_request
from olah.utils.disk_utils import touch_file_access_time
Expand Down Expand Up @@ -474,27 +475,38 @@ async def _file_realtime_stream(


generator = pathsinfo_generator(app, repo_type, org, repo, commit, [file_path], override_cache=False, method="post")
status_code = await generator.__anext__()
headers = await generator.__anext__()
content = await generator.__anext__()
try:
pathsinfo = json.loads(content)
except json.JSONDecodeError:
yield 504
yield {}
yield b""
response = error_proxy_invalid_data()
yield response.status_code
yield response.headers
yield response.body
return

if len(pathsinfo) == 0:
response = error_entry_not_found()
yield response.status_code
yield response.headers
yield response.body
return

if len(pathsinfo) != 1:
yield 504
yield {}
yield b""
response = error_proxy_timeout()
yield response.status_code
yield response.headers
yield response.body
return

pathinfo = pathsinfo[0]
if "size" not in pathinfo:
yield 504
yield {}
yield b""
response = error_proxy_timeout()
yield response.status_code
yield response.headers
yield response.body
return
file_size = pathinfo["size"]

Expand Down
1 change: 1 addition & 0 deletions olah/proxy/pathsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,6 @@ async def pathsinfo_generator(
if status == 200 and isinstance(content_json, list):
final_content.extend(content_json)

yield 200
yield {'content-type': 'application/json'}
yield json.dumps(final_content, ensure_ascii=True)
Loading

0 comments on commit c4be6b2

Please sign in to comment.