Skip to content

Commit

Permalink
Update trymerge
Browse files Browse the repository at this point in the history
  • Loading branch information
malfet committed Dec 27, 2023
1 parent 68b56e3 commit c3ecd42
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions .github/scripts/trymerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,18 @@
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Pattern, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Pattern,
Tuple,
)
from warnings import warn

import yaml
Expand Down Expand Up @@ -615,11 +626,11 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -
def _revlist_to_prs(
repo: GitRepo,
pr: "GitHubPR",
rev_list: List[str],
rev_list: Iterable[str],
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
) -> List[Tuple["GitHubPR", str]]:
rc: List[Tuple[GitHubPR, str]] = []
for idx, rev in enumerate(reversed(rev_list)):
for idx, rev in enumerate(rev_list):
msg = repo.commit_message(rev)
m = RE_PULL_REQUEST_RESOLVED.search(msg)
if m is None:
Expand Down Expand Up @@ -658,7 +669,7 @@ def skip_func(idx: int, candidate: "GitHubPR") -> bool:
return True

assert pr.is_ghstack_pr()
entire_stack = _revlist_to_prs(repo, pr, rev_list, skip_func)
entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)

for stacked_pr, rev in entire_stack:
if stacked_pr.is_closed():
Expand Down Expand Up @@ -1790,22 +1801,38 @@ def validate_revert(


def get_ghstack_dependent_prs(
repo: GitRepo, pr: GitHubPR
repo: GitRepo, pr: GitHubPR, only_closed: bool = True
) -> List[Tuple[str, GitHubPR]]:
"""
Get the PRs in the stack that are above this PR (inclusive).
Throws error if stack have branched or original branches are gone
"""
assert pr.is_ghstack_pr()
orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
if len(rev_list) == 0:
raise RuntimeError(
f"PR {pr.pr_num} does not have any revisions associated with it"
)
skip_len = len(rev_list) - 1
for branch in repo.branches_containing_ref(orig_ref):
candidate = repo.revlist(f"{pr.default_branch()}..{branch}")
# Pick longest candidate
if len(candidate) > len(rev_list):
candidate, rev_list = rev_list, candidate
# Validate that candidate always ends rev-list
if rev_list[-len(candidate) :] != candidate:
raise RuntimeError("Ieee...")
raise RuntimeError(
f"Branch {branch} revlist {', '.join(candidate)} is not a subset of {', '.join(rev_list)}"
)
# Remove commits original PR depends on
if skip_len > 0:
rev_list = rev_list[:-skip_len]
rc: List[Tuple[str, GitHubPR]] = []
for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
if not pr_.is_closed():
if not only_closed:
rc.append(("", pr_))
continue
commit_sha = get_pr_commit_sha(repo, pr_)
rc.append((commit_sha, pr_))
Expand Down Expand Up @@ -1876,11 +1903,17 @@ def try_revert(
if comment_id is not None
else "\n"
)
shas_and_prs = (
get_ghstack_dependent_prs(repo, pr)
if pr.is_ghstack_pr()
else [(commit_sha, pr)]
)
shas_and_prs = [(commit_sha, pr)]
if pr.is_ghstack_pr():
try:
shas_and_prs = get_ghstack_dependent_prs(repo, pr)
prs_to_revert = " ".join([t[1].get_pr_url() for t in shas_and_prs])
print(f"About to stack of PRs: {prs_to_revert}")
except Exception as e:
print(
f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert"
)

do_revert_prs(
repo,
shas_and_prs,
Expand Down

0 comments on commit c3ecd42

Please sign in to comment.