Skip to content

Commit

Permalink
Merge branch 'simplify-funcs'
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Jul 29, 2024
2 parents 0a8796d + e5941ff commit 00c332a
Show file tree
Hide file tree
Showing 14 changed files with 488 additions and 354 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,6 @@ dmypy.json

# Pyre type checker
.pyre/
testdata
.vscode

# vscode debug config
.vscode/
5 changes: 0 additions & 5 deletions src/lean_dojo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,3 @@
from .interaction.parse_goals import Declaration, Goal, parse_goals
from .data_extraction.lean import get_latest_commit, LeanGitRepo, LeanFile, Theorem, Pos
from .constants import __version__

if os.geteuid() == 0:
logger.warning(
"Running LeanDojo as the root user may cause unexpected issues. Proceed with caution."
)
34 changes: 21 additions & 13 deletions src/lean_dojo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,28 @@
assert re.fullmatch(r"\d+g", TACTIC_MEMORY_LIMIT)


def check_git_version(min_version: Tuple[int, int, int]) -> Tuple[int, int, int]:
def check_git_version(min_version: Tuple[int, int, int]) -> None:
"""Check the version of Git installed on the system."""
res = subprocess.run("git --version", shell=True, capture_output=True, check=True)
output = res.stdout.decode()
error = res.stderr.decode()
assert error == "", error
m = re.match(r"git version (?P<version>[0-9.]+)", output)
version = tuple(int(_) for _ in m["version"].split("."))

version_str = ".".join(str(_) for _ in version)
min_version_str = ".".join(str(_) for _ in min_version)
assert (
version >= min_version
), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}."
try:
res = subprocess.run(
"git --version", shell=True, capture_output=True, check=True
)
output = res.stdout.decode().strip()
error = res.stderr.decode()
assert error == "", error
match = re.search(r"git version (\d+\.\d+\.\d+)", output)
if not match:
raise ValueError("Could not parse Git version from the output.")
# Convert version number string to tuple of integers
version = tuple(int(_) for _ in match.group(1).split("."))

version_str = ".".join(str(_) for _ in version)
min_version_str = ".".join(str(_) for _ in min_version)
assert (
version >= min_version
), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}."
except subprocess.CalledProcessError as e:
raise Exception(f"Failed to run git command: {e}")


check_git_version((2, 25, 0))
37 changes: 17 additions & 20 deletions src/lean_dojo/data_extraction/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ..utils import (
execute,
url_exists,
get_repo_info,
report_critical_failure,
)
from ..constants import (
Expand All @@ -34,11 +33,6 @@ def _split_git_url(url: str) -> Tuple[str, str]:
return user_name, repo_name


def _format_dirname(url: str, commit: str) -> str:
user_name, repo_name = _split_git_url(url)
return f"{user_name}-{repo_name}-{commit}"


_CACHE_CORRPUTION_MSG = "The cache may have been corrputed!"


Expand All @@ -59,16 +53,16 @@ def __post_init__(self):
lock_path = self.cache_dir.with_suffix(".lock")
object.__setattr__(self, "lock", FileLock(lock_path))

def get(self, url: str, commit: str) -> Optional[Path]:
def get(self, rel_cache_dir: Path) -> Optional[Path]:
"""Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found."""
_, repo_name = _split_git_url(url)
dirname = _format_dirname(url, commit)
dirname = rel_cache_dir.parent
dirpath = self.cache_dir / dirname
cache_path = self.cache_dir / rel_cache_dir

with self.lock:
if dirpath.exists():
assert (dirpath / repo_name).exists()
return dirpath / repo_name
assert cache_path.exists()
return cache_path

elif not DISABLE_REMOTE_CACHE:
url = os.path.join(REMOTE_CACHE_URL, f"{dirname}.tar.gz")
Expand All @@ -83,20 +77,23 @@ def get(self, url: str, commit: str) -> Optional[Path]:
with tarfile.open(f"{dirpath}.tar.gz") as tar:
tar.extractall(self.cache_dir)
os.remove(f"{dirpath}.tar.gz")
assert (dirpath / repo_name).exists()
assert (cache_path).exists()

return dirpath / repo_name
return cache_path

else:
return None

def store(self, src: Path, cache_path: Union[Path, None]=None) -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache."""
if cache_path is None:
url, commit = get_repo_info(src)
_, repo_name = _split_git_url(url)
cache_path = self.cache_dir / _format_dirname(url, commit) / repo_name
if not cache_path.exists():
def store(self, src: Path, rel_cache_dir: Path) -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache.
Args:
src (Path): Path to the repo.
rel_cache_name (Path): The relative path of the stored repo in the cache.
"""
dirpath = self.cache_dir / rel_cache_dir.parent
cache_path = self.cache_dir / rel_cache_dir
if not dirpath.exists():
with self.lock:
with report_critical_failure(_CACHE_CORRPUTION_MSG):
shutil.copytree(src, cache_path)
Expand Down
Loading

0 comments on commit 00c332a

Please sign in to comment.