Skip to content

Commit

Permalink
add expand param to tree api
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 1, 2024
1 parent c5d4b8f commit ba7ba18
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
32 changes: 26 additions & 6 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_tree_filenames_recursive(self, tree, include_dir=False) -> List[str]:
def _get_commit_filenames_recursive(self, commit: Commit) -> List[str]:
return self._get_tree_filenames_recursive(commit.tree)

def _get_path_info(self, entry: IndexObjUnion) -> Dict[str, Union[int, str]]:
def _get_path_info(self, entry: IndexObjUnion, expand: bool=False) -> Dict[str, Union[int, str]]:
lfs = False
if entry.type != "tree":
t = "file"
Expand Down Expand Up @@ -122,19 +122,39 @@ def _get_path_info(self, entry: IndexObjUnion) -> Dict[str, Union[int, str]]:
"name": entry.name,
"lfs": lfs_data,
}
if expand:
last_commit = next(self._git_repo.iter_commits(paths=entry.path, max_count=1))
item["lastCommit"] = {
"id": last_commit.hexsha,
"title": last_commit.message,
"date": last_commit.committed_datetime.strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
}
item["security"] = {
"blobId": entry.hexsha,
"name": entry.name,
"safe": True,
"indexed": False,
"avScan": {
"virusFound": False,
"virusNames": None
},
"pickleImportScan": None
}
return item

def _get_tree_files(
self, tree: Tree, recursive: bool = False
self, tree: Tree, recursive: bool = False, expand: bool = False
) -> List[Dict[str, Union[int, str]]]:
entries = []
for entry in tree:
entries.append(self._get_path_info(entry=entry))
entries.append(self._get_path_info(entry=entry, expand=expand))

if recursive:
for entry in tree:
if entry.type == "tree":
entries.extend(self._get_tree_files(entry, recursive=recursive))
entries.extend(self._get_tree_files(entry, recursive=recursive, expand=expand))
return entries

def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]:
Expand Down Expand Up @@ -203,15 +223,15 @@ def get_pathinfos(
return results

def get_tree(
self, commit_hash: str, path: str, recursive: bool = False
self, commit_hash: str, path: str, recursive: bool = False, expand: bool = False
) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

index_obj = self.get_index_object_by_path(commit_hash=commit_hash, path=path)
items = self._get_tree_files(tree=index_obj, recursive=recursive)
items = self._get_tree_files(tree=index_obj, recursive=recursive, expand=expand)
for r in items:
r.pop("name")
return items
Expand Down
14 changes: 6 additions & 8 deletions olah/proxy/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://opensource.org/licenses/MIT.

import os
from typing import Dict, Literal
from typing import Dict, Literal, Mapping
from urllib.parse import urljoin
from fastapi import FastAPI, Request

Expand All @@ -29,7 +29,7 @@ async def _tree_proxy_generator(
headers: Dict[str, str],
tree_url: str,
method: str,
recursive: bool,
params: Mapping[str, str],
allow_cache: bool,
save_path: str,
):
Expand All @@ -38,7 +38,7 @@ async def _tree_proxy_generator(
async with client.stream(
method=method,
url=tree_url,
params={"recursive": recursive},
params=params,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
Expand Down Expand Up @@ -71,6 +71,7 @@ async def tree_generator(
commit: str,
path: str,
recursive: bool,
expand: bool,
override_cache: bool,
request: Request,
):
Expand All @@ -83,10 +84,7 @@ async def tree_generator(
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}"
)
if not recursive:
save_path = os.path.join(save_dir, f"tree_{method}.json")
else:
save_path = os.path.join(save_dir, f"tree_{method}_recursive.json")
save_path = os.path.join(save_dir, f"tree_{method}_recursive_{recursive}_expand_{expand}.json")

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)
Expand All @@ -102,6 +100,6 @@ async def tree_generator(
yield item
else:
async for item in _tree_proxy_generator(
app, headers, tree_url, method, recursive, allow_cache, save_path
app, headers, tree_url, method, {"recursive": recursive, "expand": expand}, allow_cache, save_path
):
yield item
10 changes: 9 additions & 1 deletion olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def check_hf_connection() -> None:

@asynccontextmanager
async def lifespan(app: FastAPI):
# TODO: Check repo cache path
await check_hf_connection()
yield

Expand Down Expand Up @@ -249,6 +250,7 @@ async def tree_proxy_common(
commit: str,
path: str,
recursive: bool,
expand: bool,
request: Request,
) -> Response:
# FIXME: do not show the private repos to other user besides owner, even though the repo was cached
Expand All @@ -263,7 +265,7 @@ async def tree_proxy_common(
git_path = os.path.join(mirror_path, repo_type, org, repo)
if os.path.exists(git_path):
local_repo = LocalMirrorRepo(git_path, repo_type, org, repo)
tree_data = local_repo.get_tree(commit, path, recursive=recursive)
tree_data = local_repo.get_tree(commit, path, recursive=recursive, expand=expand)
if tree_data is None:
continue
return JSONResponse(content=tree_data)
Expand Down Expand Up @@ -302,6 +304,7 @@ async def tree_proxy_common(
commit=commit_sha,
path=path,
recursive=recursive,
expand=expand,
override_cache=True,
request=request,
)
Expand All @@ -314,6 +317,7 @@ async def tree_proxy_common(
commit=commit_sha,
path=path,
recursive=recursive,
expand=expand,
override_cache=False,
request=request,
)
Expand All @@ -335,6 +339,7 @@ async def tree_proxy_commit2(
file_path: str,
request: Request,
recursive: bool = False,
expand: bool=False,
):
return await tree_proxy_common(
repo_type=repo_type,
Expand All @@ -343,6 +348,7 @@ async def tree_proxy_commit2(
commit=commit,
path=file_path,
recursive=recursive,
expand=expand,
request=request,
)

Expand All @@ -356,6 +362,7 @@ async def tree_proxy_commit(
file_path: str,
request: Request,
recursive: bool = False,
expand: bool=False,
):
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
Expand All @@ -368,6 +375,7 @@ async def tree_proxy_commit(
commit=commit,
path=file_path,
recursive=recursive,
expand=expand,
request=request,
)

Expand Down

0 comments on commit ba7ba18

Please sign in to comment.