From 7f7d7a5f993c7f6710381c65787f19c4c031afb5 Mon Sep 17 00:00:00 2001 From: Andrew Savage Date: Wed, 21 Jun 2023 16:59:43 +0000 Subject: [PATCH] Add dry run option --- tools/update_required_branch_checks.py | 50 ++++++++++++++++++++------ 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/tools/update_required_branch_checks.py b/tools/update_required_branch_checks.py index 3eb5d778b71b..f65faea9ccfa 100644 --- a/tools/update_required_branch_checks.py +++ b/tools/update_required_branch_checks.py @@ -13,16 +13,31 @@ # limitations under the License. """Updates the requires status checks for a branch.""" +import argparse from github import Github from typing import List YOUR_GITHUB_TOKEN = '' +TARGET_REPO = 'youtube/cobalt' + +EXCLUDED_CHECK_PATTERNS = ['feedback/copybara', '_on_device_', r'${{'] + # Exclude rc_11 and COBALT_9 releases. MINIMUM_LTS_RELEASE_NUMBER = 19 LATEST_LTS_RELEASE_NUMBER = 24 +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--dry_run', + action='store_true', + default=False, + help='Only print protection updates.') + return parser.parse_args() + + def get_protected_branches() -> List[str]: branches = ['main'] for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1): @@ -32,23 +47,27 @@ def get_protected_branches() -> List[str]: def initialize_repo_connection(): g = Github(YOUR_GITHUB_TOKEN) - return g.get_repo('youtube/cobalt') + return g.get_repo(TARGET_REPO) def get_checks_for_branch(repo, branch: str): - latest_open_pr = repo.get_pulls( - state='open', sort='updated', base=branch, direction='desc')[0] - latest_pr_commit = repo.get_commit(latest_open_pr.head) + prs = repo.get_pulls( + state='open', sort='updated', base=branch, direction='desc') + try: + latest_pr = prs[0] + except IndexError: + prs = repo.get_pulls( + state='closed', sort='updated', base=branch, direction='desc') + latest_pr = prs[0] + latest_pr_commit = repo.get_commit(latest_pr.head.sha) checks = latest_pr_commit.get_check_runs() return checks def should_include_run(check_run) -> bool: - # Filter out check runs that have '_on_device_' in the name. - if '_on_device_' in check_run.name: - return False - if check_run.name == 'feedback/copybara': - return False + for pattern in EXCLUDED_CHECK_PATTERNS: + if pattern in check_run.name: + return False return True @@ -59,17 +78,28 @@ def get_required_checks_for_branch(repo, branch: str) -> List[str]: return check_names +def print_checks(branch: str, check_names: List[str]): + print(f'Checks for {branch}:') + for check_name in check_names: + print(check_name) + print() + + def update_protection_for_branch(repo, branch: str, check_names: List[str]): branch = repo.get_branch(branch) branch.edit_protection(contexts=check_names) def main(): + args = parse_args() branches = get_protected_branches() repo = initialize_repo_connection() for branch in branches: required_checks = get_required_checks_for_branch(repo, branch) - update_protection_for_branch(repo, branch, required_checks) + if args.dry_run: + print_checks(branch, required_checks) + else: + update_protection_for_branch(repo, branch, required_checks) if __name__ == '__main__':