diff --git a/tools/update_required_branch_checks.py b/tools/update_required_branch_checks.py index f65faea9ccf..50aac1901bf 100644 --- a/tools/update_required_branch_checks.py +++ b/tools/update_required_branch_checks.py @@ -17,7 +17,7 @@ from github import Github from typing import List -YOUR_GITHUB_TOKEN = '' +YOUR_GITHUB_TOKEN = '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' TARGET_REPO = 'youtube/cobalt' @@ -28,21 +28,32 @@ LATEST_LTS_RELEASE_NUMBER = 24 +def get_protected_branches() -> List[str]: + branches = ['main'] + for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1): + branches.append(f'{i}.lts.1+') + return branches + + def parse_args(): parser = argparse.ArgumentParser() + parser.add_argument( + '-b', + '--branch', + action='append', + help='Branch to update. Can be repeated to update multiple branches.' + ' Defaults to all protected branches.') parser.add_argument( '--dry_run', action='store_true', default=False, help='Only print protection updates.') - return parser.parse_args() + args = parser.parse_args() + if not args.branch: + args.branch = get_protected_branches() -def get_protected_branches() -> List[str]: - branches = ['main'] - for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1): - branches.append(f'{i}.lts.1+') - return branches + return args def initialize_repo_connection(): @@ -92,9 +103,8 @@ def update_protection_for_branch(repo, branch: str, check_names: List[str]): def main(): args = parse_args() - branches = get_protected_branches() repo = initialize_repo_connection() - for branch in branches: + for branch in args.branch: required_checks = get_required_checks_for_branch(repo, branch) if args.dry_run: print_checks(branch, required_checks)