Skip to content

Commit

Permalink
Add dry run option
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewsavage1 committed Jun 21, 2023
1 parent 50db015 commit 7f7d7a5
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions tools/update_required_branch_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,31 @@
# limitations under the License.
"""Updates the requires status checks for a branch."""

import argparse
from github import Github
from typing import List

YOUR_GITHUB_TOKEN = ''

TARGET_REPO = 'youtube/cobalt'

EXCLUDED_CHECK_PATTERNS = ['feedback/copybara', '_on_device_', r'${{']

# Exclude rc_11 and COBALT_9 releases.
MINIMUM_LTS_RELEASE_NUMBER = 19
LATEST_LTS_RELEASE_NUMBER = 24


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--dry_run',
action='store_true',
default=False,
help='Only print protection updates.')
return parser.parse_args()


def get_protected_branches() -> List[str]:
branches = ['main']
for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1):
Expand All @@ -32,23 +47,27 @@ def get_protected_branches() -> List[str]:

def initialize_repo_connection():
g = Github(YOUR_GITHUB_TOKEN)
return g.get_repo('youtube/cobalt')
return g.get_repo(TARGET_REPO)


def get_checks_for_branch(repo, branch: str):
latest_open_pr = repo.get_pulls(
state='open', sort='updated', base=branch, direction='desc')[0]
latest_pr_commit = repo.get_commit(latest_open_pr.head)
prs = repo.get_pulls(
state='open', sort='updated', base=branch, direction='desc')
try:
latest_pr = prs[0]
except IndexError:
prs = repo.get_pulls(
state='closed', sort='updated', base=branch, direction='desc')
latest_pr = prs[0]
latest_pr_commit = repo.get_commit(latest_pr.head.sha)
checks = latest_pr_commit.get_check_runs()
return checks


def should_include_run(check_run) -> bool:
# Filter out check runs that have '_on_device_' in the name.
if '_on_device_' in check_run.name:
return False
if check_run.name == 'feedback/copybara':
return False
for pattern in EXCLUDED_CHECK_PATTERNS:
if pattern in check_run.name:
return False
return True


Expand All @@ -59,17 +78,28 @@ def get_required_checks_for_branch(repo, branch: str) -> List[str]:
return check_names


def print_checks(branch: str, check_names: List[str]):
print(f'Checks for {branch}:')
for check_name in check_names:
print(check_name)
print()


def update_protection_for_branch(repo, branch: str, check_names: List[str]):
branch = repo.get_branch(branch)
branch.edit_protection(contexts=check_names)


def main():
args = parse_args()
branches = get_protected_branches()
repo = initialize_repo_connection()
for branch in branches:
required_checks = get_required_checks_for_branch(repo, branch)
update_protection_for_branch(repo, branch, required_checks)
if args.dry_run:
print_checks(branch, required_checks)
else:
update_protection_for_branch(repo, branch, required_checks)


if __name__ == '__main__':
Expand Down

0 comments on commit 7f7d7a5

Please sign in to comment.