From 6429b410c3881b331242ce392092f04a1f6b5ba1 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 22 Jul 2024 20:48:25 +0800 Subject: [PATCH 01/11] fix git version for windows --- src/lean_dojo/constants.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/lean_dojo/constants.py b/src/lean_dojo/constants.py index aa50261..c7a5fc3 100644 --- a/src/lean_dojo/constants.py +++ b/src/lean_dojo/constants.py @@ -71,20 +71,25 @@ assert re.fullmatch(r"\d+g", TACTIC_MEMORY_LIMIT) -def check_git_version(min_version: Tuple[int, int, int]) -> Tuple[int, int, int]: +def check_git_version(min_version: Tuple[int, int, int]) -> None: """Check the version of Git installed on the system.""" - res = subprocess.run("git --version", shell=True, capture_output=True, check=True) - output = res.stdout.decode() - error = res.stderr.decode() - assert error == "", error - m = re.match(r"git version (?P[0-9.]+)", output) - version = tuple(int(_) for _ in m["version"].split(".")) - - version_str = ".".join(str(_) for _ in version) - min_version_str = ".".join(str(_) for _ in min_version) - assert ( - version >= min_version - ), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}." - + try: + res = subprocess.run("git --version", shell=True, capture_output=True, check=True) + output = res.stdout.decode().strip() + error = res.stderr.decode() + assert error == "", error + match = re.search(r"git version (\d+\.\d+\.\d+)", output) + if not match: + raise ValueError("Could not parse Git version from the output.") + # Convert version number string to tuple of integers + version = tuple(int(_) for _ in match.group(1).split('.')) + + version_str = ".".join(str(_) for _ in version) + min_version_str = ".".join(str(_) for _ in min_version) + assert ( + version >= min_version + ), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}." + except subprocess.CalledProcessError as e: + raise Exception(f"Failed to run git command: {e}") check_git_version((2, 25, 0)) From 042452c7dfb263552ebe4f91743fce2353a74c0e Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Mon, 22 Jul 2024 20:52:04 -0400 Subject: [PATCH 02/11] Update __init__.py --- src/lean_dojo/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lean_dojo/__init__.py b/src/lean_dojo/__init__.py index a1297d6..5024a5d 100644 --- a/src/lean_dojo/__init__.py +++ b/src/lean_dojo/__init__.py @@ -28,8 +28,3 @@ from .interaction.parse_goals import Declaration, Goal, parse_goals from .data_extraction.lean import get_latest_commit, LeanGitRepo, LeanFile, Theorem, Pos from .constants import __version__ - -if os.geteuid() == 0: - logger.warning( - "Running LeanDojo as the root user may cause unexpected issues. Proceed with caution." - ) From 12d23cb8e5acfe9854f8e31115bfc13e51cd5310 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Tue, 23 Jul 2024 08:05:14 +0800 Subject: [PATCH 03/11] fix commit url & git clone for windows --- .gitignore | 3 +++ pyproject.toml | 1 + src/lean_dojo/data_extraction/lean.py | 16 +++++++++------- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index cd92677..8d89e22 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Pyre type checker .pyre/ + +# vscode debug config +.vscode/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4625fc2..d6c7061 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "python-dotenv", "loguru", "filelock", + "gitpython", "psutil", "pexpect", "types-psutil", diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 5d7b140..ad4a2a2 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -17,6 +17,8 @@ from github.Repository import Repository from github.GithubException import GithubException from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator +from git import Repo + from ..utils import ( execute, @@ -455,7 +457,7 @@ def is_lean4(self) -> bool: @property def commit_url(self) -> str: - return os.path.join(self.url, f"tree/{self.commit}") + return f"{self.url}/tree/{self.commit}" def show(self) -> None: """Show the repo in the default browser.""" @@ -467,12 +469,12 @@ def exists(self) -> bool: def clone_and_checkout(self) -> None: """Clone the repo to the current working directory and checkout a specific commit.""" logger.debug(f"Cloning {self}") - execute(f"git clone -n --recursive {self.url}", capture_output=True) - with working_directory(self.name): - execute( - f"git checkout {self.commit} && git submodule update --recursive", - capture_output=True, - ) + # Clone the repository + repo = Repo.clone_from(self.url, Path(self.name)) + # Checkout the specific commit + repo.git.checkout(self.commit) + # Initialize and update submodules + repo.submodule_update(recursive=True) def get_dependencies( self, path: Union[str, Path, None] = None From 97565149d5b261af989d55be0db3d453862feda7 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Wed, 24 Jul 2024 03:18:34 +0800 Subject: [PATCH 04/11] add tests for lean repo & dojo --- src/lean_dojo/data_extraction/lean.py | 8 +++----- tests/conftest.py | 8 ++++++++ tests/data_extraction/test_lean_repo.py | 10 ++++++++++ tests/interaction/test_interaction.py | 16 ++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 tests/data_extraction/test_lean_repo.py create mode 100644 tests/interaction/test_interaction.py diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index ad4a2a2..aeea52f 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -469,12 +469,10 @@ def exists(self) -> bool: def clone_and_checkout(self) -> None: """Clone the repo to the current working directory and checkout a specific commit.""" logger.debug(f"Cloning {self}") - # Clone the repository - repo = Repo.clone_from(self.url, Path(self.name)) - # Checkout the specific commit + repo = Repo.clone_from(self.url, Path(self.name), no_checkout=True) repo.git.checkout(self.commit) - # Initialize and update submodules - repo.submodule_update(recursive=True) + repo.submodule_update(init=True, recursive=True) + def get_dependencies( self, path: Union[str, Path, None] = None diff --git a/tests/conftest.py b/tests/conftest.py index 1e14c96..3767c1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ AESOP_URL = "https://github.com/leanprover-community/aesop" MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4" LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example" +EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b" URLS = [ BATTERIES_URL, AESOP_URL, @@ -14,6 +15,13 @@ LEAN4_EXAMPLE_URL, ] +@pytest.fixture(scope="session") +def example_commit_hash(): + return EXAMPLE_COMMIT_HASH + +@pytest.fixture(scope="session") +def lean4_example_url(): + return LEAN4_EXAMPLE_URL @pytest.fixture(scope="session") def monkeysession(): diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py new file mode 100644 index 0000000..4e1e331 --- /dev/null +++ b/tests/data_extraction/test_lean_repo.py @@ -0,0 +1,10 @@ +# test for the class `LeanGitRepo` +from lean_dojo import LeanGitRepo + +def test_lean_git_repo(lean4_example_url, example_commit_hash): + repo = LeanGitRepo(lean4_example_url, example_commit_hash) + assert repo.url == lean4_example_url + assert repo.commit == example_commit_hash + assert repo.exists() + assert repo.name == "lean4-example" + assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" \ No newline at end of file diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py new file mode 100644 index 0000000..6288bd9 --- /dev/null +++ b/tests/interaction/test_interaction.py @@ -0,0 +1,16 @@ +from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem + +def test_remote_interact(lean4_example_url): + repo = LeanGitRepo(url=lean4_example_url, commit="main") + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == 'a b c : Nat\n⊢ a + b + c = a + c + b' + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == 'a b c : Nat\n⊢ a + (b + c) = a + c + b' + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished) \ No newline at end of file From a8a0977be231113debaab149df78fbfa407d70cc Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Wed, 24 Jul 2024 03:47:16 +0800 Subject: [PATCH 05/11] simplify commit hash & delay initalization of lean4 repo --- src/lean_dojo/data_extraction/lean.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 5d7b140..8e5a75c 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -43,7 +43,7 @@ ) GITHUB = Github() -LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") +LEAN4_REPO = None """The GitHub Repo for Lean 4 itself.""" _URL_REGEX = re.compile(r"(?P.*?)/*") @@ -88,15 +88,9 @@ def _to_commit_hash(repo: Repository, label: str) -> str: logger.debug(f"Querying the commit hash for {repo.name} {label}") try: - return repo.get_branch(label).commit.sha - except GithubException: - pass - - for tag in repo.get_tags(): - if tag.name == label: - return tag.commit.sha - - raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") + return repo.get_commit(label).sha + except Exception: + raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") @dataclass(eq=True, unsafe_hash=True) @@ -328,6 +322,9 @@ def get_lean4_version_from_config(toolchain: str) -> str: def get_lean4_commit_from_config(config_dict: Dict[str, Any]) -> str: """Return the required Lean commit given a ``lean-toolchain`` config.""" assert "content" in config_dict, "config_dict must have a 'content' field" + global LEAN4_REPO + if LEAN4_REPO is None: + LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") config = config_dict["content"].strip() prefix = "leanprover/lean4:" assert config.startswith(prefix), f"Invalid Lean 4 version: {config}" @@ -447,7 +444,7 @@ def from_path(cls, path: Path) -> "LeanGitRepo": @property def name(self) -> str: - return self.repo.name + return os.path.basename(self.url) @property def is_lean4(self) -> bool: From fab19e2b72458723d2ede2020c108cc0f8750d27 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Wed, 24 Jul 2024 11:47:28 +0800 Subject: [PATCH 06/11] lean version: use string instead of commit --- src/lean_dojo/data_extraction/lean.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 8e5a75c..427b531 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -427,9 +427,14 @@ def __post_init__(self) -> None: lean_version = self.commit else: config = self.get_config("lean-toolchain") - lean_version = get_lean4_commit_from_config(config) - v = get_lean4_version_from_config(config["content"]) - if not is_supported_version(v): + toolchain = config["content"] + m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip()) + if m is not None: + lean_version = m["version"] + else: + # lean_version_commit = get_lean4_commit_from_config(config) + lean_version = get_lean4_version_from_config(toolchain) + if not is_supported_version(lean_version): logger.warning( f"{self} relies on an unsupported Lean version: {lean_version}" ) From 60a500ee23ddae962590bab989ae14d84a57c933 Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Wed, 24 Jul 2024 19:55:06 +0000 Subject: [PATCH 07/11] format code --- src/lean_dojo/constants.py | 9 ++++++--- src/lean_dojo/data_extraction/lean.py | 1 - tests/conftest.py | 3 +++ tests/data_extraction/test_lean_repo.py | 3 ++- tests/interaction/test_interaction.py | 7 ++++--- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/lean_dojo/constants.py b/src/lean_dojo/constants.py index c7a5fc3..fd3f74b 100644 --- a/src/lean_dojo/constants.py +++ b/src/lean_dojo/constants.py @@ -74,7 +74,9 @@ def check_git_version(min_version: Tuple[int, int, int]) -> None: """Check the version of Git installed on the system.""" try: - res = subprocess.run("git --version", shell=True, capture_output=True, check=True) + res = subprocess.run( + "git --version", shell=True, capture_output=True, check=True + ) output = res.stdout.decode().strip() error = res.stderr.decode() assert error == "", error @@ -82,8 +84,8 @@ def check_git_version(min_version: Tuple[int, int, int]) -> None: if not match: raise ValueError("Could not parse Git version from the output.") # Convert version number string to tuple of integers - version = tuple(int(_) for _ in match.group(1).split('.')) - + version = tuple(int(_) for _ in match.group(1).split(".")) + version_str = ".".join(str(_) for _ in version) min_version_str = ".".join(str(_) for _ in min_version) assert ( @@ -92,4 +94,5 @@ def check_git_version(min_version: Tuple[int, int, int]) -> None: except subprocess.CalledProcessError as e: raise Exception(f"Failed to run git command: {e}") + check_git_version((2, 25, 0)) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index aeea52f..c12fae7 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -473,7 +473,6 @@ def clone_and_checkout(self) -> None: repo.git.checkout(self.commit) repo.submodule_update(init=True, recursive=True) - def get_dependencies( self, path: Union[str, Path, None] = None ) -> Dict[str, "LeanGitRepo"]: diff --git a/tests/conftest.py b/tests/conftest.py index 3767c1d..7113916 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,14 +15,17 @@ LEAN4_EXAMPLE_URL, ] + @pytest.fixture(scope="session") def example_commit_hash(): return EXAMPLE_COMMIT_HASH + @pytest.fixture(scope="session") def lean4_example_url(): return LEAN4_EXAMPLE_URL + @pytest.fixture(scope="session") def monkeysession(): with pytest.MonkeyPatch.context() as mp: diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index 4e1e331..005218c 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -1,10 +1,11 @@ # test for the class `LeanGitRepo` from lean_dojo import LeanGitRepo + def test_lean_git_repo(lean4_example_url, example_commit_hash): repo = LeanGitRepo(lean4_example_url, example_commit_hash) assert repo.url == lean4_example_url assert repo.commit == example_commit_hash assert repo.exists() assert repo.name == "lean4-example" - assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" \ No newline at end of file + assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py index 6288bd9..1e0b858 100644 --- a/tests/interaction/test_interaction.py +++ b/tests/interaction/test_interaction.py @@ -1,16 +1,17 @@ from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem + def test_remote_interact(lean4_example_url): repo = LeanGitRepo(url=lean4_example_url, commit="main") theorem = Theorem(repo, "Lean4Example.lean", "hello_world") # initial state dojo, state_0 = Dojo(theorem).__enter__() - assert state_0.pp == 'a b c : Nat\n⊢ a + b + c = a + c + b' + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" # state after running a tactic state_1 = dojo.run_tac(state_0, "rw [add_assoc]") - assert state_1.pp == 'a b c : Nat\n⊢ a + (b + c) = a + c + b' + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" # state after running another a sorry tactic assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() # finish proof final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") - assert isinstance(final_state, ProofFinished) \ No newline at end of file + assert isinstance(final_state, ProofFinished) From eabaa222dca8db700a66d3b8b5eed822ad470cf3 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Thu, 25 Jul 2024 14:35:16 +0800 Subject: [PATCH 08/11] update cache method --- src/lean_dojo/data_extraction/cache.py | 6 +++-- src/lean_dojo/data_extraction/lean.py | 29 +++++++++++++++---------- src/lean_dojo/data_extraction/trace.py | 2 +- tests/data_extraction/test_lean_repo.py | 18 +++++++++++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/lean_dojo/data_extraction/cache.py b/src/lean_dojo/data_extraction/cache.py index acfe1c9..e878ff2 100644 --- a/src/lean_dojo/data_extraction/cache.py +++ b/src/lean_dojo/data_extraction/cache.py @@ -90,10 +90,12 @@ def get(self, url: str, commit: str) -> Optional[Path]: else: return None - def store(self, src: Path) -> Path: + def store(self, src: Path, fmt_name: str = "") -> Path: """Store a traced repo at path ``src``. Return its path in the cache.""" url, commit = get_repo_info(src) - dirpath = self.cache_dir / _format_dirname(url, commit) + if fmt_name == "": # if not specified, extract from the traced repo + fmt_name = _format_dirname(url, commit) + dirpath = self.cache_dir / fmt_name _, repo_name = _split_git_url(url) if not dirpath.exists(): with self.lock: diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 0daed9e..cc37e91 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -28,6 +28,7 @@ working_directory, ) from ..constants import LEAN4_URL +from .cache import _format_dirname GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None) @@ -90,9 +91,14 @@ def _to_commit_hash(repo: Repository, label: str) -> str: logger.debug(f"Querying the commit hash for {repo.name} {label}") try: - return repo.get_commit(label).sha - except Exception: - raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") + return repo.get_branch(label).commit.sha + except GithubException: + pass + + for tag in repo.get_tags(): + if tag.name == label: + return tag.commit.sha + raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") @dataclass(eq=True, unsafe_hash=True) @@ -429,13 +435,7 @@ def __post_init__(self) -> None: lean_version = self.commit else: config = self.get_config("lean-toolchain") - toolchain = config["content"] - m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip()) - if m is not None: - lean_version = m["version"] - else: - # lean_version_commit = get_lean4_commit_from_config(config) - lean_version = get_lean4_version_from_config(toolchain) + lean_version = get_lean4_version_from_config(config["content"]) if not is_supported_version(lean_version): logger.warning( f"{self} relies on an unsupported Lean version: {lean_version}" @@ -444,9 +444,9 @@ def __post_init__(self) -> None: object.__setattr__(self, "lean_version", lean_version) @classmethod - def from_path(cls, path: Path) -> "LeanGitRepo": + def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo": """Construct a :class:`LeanGitRepo` object from the path to a local Git repo.""" - url, commit = get_repo_info(path) + url, commit = get_repo_info(Path(path)) return cls(url, commit) @property @@ -461,6 +461,11 @@ def is_lean4(self) -> bool: def commit_url(self) -> str: return f"{self.url}/tree/{self.commit}" + @property + def format_dirname(self) -> str: + """Return the formatted cache directory name""" + return _format_dirname(self.url, self.commit) + def show(self) -> None: """Show the repo in the default browser.""" webbrowser.open(self.commit_url) diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index e5b2492..f30a511 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -212,7 +212,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: _trace(repo, build_deps) traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) traced_repo.save_to_disk() - path = cache.store(tmp_dir / repo.name) + path = cache.store(tmp_dir / repo.name, repo.format_dirname) else: logger.debug("The traced repo is available in the cache.") return path diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index 005218c..9660c3c 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -1,5 +1,7 @@ # test for the class `LeanGitRepo` from lean_dojo import LeanGitRepo +from lean_dojo.data_extraction.lean import _to_commit_hash +from lean_dojo.constants import LEAN4_URL def test_lean_git_repo(lean4_example_url, example_commit_hash): @@ -9,3 +11,19 @@ def test_lean_git_repo(lean4_example_url, example_commit_hash): assert repo.exists() assert repo.name == "lean4-example" assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" + # test cache directory + assert ( + repo.format_dirname + == "yangky11-lean4-example-3f8c5eb303a225cdef609498b8d87262e5ef344b" + ) + # test commit hash + ## test commit hash + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + ## test branch, assume the branch is not changed + assert _to_commit_hash(repo, "paper") == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ## test tag + lean4_repo = LeanGitRepo(LEAN4_URL, "master") + assert ( + _to_commit_hash(lean4_repo, "v4.9.1") + == "1b78cb4836cf626007bd38872956a6fab8910993" + ) From ff0f9f475e09e70019cde469de26e52ee1fff686 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Fri, 26 Jul 2024 03:30:25 +0800 Subject: [PATCH 09/11] update functions for git Repos & add tests --- src/lean_dojo/data_extraction/cache.py | 24 +++--- src/lean_dojo/data_extraction/lean.py | 89 ++++++++++++++----- src/lean_dojo/data_extraction/trace.py | 4 +- src/lean_dojo/utils.py | 28 ++++++ tests/conftest.py | 6 ++ tests/data_extraction/test_cache.py | 55 ++++++++++++ tests/data_extraction/test_lean_repo.py | 109 ++++++++++++++++++++---- 7 files changed, 266 insertions(+), 49 deletions(-) create mode 100644 tests/data_extraction/test_cache.py diff --git a/src/lean_dojo/data_extraction/cache.py b/src/lean_dojo/data_extraction/cache.py index e878ff2..829b7e8 100644 --- a/src/lean_dojo/data_extraction/cache.py +++ b/src/lean_dojo/data_extraction/cache.py @@ -59,11 +59,11 @@ def __post_init__(self): lock_path = self.cache_dir.with_suffix(".lock") object.__setattr__(self, "lock", FileLock(lock_path)) - def get(self, url: str, commit: str) -> Optional[Path]: + def get(self, url: str, commit: str, prefix: str = "") -> Optional[Path]: """Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found.""" _, repo_name = _split_git_url(url) dirname = _format_dirname(url, commit) - dirpath = self.cache_dir / dirname + dirpath = self.cache_dir / prefix / dirname with self.lock: if dirpath.exists(): @@ -90,18 +90,20 @@ def get(self, url: str, commit: str) -> Optional[Path]: else: return None - def store(self, src: Path, fmt_name: str = "") -> Path: - """Store a traced repo at path ``src``. Return its path in the cache.""" - url, commit = get_repo_info(src) - if fmt_name == "": # if not specified, extract from the traced repo - fmt_name = _format_dirname(url, commit) - dirpath = self.cache_dir / fmt_name - _, repo_name = _split_git_url(url) + def store(self, src: Path, rel_cache_dir: Path) -> Path: + """Store a traced repo at path ``src``. Return its path in the cache. + + Args: + src (Path): Path to the repo. + rel_cache_name (Path): The relative path of the stored repo in the cache. + """ + dirpath = self.cache_dir / rel_cache_dir.parent + cache_path = self.cache_dir / rel_cache_dir if not dirpath.exists(): with self.lock: with report_critical_failure(_CACHE_CORRPUTION_MSG): - shutil.copytree(src, dirpath / repo_name) - return dirpath / repo_name + shutil.copytree(src, cache_path) + return cache_path cache = Cache(CACHE_DIR) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index cc37e91..b9c5f60 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -18,6 +18,9 @@ from github.GithubException import GithubException from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator from git import Repo +from ..constants import TMP_DIR +import uuid +import shutil from ..utils import ( @@ -26,6 +29,8 @@ url_exists, get_repo_info, working_directory, + is_git_repo, + repo_type_of_url, ) from ..constants import LEAN4_URL from .cache import _format_dirname @@ -52,18 +57,46 @@ _URL_REGEX = re.compile(r"(?P.*?)/*") -def normalize_url(url: str) -> str: +def normalize_url(url: str, repo_type: str = "github") -> str: + if repo_type == "local": + return os.path.abspath(url) # Convert to absolute path if local return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`. @cache -def url_to_repo(url: str, num_retries: int = 2) -> Repository: +def url_to_repo( + url: str, + num_retries: int = 2, + repo_type: Union[str, None] = None, + tmp_dir: Union[Path] = None, +) -> Union[Repo, Repository]: + """Convert a URL to a Repo object. + + Args: + url (str): The URL of the repository. + num_retries (int): Number of retries in case of failure. + repo_type (Optional[str]): The type of the repository. Defaults to None. + tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None. + + Returns: + Repo: A Git Repo object. + """ url = normalize_url(url) backoff = 1 - + tmp_dir = tmp_dir or os.path.join(TMP_DIR or "/tmp", str(uuid.uuid4())[:8]) + repo_type = repo_type or repo_type_of_url(url) while True: try: - return GITHUB.get_repo("/".join(url.split("/")[-2:])) + if repo_type == "github": + return GITHUB.get_repo("/".join(url.split("/")[-2:])) + with working_directory(tmp_dir): + repo_name = os.path.basename(url) + if repo_type == "local": + assert is_git_repo(url), f"Local path {url} is not a git repo" + shutil.copytree(url, repo_name) + return Repo(repo_name) + else: + return Repo.clone_from(url, repo_name) except Exception as ex: if num_retries <= 0: raise ex @@ -77,7 +110,10 @@ def url_to_repo(url: str, num_retries: int = 2) -> Repository: def get_latest_commit(url: str) -> str: """Get the hash of the latest commit of the Git repo at ``url``.""" repo = url_to_repo(url) - return repo.get_branch(repo.default_branch).commit.sha + if isinstance(repo, Repository): + return repo.get_branch(repo.default_branch).commit.sha + else: + return repo.head.commit.hexsha def cleanse_string(s: Union[str, Path]) -> str: @@ -85,20 +121,28 @@ def cleanse_string(s: Union[str, Path]) -> str: return str(s).replace("/", "_").replace(":", "_") -@cache -def _to_commit_hash(repo: Repository, label: str) -> str: +def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str: """Convert a tag or branch to a commit hash.""" - logger.debug(f"Querying the commit hash for {repo.name} {label}") - - try: - return repo.get_branch(label).commit.sha - except GithubException: - pass - - for tag in repo.get_tags(): - if tag.name == label: - return tag.commit.sha - raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") + # GitHub repository + if isinstance(repo, Repository): + logger.debug(f"Querying the commit hash for {repo.name} {label}") + try: + commit = repo.get_commit(label).sha + except GithubException as e: + raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}") + # Local or remote Git repository + elif isinstance(repo, Repo): + logger.debug( + f"Querying the commit hash for {repo.working_dir} repository {label}" + ) + try: + # Resolve the label to a commit hash + commit = repo.commit(label).hexsha + except Exception as e: + raise ValueError(f"Error converting ref to commit hash: {e}") + else: + raise TypeError("Unsupported repository type") + return commit @dataclass(eq=True, unsafe_hash=True) @@ -320,6 +364,11 @@ def __getitem__(self, key) -> str: _LEAN4_VERSION_REGEX = re.compile(r"leanprover/lean4:(?P.+?)") +def is_commit_hash(s: str): + """Check if a string is a valid commit hash.""" + return len(s) == 40 and _COMMIT_REGEX.fullmatch(s) + + def get_lean4_version_from_config(toolchain: str) -> str: """Return the required Lean version given a ``lean-toolchain`` config.""" m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip()) @@ -419,12 +468,12 @@ def __post_init__(self) -> None: object.__setattr__(self, "repo", url_to_repo(self.url)) # Convert tags or branches to commit hashes - if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)): + if not is_commit_hash(self.commit): if (self.url, self.commit) in info_cache.tag2commit: commit = info_cache.tag2commit[(self.url, self.commit)] else: commit = _to_commit_hash(self.repo, self.commit) - assert _COMMIT_REGEX.fullmatch(commit), f"Invalid commit hash: {commit}" + assert is_commit_hash(commit) info_cache.tag2commit[(self.url, self.commit)] = commit object.__setattr__(self, "commit", commit) diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index f30a511..56f2358 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -212,7 +212,9 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: _trace(repo, build_deps) traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) traced_repo.save_to_disk() - path = cache.store(tmp_dir / repo.name, repo.format_dirname) + src_dir = tmp_dir / repo.name + rel_cache_dir = Path(repo.format_dirname / repo.name) + path = cache.store(src_dir, rel_cache_dir) else: logger.debug("The traced repo is available in the cache.") return path diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 0386949..12cf101 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from ray.util.actor_pool import ActorPool from typing import Tuple, Union, List, Generator, Optional +from urllib.parse import urlparse from .constants import NUM_WORKERS, TMP_DIR, LEAN4_PACKAGES_DIR, LEAN4_BUILD_DIR @@ -144,6 +145,33 @@ def camel_case(s: str) -> str: return _CAMEL_CASE_REGEX.sub(" ", s).title().replace(" ", "") +def repo_type_of_url(url: str) -> Union[str, None]: + """Get the type of the repository. + + Args: + url (str): The URL of the repository. + Returns: + str: The type of the repository. + """ + parsed_url = urlparse(url) + if parsed_url.scheme in ["http", "https"]: + # case 1 - GitHub URL + if "github.com" in url: + if not url.startswith("https://"): + logger.warning(f"{url} should start with https://") + return + else: + return "github" + # case 2 - remote URL + elif url_exists(url): # not check whether it is a git URL + return "remote" + # case 3 - local path + elif is_git_repo(Path(parsed_url.path)): + return "local" + logger.warning(f"{url} is not a valid URL") + return None + + @cache def get_repo_info(path: Path) -> Tuple[str, str]: """Get the URL and commit hash of the Git repo at ``path``. diff --git a/tests/conftest.py b/tests/conftest.py index 7113916..d581c62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4" LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example" EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b" +REMOTE_EXAMPLE_URL = "https://gitee.com/rexzong/lean4-example" URLS = [ BATTERIES_URL, AESOP_URL, @@ -16,6 +17,11 @@ ] +@pytest.fixture(scope="session") +def remote_example_url(): + return REMOTE_EXAMPLE_URL + + @pytest.fixture(scope="session") def example_commit_hash(): return EXAMPLE_COMMIT_HASH diff --git a/tests/data_extraction/test_cache.py b/tests/data_extraction/test_cache.py new file mode 100644 index 0000000..5f8004a --- /dev/null +++ b/tests/data_extraction/test_cache.py @@ -0,0 +1,55 @@ +# test for cache manager +from git import Repo +from lean_dojo.utils import working_directory +from pathlib import Path +from lean_dojo.data_extraction.lean import _format_dirname +from lean_dojo.data_extraction.cache import cache + + +def test_get_cache(lean4_example_url, remote_example_url, example_commit_hash): + # Note: The `git.Repo` requires the local repo to be cloned in a directory + # all cached repos are stored in CACHE_DIR/repos + prefix = "repos" + + # test local repo cache + with working_directory() as tmp_dir: + # assume that the local repo placed in `/.../testrepo/lean4-example` + repo = Repo.clone_from(lean4_example_url, "testrepo/lean4-example") + repo.git.checkout(example_commit_hash) + local_dir = tmp_dir / "testrepo/lean4-example" + # use local_dir as the key to store the cache + rel_cache_dir = ( + prefix + / Path(_format_dirname(str(local_dir), example_commit_hash)) + / local_dir.name + ) + cache.store(local_dir, rel_cache_dir) + # get the cache + local_url, local_commit = str(local_dir), example_commit_hash + repo_cache = cache.get(local_url, local_commit, prefix) + assert ( + _format_dirname(local_url, local_commit) + == f"{local_dir.parent.name}-{local_dir.name}-{local_commit}" + ) + assert repo_cache is not None + + # test remote repo cache + with working_directory() as tmp_dir: + repo = Repo.clone_from(remote_example_url, "lean4-example") + repo.git.checkout(example_commit_hash) + tmp_remote_dir = tmp_dir / "lean4-example" + # use remote url as the key to store the cache + rel_cache_dir = ( + prefix + / Path(_format_dirname(str(remote_example_url), example_commit_hash)) + / tmp_remote_dir.name + ) + cache.store(tmp_remote_dir, rel_cache_dir) + # get the cache + remote_url, remote_commit = remote_example_url, example_commit_hash + repo_cache = cache.get(remote_url, remote_commit, prefix) + assert repo_cache is not None + assert ( + _format_dirname(remote_url, remote_commit) + == f"rexzong-lean4-example-{example_commit_hash}" + ) diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index 9660c3c..c25b18c 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -1,10 +1,99 @@ # test for the class `LeanGitRepo` from lean_dojo import LeanGitRepo -from lean_dojo.data_extraction.lean import _to_commit_hash from lean_dojo.constants import LEAN4_URL +from git import Repo +from github.Repository import Repository +from lean_dojo.utils import working_directory, repo_type_of_url +from lean_dojo.data_extraction.lean import ( + _to_commit_hash, + url_to_repo, + get_latest_commit, + is_commit_hash, + GITHUB, +) -def test_lean_git_repo(lean4_example_url, example_commit_hash): +def test_url_to_repo(lean4_example_url, remote_example_url): + repo_name = "lean4-example" + + # 1. github + ## test get_latest_commit + gh_cm_hash = get_latest_commit(lean4_example_url) + assert is_commit_hash(gh_cm_hash) + + ## test url_to_repo & repo_type_of_url + github_repo = url_to_repo(lean4_example_url) + assert repo_type_of_url(lean4_example_url) == "github" + assert isinstance(github_repo, Repository) + assert github_repo.name == repo_name + + # 2. local + with working_directory() as tmp_dir: + + ## clone from github + Repo.clone_from(lean4_example_url, repo_name) + + ## test get_latest_commit + local_url = str((tmp_dir / repo_name).absolute()) + assert get_latest_commit(local_url) == gh_cm_hash + + ## test url_to_repo & repo_type_of_url + local_repo = url_to_repo(local_url, repo_type="local") + assert repo_type_of_url(local_url) == "local" + assert isinstance(local_repo, Repo) + assert ( + local_repo.working_dir != local_url + ), "The working directory should not be the same as the original repo" + + # 3. remote + with working_directory(): + remote_repo = url_to_repo(remote_example_url) + assert repo_type_of_url(remote_example_url) == "remote" + assert isinstance(remote_repo, Repo) + re_cm_hash = get_latest_commit(remote_example_url) + tmp_repo_path = str(remote_repo.working_dir) + assert re_cm_hash == get_latest_commit(tmp_repo_path) + + +def test_to_commit_hash(lean4_example_url, remote_example_url, example_commit_hash): + # 1. github + repo = GITHUB.get_repo("yangky11/lean4-example") + ## commit hash + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + ## branch, assume this branch is not changing + assert _to_commit_hash(repo, "paper") == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + gh_main_hash = _to_commit_hash(repo, "main") + ## git tag + assert ( + _to_commit_hash(GITHUB.get_repo("leanprover/lean4"), "v4.9.1") + == "1b78cb4836cf626007bd38872956a6fab8910993" + ) + # 2. local + with working_directory(): + repo = Repo.clone_from(lean4_example_url, "lean4-example") + repo.git.checkout(example_commit_hash) + repo.create_tag("v0.1.0") # create a tag + repo.git.checkout("main") + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + assert _to_commit_hash(repo, "main") == gh_main_hash + assert ( + _to_commit_hash(repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + assert _to_commit_hash(repo, "v0.1.0") == example_commit_hash + + # 3. remote + with working_directory(): + repo = url_to_repo(remote_example_url) + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + assert ( + _to_commit_hash(repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + # no tags in the remote repo + + +def test_git_lean_repo(lean4_example_url, example_commit_hash): repo = LeanGitRepo(lean4_example_url, example_commit_hash) assert repo.url == lean4_example_url assert repo.commit == example_commit_hash @@ -12,18 +101,4 @@ def test_lean_git_repo(lean4_example_url, example_commit_hash): assert repo.name == "lean4-example" assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" # test cache directory - assert ( - repo.format_dirname - == "yangky11-lean4-example-3f8c5eb303a225cdef609498b8d87262e5ef344b" - ) - # test commit hash - ## test commit hash - assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash - ## test branch, assume the branch is not changed - assert _to_commit_hash(repo, "paper") == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" - ## test tag - lean4_repo = LeanGitRepo(LEAN4_URL, "master") - assert ( - _to_commit_hash(lean4_repo, "v4.9.1") - == "1b78cb4836cf626007bd38872956a6fab8910993" - ) + assert repo.format_dirname == f"yangky11-lean4-example-{example_commit_hash}" From 6c7ef63875ab82a2c2f1112baaf5f45b1e347768 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Fri, 26 Jul 2024 04:04:05 +0800 Subject: [PATCH 10/11] fix url_exists --- src/lean_dojo/data_extraction/trace.py | 2 +- src/lean_dojo/utils.py | 8 ++++++-- tests/data_extraction/test_trace.py | 8 ++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index 56f2358..9527c5a 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -213,7 +213,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) traced_repo.save_to_disk() src_dir = tmp_dir / repo.name - rel_cache_dir = Path(repo.format_dirname / repo.name) + rel_cache_dir = Path(repo.format_dirname) / repo.name path = cache.store(src_dir, rel_cache_dir) else: logger.debug("The traced repo is available in the cache.") diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 12cf101..2ee132e 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -237,9 +237,13 @@ def read_url(url: str, num_retries: int = 2) -> str: @cache def url_exists(url: str) -> bool: - """Return True if the URL ``url`` exists.""" + """Return True if the URL ``url`` exists, using the GITHUB_ACCESS_TOKEN for authentication if provided.""" try: - with urllib.request.urlopen(url) as _: + request = urllib.request.Request(url) + gh_token = os.getenv("GITHUB_ACCESS_TOKEN") + if gh_token is not None: + request.add_header("Authorization", f"token {gh_token}") + with urllib.request.urlopen(request) as _: return True except urllib.error.HTTPError: return False diff --git a/tests/data_extraction/test_trace.py b/tests/data_extraction/test_trace.py index 0064933..71788fd 100644 --- a/tests/data_extraction/test_trace.py +++ b/tests/data_extraction/test_trace.py @@ -1,5 +1,13 @@ from pathlib import Path from lean_dojo import * +from lean_dojo.data_extraction.cache import cache + + +def test_example_trace(lean4_example_repo): + trace_repo = trace(lean4_example_repo) + repo = trace_repo.repo + path = cache.get(repo.url, repo.commit) + assert path is not None def test_trace(traced_repo): From e5941ff350445edb89ab17b70ad2e2a8be8195a0 Mon Sep 17 00:00:00 2001 From: rexwzh <1073853456@qq.com> Date: Fri, 26 Jul 2024 10:47:05 +0800 Subject: [PATCH 11/11] allow repo_type of git@github.com --- src/lean_dojo/data_extraction/lean.py | 34 +++++++++++++++++++++++-- src/lean_dojo/utils.py | 28 -------------------- tests/data_extraction/test_lean_repo.py | 5 +++- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index b9c5f60..ed82cdf 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -21,16 +21,15 @@ from ..constants import TMP_DIR import uuid import shutil +from urllib.parse import urlparse from ..utils import ( - execute, read_url, url_exists, get_repo_info, working_directory, is_git_repo, - repo_type_of_url, ) from ..constants import LEAN4_URL from .cache import _format_dirname @@ -56,6 +55,8 @@ _URL_REGEX = re.compile(r"(?P.*?)/*") +_SSH_TO_HTTPS_REGEX = re.compile(r"^git@github\.com:(.+)/(.+)(?:\.git)?$") + def normalize_url(url: str, repo_type: str = "github") -> str: if repo_type == "local": @@ -63,6 +64,35 @@ def normalize_url(url: str, repo_type: str = "github") -> str: return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`. +def repo_type_of_url(url: str) -> Union[str, None]: + """Get the type of the repository. + + Args: + url (str): The URL of the repository. + Returns: + str: The type of the repository. + """ + m = _SSH_TO_HTTPS_REGEX.match(url) + url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url + parsed_url = urlparse(url) + if parsed_url.scheme in ["http", "https"]: + # case 1 - GitHub URL + if "github.com" in url: + if not url.startswith("https://"): + logger.warning(f"{url} should start with https://") + return + else: + return "github" + # case 2 - remote URL + elif url_exists(url): # not check whether it is a git URL + return "remote" + # case 3 - local path + elif is_git_repo(Path(parsed_url.path)): + return "local" + logger.warning(f"{url} is not a valid URL") + return None + + @cache def url_to_repo( url: str, diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 2ee132e..0adf60f 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -16,7 +16,6 @@ from contextlib import contextmanager from ray.util.actor_pool import ActorPool from typing import Tuple, Union, List, Generator, Optional -from urllib.parse import urlparse from .constants import NUM_WORKERS, TMP_DIR, LEAN4_PACKAGES_DIR, LEAN4_BUILD_DIR @@ -145,33 +144,6 @@ def camel_case(s: str) -> str: return _CAMEL_CASE_REGEX.sub(" ", s).title().replace(" ", "") -def repo_type_of_url(url: str) -> Union[str, None]: - """Get the type of the repository. - - Args: - url (str): The URL of the repository. - Returns: - str: The type of the repository. - """ - parsed_url = urlparse(url) - if parsed_url.scheme in ["http", "https"]: - # case 1 - GitHub URL - if "github.com" in url: - if not url.startswith("https://"): - logger.warning(f"{url} should start with https://") - return - else: - return "github" - # case 2 - remote URL - elif url_exists(url): # not check whether it is a git URL - return "remote" - # case 3 - local path - elif is_git_repo(Path(parsed_url.path)): - return "local" - logger.warning(f"{url} is not a valid URL") - return None - - @cache def get_repo_info(path: Path) -> Tuple[str, str]: """Get the URL and commit hash of the Git repo at ``path``. diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py index c25b18c..cf8b4d4 100644 --- a/tests/data_extraction/test_lean_repo.py +++ b/tests/data_extraction/test_lean_repo.py @@ -3,9 +3,10 @@ from lean_dojo.constants import LEAN4_URL from git import Repo from github.Repository import Repository -from lean_dojo.utils import working_directory, repo_type_of_url +from lean_dojo.utils import working_directory from lean_dojo.data_extraction.lean import ( _to_commit_hash, + repo_type_of_url, url_to_repo, get_latest_commit, is_commit_hash, @@ -24,6 +25,8 @@ def test_url_to_repo(lean4_example_url, remote_example_url): ## test url_to_repo & repo_type_of_url github_repo = url_to_repo(lean4_example_url) assert repo_type_of_url(lean4_example_url) == "github" + assert repo_type_of_url("git@github.com:yangky11/lean4-example.git") == "github" + assert repo_type_of_url("git@github.com:yangky11/lean4-example") == "github" assert isinstance(github_repo, Repository) assert github_repo.name == repo_name