Skip to content

Commit

Permalink
tree api bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Aug 22, 2024
1 parent 5b74eba commit 3501520
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 28 deletions.
10 changes: 10 additions & 0 deletions olah/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ def error_page_not_found() -> Response:
},
status_code=404,
)

def error_entry_not_found(branch: str, path: str) -> Response:
return Response(
headers={
"x-error-code": "EntryNotFound",
"x-error-message": f"{path} does not exist on \"{branch}\"",
},
status_code=404,
)

85 changes: 79 additions & 6 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,70 @@ def _get_description(self, commit: Commit) -> str:
readme = self._get_readme(commit)
return self._remove_card(readme)

def _get_entry_files(self, tree, include_dir=False) -> List[str]:
def _get_tree_files_recursive(self, tree, include_dir=False) -> List[str]:
out_paths = []
for entry in tree:
if entry.type == "tree":
out_paths.extend(self._get_entry_files(entry))
out_paths.extend(self._get_tree_files_recursive(entry))
if include_dir:
out_paths.append(entry.path)
else:
out_paths.append(entry.path)
return out_paths

def _get_tree_files(self, commit: Commit) -> List[str]:
return self._get_entry_files(commit.tree)
def _get_commit_files_recursive(self, commit: Commit) -> List[str]:
return self._get_tree_files_recursive(commit.tree)

def _get_tree_files(self, tree: Tree) -> List[Dict[str, Union[int, str]]]:
entries = []
for entry in tree:
lfs = False
if entry.type != "tree":
t = "file"
repr_size = entry.size
if repr_size > 120 and repr_size < 150:
# check lfs
lfs_data = entry.data_stream.read().decode("utf-8")
match_groups = re.match(
r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
lfs_data,
)
if match_groups is not None:
lfs = True
sha256 = match_groups.group(1)
repr_size = int(match_groups.group(2))
lfs_data = {
"oid": sha256,
"size": repr_size,
"pointerSize": entry.size,
}
else:
t = "directory"
repr_size = 0

if not lfs:
entries.append(
{
"type": t,
"oid": entry.hexsha,
"size": repr_size,
"path": entry.name,
}
)
else:
entries.append(
{
"type": t,
"oid": entry.hexsha,
"size": repr_size,
"path": entry.name,
"lfs": lfs_data,
}
)
return entries

def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]:
return self._get_tree_files(commit.tree)

def _get_earliest_commit(self) -> Commit:
earliest_commit = None
Expand All @@ -92,7 +143,27 @@ def _get_earliest_commit(self) -> Commit:

return earliest_commit

def get_meta(self, commit_hash: str) -> Dict[str, Any]:
def get_tree(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

path_part = path.split("/")
tree = commit.tree
items = self._get_tree_files(tree=tree)
for part in path_part:
if len(part.strip()) == 0:
continue
if part not in [
item["path"] for item in items if item["type"] == "directory"
]:
return None
tree = tree[part]
items = self._get_tree_files(tree=tree)
return items

def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
Expand All @@ -117,7 +188,9 @@ def get_meta(self, commit_hash: str) -> Dict[str, Any]:
meta.cardData = yaml.load(
self._match_card(self._get_readme(commit)), Loader=yaml.CLoader
)
meta.siblings = [{"rfilename": p} for p in self._get_tree_files(commit)]
meta.siblings = [
{"rfilename": p} for p in self._get_commit_files_recursive(commit)
]
meta.createdAt = self._get_earliest_commit().committed_datetime.strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
Expand Down
23 changes: 14 additions & 9 deletions olah/proxy/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def tree_proxy_cache(
org: str,
repo: str,
commit: str,
path: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
Expand All @@ -42,16 +43,15 @@ async def tree_proxy_cache(
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}"
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}"
)
save_path = os.path.join(save_dir, f"tree_{method}.json")
make_dirs(save_path)

# url
org_repo = get_org_repo(org, repo)
tree_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/tree/{commit}",
f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}",
)
headers = {}
if "authorization" in request.headers:
Expand All @@ -65,9 +65,12 @@ async def tree_proxy_cache(
follow_redirects=True,
)
if response.status_code == 200:
make_dirs(save_path)
await _write_cache_request(
save_path, response.status_code, response.headers, response.content
)
elif response.status_code == 404:
pass
else:
raise Exception(
f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}"
Expand Down Expand Up @@ -104,9 +107,11 @@ async def _tree_proxy_generator(
for chunk in content_chunks:
content += chunk

await _write_cache_request(
save_path, response_status_code, response_headers, bytes(content)
)
if response_status_code == 200:
make_dirs(save_path)
await _write_cache_request(
save_path, response_status_code, response_headers, bytes(content)
)


async def tree_generator(
Expand All @@ -115,6 +120,7 @@ async def tree_generator(
org: str,
repo: str,
commit: str,
path: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
Expand All @@ -124,18 +130,17 @@ async def tree_generator(
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}"
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}"
)
save_path = os.path.join(save_dir, f"tree_{method}.json")
make_dirs(save_path)

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

org_repo = get_org_repo(org, repo)
tree_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/tree/{commit}",
f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}",
)
# proxy
if use_cache:
Expand Down
27 changes: 14 additions & 13 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import httpx

from olah.proxy.tree import tree_generator, tree_proxy_cache
from olah.utils.url_utils import clean_path

BASE_SETTINGS = False
if not BASE_SETTINGS:
Expand Down Expand Up @@ -223,9 +224,10 @@ async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request:
)

# Git Tree
async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, request: Request) -> Response:
async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, path: str, request: Request) -> Response:
# TODO: the head method of meta apis
# FIXME: do not show the private repos to other user besides owner, even though the repo was cached
path = clean_path(path)
if repo_type not in REPO_TYPES_MAPPING.keys():
return error_page_not_found()
if not await check_proxy_rules_hf(app, repo_type, org, repo):
Expand All @@ -236,8 +238,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re
git_path = os.path.join(mirror_path, repo_type, org, repo)
if os.path.exists(git_path):
local_repo = LocalMirrorRepo(git_path, repo_type, org, repo)
# TODO: Local git repo trees
tree_data = local_repo.get_tree(commit)
tree_data = local_repo.get_tree(commit, path)
if tree_data is None:
continue
return JSONResponse(content=tree_data)
Expand Down Expand Up @@ -268,33 +269,33 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re
return error_repo_not_found()
# if branch name and online mode, refresh branch info
if not app.app_settings.config.offline and commit_sha != commit:
await tree_proxy_cache(app, repo_type, org, repo, commit, request)
generator = tree_generator(app, repo_type, org, repo, commit_sha, request)
await tree_proxy_cache(app, repo_type, org, repo, commit, path, request)
generator = tree_generator(app, repo_type, org, repo, commit_sha, path, request)
headers = await generator.__anext__()
return StreamingResponse(generator, headers=headers)
except httpx.ConnectTimeout:
traceback.print_exc()
return Response(status_code=504)

@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}")
@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}")
@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}")
@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}/{file_path:path}")
async def tree_proxy_commit2(
repo_type: str, org: str, repo: str, commit: str, request: Request
repo_type: str, org: str, repo: str, commit: str, file_path: str, request: Request
):
return await tree_proxy_common(
repo_type=repo_type, org=org, repo=repo, commit=commit, request=request
repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request
)


@app.head("/api/{repo_type}/{org_repo}/tree/{commit}")
@app.get("/api/{repo_type}/{org_repo}/tree/{commit}")
async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request):
@app.head("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}")
@app.get("/api/{repo_type}/{org_repo}/tree/{commit}/{file_path:path}")
async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, file_path: str, request: Request):
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return error_repo_not_found()

return await tree_proxy_common(
repo_type=repo_type, org=org, repo=repo, commit=commit, request=request
repo_type=repo_type, org=org, repo=repo, commit=commit, path=file_path, request=request
)


Expand Down
9 changes: 9 additions & 0 deletions olah/utils/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,12 @@ def remove_query_param(url: str, param_name: str) -> str:
new_url = urlunparse(parsed_url._replace(query=new_query))

return new_url


def clean_path(path: str) -> str:
while ".." in path:
path = path.replace("..", "")
path = path.replace("\\", "/")
while "//" in path:
path = path.replace("//", "/")
return path

0 comments on commit 3501520

Please sign in to comment.