Skip to content

Commit

Permalink
update cache method
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Jul 25, 2024
1 parent 570e787 commit 25f07ec
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
6 changes: 4 additions & 2 deletions src/lean_dojo/data_extraction/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ def get(self, url: str, commit: str) -> Optional[Path]:
else:
return None

def store(self, src: Path) -> Path:
def store(self, src: Path, fmt_name: str = "") -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache."""
url, commit = get_repo_info(src)
dirpath = self.cache_dir / _format_dirname(url, commit)
if fmt_name == "": # if not specified, extract from the traced repo
fmt_name = _format_dirname(url, commit)
dirpath = self.cache_dir / fmt_name
_, repo_name = _split_git_url(url)
if not dirpath.exists():
with self.lock:
Expand Down
18 changes: 9 additions & 9 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
working_directory,
)
from ..constants import LEAN4_URL
from .cache import _format_dirname


GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None)
Expand Down Expand Up @@ -429,13 +430,7 @@ def __post_init__(self) -> None:
lean_version = self.commit
else:
config = self.get_config("lean-toolchain")
toolchain = config["content"]
m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip())
if m is not None:
lean_version = m["version"]
else:
# lean_version_commit = get_lean4_commit_from_config(config)
lean_version = get_lean4_version_from_config(toolchain)
lean_version = get_lean4_version_from_config(config["content"])
if not is_supported_version(lean_version):
logger.warning(
f"{self} relies on an unsupported Lean version: {lean_version}"
Expand All @@ -444,9 +439,9 @@ def __post_init__(self) -> None:
object.__setattr__(self, "lean_version", lean_version)

@classmethod
def from_path(cls, path: Path) -> "LeanGitRepo":
def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo":
"""Construct a :class:`LeanGitRepo` object from the path to a local Git repo."""
url, commit = get_repo_info(path)
url, commit = get_repo_info(Path(path))
return cls(url, commit)

@property
Expand All @@ -461,6 +456,11 @@ def is_lean4(self) -> bool:
def commit_url(self) -> str:
return f"{self.url}/tree/{self.commit}"

@property
def format_dirname(self) -> str:
"""Return the formatted cache directory name"""
return _format_dirname(self.url, self.commit)

def show(self) -> None:
"""Show the repo in the default browser."""
webbrowser.open(self.commit_url)
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
_trace(repo, build_deps)
traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps)
traced_repo.save_to_disk()
path = cache.store(tmp_dir / repo.name)
path = cache.store(tmp_dir / repo.name, repo.format_dirname)
else:
logger.debug("The traced repo is available in the cache.")
return path
Expand Down
18 changes: 18 additions & 0 deletions tests/data_extraction/test_lean_repo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# test for the class `LeanGitRepo`
from lean_dojo import LeanGitRepo
from lean_dojo.data_extraction.lean import _to_commit_hash
from lean_dojo.constants import LEAN4_URL


def test_lean_git_repo(lean4_example_url, example_commit_hash):
Expand All @@ -9,3 +11,19 @@ def test_lean_git_repo(lean4_example_url, example_commit_hash):
assert repo.exists()
assert repo.name == "lean4-example"
assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}"
# test cache directory
assert (
repo.format_dirname
== "yangky11-lean4-example-3f8c5eb303a225cdef609498b8d87262e5ef344b"
)
# test commit hash
## test commit hash
assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash
## test branch, assume the branch is not changed
assert _to_commit_hash(repo, "paper") == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c"
## test tag
lean4_repo = LeanGitRepo(LEAN4_URL, "master")
assert (
_to_commit_hash(lean4_repo, "v4.9.1")
== "1b78cb4836cf626007bd38872956a6fab8910993"
)

0 comments on commit 25f07ec

Please sign in to comment.