diff --git a/src/lean_dojo/data_extraction/cache.py b/src/lean_dojo/data_extraction/cache.py index acfe1c9..bbbafd3 100644 --- a/src/lean_dojo/data_extraction/cache.py +++ b/src/lean_dojo/data_extraction/cache.py @@ -8,12 +8,11 @@ from loguru import logger from filelock import FileLock from dataclasses import dataclass, field -from typing import Optional, Tuple, Generator +from typing import Optional, Tuple, Generator, Union from ..utils import ( execute, url_exists, - get_repo_info, report_critical_failure, ) from ..constants import ( @@ -34,11 +33,6 @@ def _split_git_url(url: str) -> Tuple[str, str]: return user_name, repo_name -def _format_dirname(url: str, commit: str) -> str: - user_name, repo_name = _split_git_url(url) - return f"{user_name}-{repo_name}-{commit}" - - _CACHE_CORRPUTION_MSG = "The cache may have been corrputed!" @@ -59,16 +53,16 @@ def __post_init__(self): lock_path = self.cache_dir.with_suffix(".lock") object.__setattr__(self, "lock", FileLock(lock_path)) - def get(self, url: str, commit: str) -> Optional[Path]: - """Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found.""" - _, repo_name = _split_git_url(url) - dirname = _format_dirname(url, commit) + def get(self, rel_cache_dir: Path) -> Optional[Path]: + """Get the cache repo at ``CACHE_DIR / rel_cache_dir`` from the cache.""" + dirname = rel_cache_dir.parent dirpath = self.cache_dir / dirname + cache_path = self.cache_dir / rel_cache_dir with self.lock: if dirpath.exists(): - assert (dirpath / repo_name).exists() - return dirpath / repo_name + assert cache_path.exists() + return cache_path elif not DISABLE_REMOTE_CACHE: url = os.path.join(REMOTE_CACHE_URL, f"{dirname}.tar.gz") @@ -83,23 +77,27 @@ def get(self, url: str, commit: str) -> Optional[Path]: with tarfile.open(f"{dirpath}.tar.gz") as tar: tar.extractall(self.cache_dir) os.remove(f"{dirpath}.tar.gz") - assert (dirpath / repo_name).exists() + assert (cache_path).exists() - return dirpath / repo_name + return cache_path else: return None - def store(self, src: Path) -> Path: - """Store a traced repo at path ``src``. Return its path in the cache.""" - url, commit = get_repo_info(src) - dirpath = self.cache_dir / _format_dirname(url, commit) - _, repo_name = _split_git_url(url) + def store(self, src: Path, rel_cache_dir: Path) -> Path: + """Store a repo at path ``src``. Return its path in the cache. + + Args: + src (Path): Path to the repo. + rel_cache_dir (Path): The relative path of the stored repo in the cache. + """ + dirpath = self.cache_dir / rel_cache_dir.parent + cache_path = self.cache_dir / rel_cache_dir if not dirpath.exists(): with self.lock: with report_critical_failure(_CACHE_CORRPUTION_MSG): - shutil.copytree(src, dirpath / repo_name) - return dirpath / repo_name + shutil.copytree(src, cache_path) + return cache_path cache = Cache(CACHE_DIR) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index c12fae7..bb46bf0 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -9,6 +9,7 @@ import time import urllib import webbrowser +import shutil from pathlib import Path from loguru import logger from functools import cache @@ -17,19 +18,22 @@ from github.Repository import Repository from github.GithubException import GithubException from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator -from git import Repo - - +from git import Repo, BadName +from ..constants import TMP_DIR +import uuid +import shutil +from urllib.parse import urlparse +from .cache import cache as repo_cache +from .cache import _split_git_url from ..utils import ( - execute, read_url, url_exists, get_repo_info, working_directory, + is_git_repo, ) from ..constants import LEAN4_URL - GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None) """GiHub personal access token is optional. If provided, it can increase the rate limit for GitHub API calls. @@ -45,24 +49,96 @@ ) GITHUB = Github() -LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") +LEAN4_REPO = None """The GitHub Repo for Lean 4 itself.""" _URL_REGEX = re.compile(r"(?P.*?)/*") +_SSH_TO_HTTPS_REGEX = re.compile(r"^git@github\.com:(.+)/(.+)(?:\.git)?$") + +REPO_CACHE_PREFIX = "repos" + -def normalize_url(url: str) -> str: +def normalize_url(url: str, repo_type: str = "github") -> str: + if repo_type == "local": + return os.path.abspath(url) # Convert to absolute path if local return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`. +def repo_type_of_url(url: str) -> Union[str, None]: + """Get the type of the repository. + + Args: + url (str): The URL of the repository. + Returns: + str: The type of the repository. + """ + m = _SSH_TO_HTTPS_REGEX.match(url) + url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url + parsed_url = urlparse(url) + if parsed_url.scheme in ["http", "https"]: + # case 1 - GitHub URL + if "github.com" in url: + if not url.startswith("https://"): + logger.warning(f"{url} should start with https://") + return + else: + return "github" + # case 2 - remote URL + elif url_exists(url): # not check whether it is a git URL + return "remote" + # case 3 - local path + elif is_git_repo(Path(parsed_url.path)): + return "local" + logger.warning(f"{url} is not a valid URL") + return None + + +def _format_dirname(url: str, commit: str) -> str: + user_name, repo_name = _split_git_url(url) + repo_type = repo_type_of_url(url) + assert repo_type is not None, f"Invalid url {url}" + if repo_type == "github": + return f"{user_name}-{repo_name}-{commit}" + else: # git repo + return f"gitpython-{repo_name}-{commit}" + + @cache -def url_to_repo(url: str, num_retries: int = 2) -> Repository: +def url_to_repo( + url: str, + num_retries: int = 2, + repo_type: Union[str, None] = None, + tmp_dir: Union[Path] = None, +) -> Union[Repo, Repository]: + """Convert a URL to a Repo object. + + Args: + url (str): The URL of the repository. + num_retries (int): Number of retries in case of failure. + repo_type (Optional[str]): The type of the repository. Defaults to None. + tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None. + + Returns: + Repo: A Git Repo object. + """ url = normalize_url(url) backoff = 1 - + tmp_dir = tmp_dir or os.path.join(TMP_DIR or "/tmp", str(uuid.uuid4())[:8]) + repo_type = repo_type or repo_type_of_url(url) + assert repo_type is not None, f"Invalid url {url}" while True: try: - return GITHUB.get_repo("/".join(url.split("/")[-2:])) + if repo_type == "github": + return GITHUB.get_repo("/".join(url.split("/")[-2:])) + with working_directory(tmp_dir): + repo_name = os.path.basename(url) + if repo_type == "local": + assert is_git_repo(url), f"Local path {url} is not a git repo" + shutil.copytree(url, repo_name) + return Repo(repo_name) + else: + return Repo.clone_from(url, repo_name) except Exception as ex: if num_retries <= 0: raise ex @@ -76,7 +152,10 @@ def url_to_repo(url: str, num_retries: int = 2) -> Repository: def get_latest_commit(url: str) -> str: """Get the hash of the latest commit of the Git repo at ``url``.""" repo = url_to_repo(url) - return repo.get_branch(repo.default_branch).commit.sha + if isinstance(repo, Repository): + return repo.get_branch(repo.default_branch).commit.sha + else: + return repo.head.commit.hexsha def cleanse_string(s: Union[str, Path]) -> str: @@ -84,21 +163,28 @@ def cleanse_string(s: Union[str, Path]) -> str: return str(s).replace("/", "_").replace(":", "_") -@cache -def _to_commit_hash(repo: Repository, label: str) -> str: +def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str: """Convert a tag or branch to a commit hash.""" - logger.debug(f"Querying the commit hash for {repo.name} {label}") - - try: - return repo.get_branch(label).commit.sha - except GithubException: - pass - - for tag in repo.get_tags(): - if tag.name == label: - return tag.commit.sha - - raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") + # GitHub repository + if isinstance(repo, Repository): + logger.debug(f"Querying the commit hash for {repo.name} {label}") + try: + commit = repo.get_commit(label).sha + except GithubException as e: + raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}") + # Local or remote Git repository + elif isinstance(repo, Repo): + logger.debug( + f"Querying the commit hash for {repo.working_dir} repository {label}" + ) + try: + # Resolve the label to a commit hash + commit = repo.commit(label).hexsha + except Exception as e: + raise ValueError(f"Error converting ref to commit hash: {e}") + else: + raise TypeError("Unsupported repository type") + return commit @dataclass(eq=True, unsafe_hash=True) @@ -320,6 +406,11 @@ def __getitem__(self, key) -> str: _LEAN4_VERSION_REGEX = re.compile(r"leanprover/lean4:(?P.+?)") +def is_commit_hash(s: str): + """Check if a string is a valid commit hash.""" + return len(s) == 40 and _COMMIT_REGEX.fullmatch(s) + + def get_lean4_version_from_config(toolchain: str) -> str: """Return the required Lean version given a ``lean-toolchain`` config.""" m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip()) @@ -329,6 +420,9 @@ def get_lean4_version_from_config(toolchain: str) -> str: def get_lean4_commit_from_config(config_dict: Dict[str, Any]) -> str: """Return the required Lean commit given a ``lean-toolchain`` config.""" + global LEAN4_REPO + if LEAN4_REPO is None: + LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") assert "content" in config_dict, "config_dict must have a 'content' field" config = config_dict["content"].strip() prefix = "leanprover/lean4:" @@ -388,9 +482,10 @@ class LeanGitRepo: """Git repo of a Lean project.""" url: str - """The repo's Github URL. + """The repo's URL. - Note that we only support Github as of now. + It can be a GitHub URL that starts with https:// or git@github.com, + a local path, or any other valid Git URL. """ commit: str @@ -399,42 +494,62 @@ class LeanGitRepo: You can also use tags such as ``v3.5.0``. They will be converted to commit hashes. """ - repo: Repository = field(init=False, repr=False) - """A :class:`github.Repository` object. + repo: Union[Repository, Repo] = field(init=False, repr=False) + """A :class:`github.Repository` object for GitHub repos or + a :class:`git.Repo` object for local or remote Git repos. """ lean_version: str = field(init=False, repr=False) """Required Lean version. """ + repo_type: str = field(init=False, repr=False) + """Type of the repo. It can be ``github``, ``local`` or ``remote``. + """ + def __post_init__(self) -> None: - if "github.com" not in self.url: - raise ValueError(f"{self.url} is not a Github URL") - if not self.url.startswith("https://"): + repo_type = repo_type_of_url(self.url) + if repo_type is None: raise ValueError(f"{self.url} is not a valid URL") - object.__setattr__(self, "url", normalize_url(self.url)) - object.__setattr__(self, "repo", url_to_repo(self.url)) - + object.__setattr__(self, "repo_type", repo_type) + object.__setattr__(self, "url", normalize_url(self.url, repo_type=repo_type)) + # set repo and commit + if repo_type == "github": + repo = url_to_repo(self.url, repo_type=repo_type) + else: + # get repo from cache + rel_cache_dir = lambda url, commit: Path( + f"{REPO_CACHE_PREFIX}/{_format_dirname(url, commit)}/{self.name}" + ) + cache_repo_dir = repo_cache.get(rel_cache_dir(self.url, self.commit)) + if cache_repo_dir is None: + with working_directory() as tmp_dir: + repo = url_to_repo(self.url, repo_type=repo_type, tmp_dir=tmp_dir) + commit = _to_commit_hash(repo, self.commit) + cache_repo_dir = repo_cache.store( + repo.working_dir, rel_cache_dir(self.url, commit) + ) + repo = Repo(cache_repo_dir) # Convert tags or branches to commit hashes - if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)): + if not is_commit_hash(self.commit): if (self.url, self.commit) in info_cache.tag2commit: commit = info_cache.tag2commit[(self.url, self.commit)] else: - commit = _to_commit_hash(self.repo, self.commit) - assert _COMMIT_REGEX.fullmatch(commit), f"Invalid commit hash: {commit}" - info_cache.tag2commit[(self.url, self.commit)] = commit + commit = _to_commit_hash(repo, self.commit) + assert is_commit_hash(commit), f"Invalid commit hash: {commit}" + info_cache.tag2commit[(self.url, commit)] = commit object.__setattr__(self, "commit", commit) + object.__setattr__(self, "repo", repo) # Determine the required Lean version. if (self.url, self.commit) in info_cache.lean_version: lean_version = info_cache.lean_version[(self.url, self.commit)] - elif self.is_lean4: - lean_version = self.commit + if self.is_lean4: + lean_version = "latest" # lean4 itself else: config = self.get_config("lean-toolchain") - lean_version = get_lean4_commit_from_config(config) - v = get_lean4_version_from_config(config["content"]) - if not is_supported_version(v): + lean_version = get_lean4_version_from_config(config["content"]) + if not is_supported_version(lean_version): logger.warning( f"{self} relies on an unsupported Lean version: {lean_version}" ) @@ -442,14 +557,14 @@ def __post_init__(self) -> None: object.__setattr__(self, "lean_version", lean_version) @classmethod - def from_path(cls, path: Path) -> "LeanGitRepo": + def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo": """Construct a :class:`LeanGitRepo` object from the path to a local Git repo.""" - url, commit = get_repo_info(path) - return cls(url, commit) + commit = Repo(path).head.commit.hexsha + return cls(str(path), commit) @property def name(self) -> str: - return self.repo.name + return os.path.basename(self.url) @property def is_lean4(self) -> bool: @@ -459,12 +574,29 @@ def is_lean4(self) -> bool: def commit_url(self) -> str: return f"{self.url}/tree/{self.commit}" + @property + def format_dirname(self) -> Path: + """Return the formatted cache directory name""" + assert is_commit_hash(self.commit), f"Invalid commit hash: {self.commit}" + return Path(_format_dirname(self.url, self.commit)) + def show(self) -> None: """Show the repo in the default browser.""" webbrowser.open(self.commit_url) def exists(self) -> bool: - return url_exists(self.commit_url) + if self.repo_type != "github": + repo = self.repo # git repo + try: + repo.commit(self.commit) + return repo.head.commit.hexsha == self.commit + except BadName: + logger.warning( + f"Commit {self.commit} does not exist in this repository." + ) + return False + else: + return url_exists(self.commit_url) def clone_and_checkout(self) -> None: """Clone the repo to the current working directory and checkout a specific commit.""" @@ -605,8 +737,13 @@ def _get_config_url(self, filename: str) -> str: def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: """Return the repo's files.""" - config_url = self._get_config_url(filename) - content = read_url(config_url, num_retries) + if self.repo_type == "github": + config_url = self._get_config_url(filename) + content = read_url(config_url, num_retries) + else: + working_dir = self.repo.working_dir + with open(os.path.join(working_dir, filename), "r") as f: + content = f.read() if filename.endswith(".toml"): return toml.loads(content) elif filename.endswith(".json"): diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index e5b2492..1fae85a 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -204,7 +204,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: Returns: Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2` """ - path = cache.get(repo.url, repo.commit) + path = cache.get(repo.format_dirname / repo.name) if path is None: logger.info(f"Tracing {repo}") with working_directory() as tmp_dir: @@ -212,7 +212,9 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: _trace(repo, build_deps) traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) traced_repo.save_to_disk() - path = cache.store(tmp_dir / repo.name) + src_dir = tmp_dir / repo.name + rel_cache_dir = Path(repo.format_dirname) / repo.name + path = cache.store(src_dir, rel_cache_dir) else: logger.debug("The traced repo is available in the cache.") return path diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 82317b2..cf4a23a 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -15,6 +15,7 @@ from loguru import logger from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Tuple, Union +from git import Repo from ..utils import ( is_git_repo, diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 0386949..cc3e211 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -5,7 +5,7 @@ import os import ray import time -import urllib +import urllib, urllib.request, urllib.error import typing import hashlib import tempfile @@ -16,6 +16,7 @@ from contextlib import contextmanager from ray.util.actor_pool import ActorPool from typing import Tuple, Union, List, Generator, Optional +from urllib.parse import urlparse from .constants import NUM_WORKERS, TMP_DIR, LEAN4_PACKAGES_DIR, LEAN4_BUILD_DIR @@ -144,6 +145,34 @@ def camel_case(s: str) -> str: return _CAMEL_CASE_REGEX.sub(" ", s).title().replace(" ", "") +def repo_type_of_url(url: str) -> str: + """Get the type of the repository. + + Args: + url (str): The URL of the repository. + + Returns: + str: The type of the repository. + """ + parsed_url = urlparse(url) + if parsed_url.scheme in ["http", "https"]: + # case 1 - GitHub URL + if "github.com" in url: + if not url.startswith("https://"): + logger.warning(f"{url} should start with https://") + return + else: + return "github" + # case 2 - remote Git URL + else: + return "remote" + # case 3 - local path + elif os.path.exists(parsed_url.path): + return "local" + else: + logger.warning(f"{url} is not a valid URL") + + @cache def get_repo_info(path: Path) -> Tuple[str, str]: """Get the URL and commit hash of the Git repo at ``path``. @@ -154,20 +183,12 @@ def get_repo_info(path: Path) -> Tuple[str, str]: Returns: Tuple[str, str]: URL and (most recent) hash commit """ - with working_directory(path): - # Get the URL. - url_msg, _ = execute(f"git remote get-url origin", capture_output=True) - url = url_msg.strip() - # Get the commit. - commit_msg, _ = execute(f"git log -n 1", capture_output=True) - m = re.search(r"(?<=^commit )[a-z0-9]+", commit_msg) - assert m is not None - commit = m.group() - - if url.startswith("git@"): - assert url.endswith(".git") - url = url[: -len(".git")].replace(":", "/").replace("git@", "https://") - + url = str(path.absolute()) # use the absolute path + # Get the commit. + commit_msg, _ = execute(f"git log -n 1", capture_output=True) + m = re.search(r"(?<=^commit )[a-z0-9]+", commit_msg) + assert m is not None + commit = m.group() return url, commit @@ -196,7 +217,11 @@ def read_url(url: str, num_retries: int = 2) -> str: backoff = 1 while True: try: - with urllib.request.urlopen(url) as f: + request = urllib.request.Request(url) + gh_token = os.getenv("GITHUB_ACCESS_TOKEN") + if gh_token is not None: + request.add_header("Authorization", f"token {gh_token}") + with urllib.request.urlopen(request) as f: return f.read().decode() except Exception as ex: if num_retries <= 0: @@ -209,9 +234,13 @@ def read_url(url: str, num_retries: int = 2) -> str: @cache def url_exists(url: str) -> bool: - """Return True if the URL ``url`` exists.""" + """Return True if the URL ``url`` exists, using the GITHUB_ACCESS_TOKEN for authentication if provided.""" try: - with urllib.request.urlopen(url) as _: + request = urllib.request.Request(url) + gh_token = os.getenv("GITHUB_ACCESS_TOKEN") + if gh_token is not None: + request.add_header("Authorization", f"token {gh_token}") + with urllib.request.urlopen(request) as _: return True except urllib.error.HTTPError: return False diff --git a/tests/conftest.py b/tests/conftest.py index 7113916..d581c62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4" LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example" EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b" +REMOTE_EXAMPLE_URL = "https://gitee.com/rexzong/lean4-example" URLS = [ BATTERIES_URL, AESOP_URL, @@ -16,6 +17,11 @@ ] +@pytest.fixture(scope="session") +def remote_example_url(): + return REMOTE_EXAMPLE_URL + + @pytest.fixture(scope="session") def example_commit_hash(): return EXAMPLE_COMMIT_HASH diff --git a/tests/data_extraction/test_cache.py b/tests/data_extraction/test_cache.py new file mode 100644 index 0000000..4741239 --- /dev/null +++ b/tests/data_extraction/test_cache.py @@ -0,0 +1,39 @@ +# test for cache manager +from git import Repo +from lean_dojo.utils import working_directory +from pathlib import Path +from lean_dojo.data_extraction.cache import cache + + +def test_repo_cache(lean4_example_url, remote_example_url, example_commit_hash): + # Note: The `git.Repo` requires the local repo to be cloned in a directory + # all cached repos are stored in CACHE_DIR/repos + prefix = "repos" + repo_name = "lean4-example" + + # test local repo cache + with working_directory() as tmp_dir: + repo = Repo.clone_from(lean4_example_url, repo_name) + repo.git.checkout(example_commit_hash) + local_dir = tmp_dir / repo_name + rel_cache_dir = ( + prefix / Path(f"gitpython-{repo_name}-{example_commit_hash}") / repo_name + ) + cache.store(local_dir, rel_cache_dir) + # get the cache + repo_cache_dir = cache.get(rel_cache_dir) + assert repo_cache_dir is not None + + # test remote repo cache + with working_directory() as tmp_dir: + repo = Repo.clone_from(remote_example_url, repo_name) + tmp_remote_dir = tmp_dir / repo_name + rel_cache_dir = ( + prefix + / Path(f"gitpython-{repo_name}-{repo.head.commit.hexsha}") + / repo_name + ) + cache.store(tmp_remote_dir, rel_cache_dir) + # get the cache + repo_cache = cache.get(rel_cache_dir) + assert repo_cache is not None diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index 005218c..04abb7f 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -1,11 +1,136 @@ # test for the class `LeanGitRepo` from lean_dojo import LeanGitRepo +from git import Repo +from github.Repository import Repository +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.lean import ( + _to_commit_hash, + repo_type_of_url, + url_to_repo, + get_latest_commit, + is_commit_hash, + GITHUB, + LEAN4_REPO, +) -def test_lean_git_repo(lean4_example_url, example_commit_hash): +def test_github_type(lean4_example_url, example_commit_hash): + repo_name = "lean4-example" + + ## get_latest_commit + gh_cm_hash = get_latest_commit(lean4_example_url) + assert is_commit_hash(gh_cm_hash) + + ## url_to_repo & repo_type_of_url + github_repo = url_to_repo(lean4_example_url) + assert repo_type_of_url(lean4_example_url) == "github" + assert repo_type_of_url("git@github.com:yangky11/lean4-example.git") == "github" + assert repo_type_of_url("git@github.com:yangky11/lean4-example") == "github" + assert isinstance(github_repo, Repository) + assert github_repo.name == repo_name + + ## commit hash + assert _to_commit_hash(github_repo, example_commit_hash) == example_commit_hash + ### test branch, assume this branch is not changing + assert ( + _to_commit_hash(github_repo, "paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + ### test git tag + assert ( + _to_commit_hash(GITHUB.get_repo("leanprover/lean4"), "v4.9.1") + == "1b78cb4836cf626007bd38872956a6fab8910993" + ) + + ## LeanGitRepo + LeanGitRepo(lean4_example_url, "main") # init with branch repo = LeanGitRepo(lean4_example_url, example_commit_hash) assert repo.url == lean4_example_url + assert repo.repo_type == "github" assert repo.commit == example_commit_hash assert repo.exists() - assert repo.name == "lean4-example" + assert repo.name == repo_name + assert repo.lean_version == "v4.7.0" assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" + # cache name + assert isinstance(repo.repo, Repository) + assert str(repo.format_dirname) == f"yangky11-{repo_name}-{example_commit_hash}" + + +def test_remote_type(remote_example_url, example_commit_hash): + repo_name = "lean4-example" + + remote_repo = url_to_repo(remote_example_url) + assert repo_type_of_url(remote_example_url) == "remote" + assert isinstance(remote_repo, Repo) + re_cm_hash = get_latest_commit(remote_example_url) + assert re_cm_hash == get_latest_commit(str(remote_repo.working_dir)) + assert _to_commit_hash(remote_repo, example_commit_hash) == example_commit_hash + assert ( + _to_commit_hash(remote_repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + + ## LeanGitRepo + LeanGitRepo(remote_example_url, "main") + repo = LeanGitRepo(remote_example_url, example_commit_hash) + assert repo.url == remote_example_url + assert repo.repo_type == "remote" + assert repo.commit == example_commit_hash + assert repo.exists() + assert repo.name == repo_name + assert repo.lean_version == "v4.7.0" + assert repo.commit_url == f"{remote_example_url}/tree/{example_commit_hash}" + # cache name + assert isinstance(repo.repo, Repo) + assert str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}" + + +def test_local_type(lean4_example_url, example_commit_hash): + repo_name = "lean4-example" + gh_cm_hash = get_latest_commit(lean4_example_url) + + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, repo_name) + + ## get_latest_commit + local_url = str((tmp_dir / repo_name).absolute()) + assert get_latest_commit(local_url) == gh_cm_hash + + ## url_to_repo & repo_type_of_url + local_repo = url_to_repo(local_url, repo_type="local") + assert repo_type_of_url(local_url) == "local" + assert isinstance(local_repo, Repo) + assert ( + local_repo.working_dir != local_url + ), "The working directory should not be the same as the original repo" + + ## commit hash + repo = Repo(local_url) + repo.git.checkout(example_commit_hash) + repo.create_tag("v0.1.0") # create a tag for the example commit hash + repo.git.checkout("main") # switch back to main branch + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + assert ( + _to_commit_hash(repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + assert _to_commit_hash(repo, "v0.1.0") == example_commit_hash + + ## LeanGitRepo + LeanGitRepo(local_url, "main") + repo = LeanGitRepo(local_url, example_commit_hash) + repo2 = LeanGitRepo.from_path(local_url) # test from_path + assert repo.url == local_url == repo2.url + assert repo.repo_type == "local" == repo2.repo_type + assert repo.commit == example_commit_hash and repo2.commit == gh_cm_hash + assert repo.exists() and repo2.exists() + assert repo.name == repo_name == repo2.name + assert repo.lean_version == "v4.7.0" + # cache name + assert isinstance(repo.repo, Repo) and isinstance(repo2.repo, Repo) + assert ( + str(repo.format_dirname) == f"gitpython-{repo_name}-{example_commit_hash}" + ) + assert str(repo2.format_dirname) == f"gitpython-{repo_name}-{gh_cm_hash}" diff --git a/tests/data_extraction/test_trace.py b/tests/data_extraction/test_trace.py index 0064933..24df62f 100644 --- a/tests/data_extraction/test_trace.py +++ b/tests/data_extraction/test_trace.py @@ -1,5 +1,42 @@ from pathlib import Path from lean_dojo import * +from lean_dojo.data_extraction.cache import cache +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.lean import url_to_repo +from git import Repo + + +def test_github_trace(lean4_example_url): + # github + github_repo = LeanGitRepo(lean4_example_url, "main") + assert github_repo.repo_type == "github" + trace_repo = trace(github_repo) + path = cache.get(github_repo.format_dirname / github_repo.name) + assert path is not None + + +def test_remote_trace(remote_example_url): + # remote + remote_repo = LeanGitRepo(remote_example_url, "main") + assert remote_repo.repo_type == "remote" + trace_repo = trace(remote_repo) + path = cache.get(remote_repo.format_dirname / remote_repo.name) + assert path is not None + + +def test_local_trace(lean4_example_url): + # local + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, "lean4-example") + local_dir = str((tmp_dir / "lean4-example")) + local_url = str((tmp_dir / "lean4-example").absolute()) + local_repo = LeanGitRepo(local_dir, "main") + assert local_repo.url == local_url + assert local_repo.repo_type == "local" + trace_repo = trace(local_repo) + path = cache.get(local_repo.format_dirname / local_repo.name) + assert path is not None def test_trace(traced_repo): diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py index 1e0b858..3eff4d9 100644 --- a/tests/interaction/test_interaction.py +++ b/tests/interaction/test_interaction.py @@ -1,8 +1,32 @@ from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem +from lean_dojo.utils import working_directory +from git import Repo +import os +# Avoid using remote cache +os.environ["DISABLE_REMOTE_CACHE"] = "true" -def test_remote_interact(lean4_example_url): + +def test_github_interact(lean4_example_url): repo = LeanGitRepo(url=lean4_example_url, commit="main") + assert repo.repo_type == "github" + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished) + + +def test_remote_interact(remote_example_url): + repo = LeanGitRepo(url=remote_example_url, commit="main") + assert repo.repo_type == "remote" theorem = Theorem(repo, "Lean4Example.lean", "hello_world") # initial state dojo, state_0 = Dojo(theorem).__enter__() @@ -15,3 +39,26 @@ def test_remote_interact(lean4_example_url): # finish proof final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") assert isinstance(final_state, ProofFinished) + + +def test_local_interact(lean4_example_url): + # Clone the GitHub repository to the local path + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, "lean4-example") + + local_dir = str((tmp_dir / "lean4-example")) + repo = LeanGitRepo(local_dir, commit="main") + assert repo.repo_type == "local" + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished)