Skip to content

Commit

Permalink
add paths-info api
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 1, 2024
1 parent 6219054 commit de09a24
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 80 deletions.
169 changes: 111 additions & 58 deletions olah/mirror/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List, Union
import gitdb
from git import Commit, Optional, Repo, Tree
from git.objects.base import IndexObjUnion
from gitdb.base import OStream
import yaml

Expand Down Expand Up @@ -65,66 +66,75 @@ def _get_description(self, commit: Commit) -> str:
readme = self._get_readme(commit)
return self._remove_card(readme)

def _get_tree_files_recursive(self, tree, include_dir=False) -> List[str]:
def _get_tree_filenames_recursive(self, tree, include_dir=False) -> List[str]:
out_paths = []
for entry in tree:
if entry.type == "tree":
out_paths.extend(self._get_tree_files_recursive(entry))
out_paths.extend(self._get_tree_filenames_recursive(entry))
if include_dir:
out_paths.append(entry.path)
else:
out_paths.append(entry.path)
return out_paths

def _get_commit_files_recursive(self, commit: Commit) -> List[str]:
return self._get_tree_files_recursive(commit.tree)
def _get_commit_filenames_recursive(self, commit: Commit) -> List[str]:
return self._get_tree_filenames_recursive(commit.tree)

def _get_tree_files(self, tree: Tree) -> List[Dict[str, Union[int, str]]]:
entries = []
for entry in tree:
lfs = False
if entry.type != "tree":
t = "file"
repr_size = entry.size
if repr_size > 120 and repr_size < 150:
# check lfs
lfs_data = entry.data_stream.read().decode("utf-8")
match_groups = re.match(
r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
lfs_data,
)
if match_groups is not None:
lfs = True
sha256 = match_groups.group(1)
repr_size = int(match_groups.group(2))
lfs_data = {
"oid": sha256,
"size": repr_size,
"pointerSize": entry.size,
}
else:
t = "directory"
repr_size = 0

if not lfs:
entries.append(
{
"type": t,
"oid": entry.hexsha,
"size": repr_size,
"path": entry.name,
}
def _get_path_info(self, entry: IndexObjUnion) -> Dict[str, Union[int, str]]:
lfs = False
if entry.type != "tree":
t = "file"
repr_size = entry.size
if repr_size > 120 and repr_size < 150:
# check lfs
lfs_data = entry.data_stream.read().decode("utf-8")
match_groups = re.match(
r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
lfs_data,
)
else:
entries.append(
{
"type": t,
"oid": entry.hexsha,
if match_groups is not None:
lfs = True
sha256 = match_groups.group(1)
repr_size = int(match_groups.group(2))
lfs_data = {
"oid": sha256,
"size": repr_size,
"path": entry.name,
"lfs": lfs_data,
"pointerSize": entry.size,
}
)
else:
t = "directory"
repr_size = entry.size

if not lfs:
item = {
"type": t,
"oid": entry.hexsha,
"size": repr_size,
"path": entry.path,
"name": entry.name,
}
else:
item = {
"type": t,
"oid": entry.hexsha,
"size": repr_size,
"path": entry.path,
"name": entry.name,
"lfs": lfs_data,
}
return item

def _get_tree_files(
self, tree: Tree, recursive: bool = False
) -> List[Dict[str, Union[int, str]]]:
entries = []
for entry in tree:
entries.append(self._get_path_info(entry=entry))

if recursive:
for entry in tree:
if entry.type == "tree":
entries.extend(self._get_tree_files(entry, recursive=recursive))
return entries

def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]:
Expand All @@ -143,24 +153,67 @@ def _get_earliest_commit(self) -> Commit:

return earliest_commit

def get_tree(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]:
def get_index_object_by_path(
self, commit_hash: str, path: str
) -> Optional[IndexObjUnion]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

path_part = path.split("/")
path_part = [part for part in path_part if len(part.strip()) != 0]
tree = commit.tree
items = self._get_tree_files(tree=tree)
for part in path_part:
if len(part.strip()) == 0:
continue
if part not in [
item["path"] for item in items if item["type"] == "directory"
]:
return None
if len(path_part) == 0:
return None
for i, part in enumerate(path_part):
if i != len(path_part) - 1:
if part not in [
item["name"] for item in items if item["type"] == "directory"
]:
return None
else:
if part not in [
item["name"] for item in items
]:
return None
tree = tree[part]
items = self._get_tree_files(tree=tree)
if tree.type == "tree":
items = self._get_tree_files(tree=tree, recursive=False)
return tree

def get_pathinfos(
self, commit_hash: str, paths: List[str]
) -> Optional[List[Dict[str, Any]]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

results = []
for path in paths:
index_obj = self.get_index_object_by_path(
commit_hash=commit_hash, path=path
)
if index_obj is not None:
results.append(self._get_path_info(index_obj))

for r in results:
r.pop("name")
return results

def get_tree(
self, commit_hash: str, path: str, recursive: bool = False
) -> Optional[Dict[str, Any]]:
try:
commit = self._git_repo.commit(commit_hash)
except gitdb.exc.BadName:
return None

index_obj = self.get_index_object_by_path(commit_hash=commit_hash, path=path)
items = self._get_tree_files(tree=index_obj, recursive=recursive)
for r in items:
r.pop("name")
return items

def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -189,7 +242,7 @@ def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]:
self._match_card(self._get_readme(commit)), Loader=yaml.CLoader
)
meta.siblings = [
{"rfilename": p} for p in self._get_commit_files_recursive(commit)
{"rfilename": p} for p in self._get_commit_filenames_recursive(commit)
]
meta.createdAt = self._get_earliest_commit().committed_datetime.strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
Expand Down
108 changes: 108 additions & 0 deletions olah/proxy/pathsinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import json
import os
from typing import Dict, List, Literal
from urllib.parse import quote, urljoin
from fastapi import FastAPI, Request

import httpx
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.cache_utils import _read_cache_request, _write_cache_request
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs


async def _pathsinfo_cache(save_path: str):
cache_rq = await _read_cache_request(save_path)
return cache_rq["status_code"], cache_rq["headers"], cache_rq["content"]


async def _pathsinfo_proxy(
app: FastAPI,
headers: Dict[str, str],
pathsinfo_url: str,
allow_cache: bool,
method: str,
path: str,
save_path: str,
):
headers = {k: v for k, v in headers.items()}
headers.pop("content-length")
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.request(
method=method,
url=pathsinfo_url,
headers=headers,
data={"paths": path},
timeout=WORKER_API_TIMEOUT,
)

if response.status_code == 200:
make_dirs(save_path)
await _write_cache_request(
save_path,
response.status_code,
response.headers,
bytes(response.content),
)
return response.status_code, response.headers, response.content


async def pathsinfo_generator(
app: FastAPI,
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
commit: str,
paths: List[str],
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path

final_content = []
for path in paths:
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/paths-info/{commit}/{path}"
)

save_path = os.path.join(save_dir, f"paths-info_{method}.json")

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

org_repo = get_org_repo(org, repo)
pathsinfo_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/paths-info/{commit}",
)
# proxy
if use_cache:
status, headers, content = await _pathsinfo_cache(save_path)
else:
print(path)
status, headers, content = await _pathsinfo_proxy(
app, headers, pathsinfo_url, allow_cache, method, path, save_path
)

try:
content_json = json.loads(content)
except json.JSONDecodeError:
continue
if status == 200 and isinstance(content_json, list):
final_content.extend(content_json)

yield {'content-type': 'application/json'}
yield json.dumps(final_content, ensure_ascii=True)
Loading

0 comments on commit de09a24

Please sign in to comment.