Skip to content

Commit

Permalink
Add loop across protected branches
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewsavage1 committed Jun 21, 2023
1 parent 9135cac commit 50db015
Showing 1 changed file with 48 additions and 17 deletions.
65 changes: 48 additions & 17 deletions tools/update_required_branch_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,63 @@
"""Updates the requires status checks for a branch."""

from github import Github
from typing import List

BRANCH_NAME = 'main'
YOUR_GITHUB_TOKEN = ''

YOUR_TOKEN = ''
g = Github(YOUR_TOKEN)
# Exclude rc_11 and COBALT_9 releases.
MINIMUM_LTS_RELEASE_NUMBER = 19
LATEST_LTS_RELEASE_NUMBER = 24

repo = g.get_repo('youtube/cobalt')
latest_open_pr = repo.get_pulls(
state='open', sort='updated', base=BRANCH_NAME, direction='desc')[0]
latest_pr_commit = repo.get_commit(latest_open_pr.head)
check_runs = latest_pr_commit.get_check_runs()

def get_protected_branches() -> List[str]:
branches = ['main']
for i in range(MINIMUM_LTS_RELEASE_NUMBER, LATEST_LTS_RELEASE_NUMBER + 1):
branches.append(f'{i}.lts.1+')
return branches

# Filter out check runs that have '_on_device_' in the name
def should_include_run(check_run):

def initialize_repo_connection():
g = Github(YOUR_GITHUB_TOKEN)
return g.get_repo('youtube/cobalt')


def get_checks_for_branch(repo, branch: str):
latest_open_pr = repo.get_pulls(
state='open', sort='updated', base=branch, direction='desc')[0]
latest_pr_commit = repo.get_commit(latest_open_pr.head)
checks = latest_pr_commit.get_check_runs()
return checks


def should_include_run(check_run) -> bool:
# Filter out check runs that have '_on_device_' in the name.
if '_on_device_' in check_run.name:
return False
if check_run.name == 'feedback/copybara':
return False
return True


filtered_check_runs = [run for run in check_runs if should_include_run(run)]
check_names = [run.name for run in filtered_check_runs]
def get_required_checks_for_branch(repo, branch: str) -> List[str]:
checks = get_checks_for_branch(repo, branch)
filtered_check_runs = [run for run in checks if should_include_run(run)]
check_names = [run.name for run in filtered_check_runs]
return check_names


def update_protection_for_branch(repo, branch: str, check_names: List[str]):
branch = repo.get_branch(branch)
branch.edit_protection(contexts=check_names)


def main():
branches = get_protected_branches()
repo = initialize_repo_connection()
for branch in branches:
required_checks = get_required_checks_for_branch(repo, branch)
update_protection_for_branch(repo, branch, required_checks)

branch = repo.get_branch(BRANCH_NAME)
protection = branch.get_protection()
check_names += protection.required_status_checks.contexts
checks = set(check_names)

branch.edit_protection(contexts=check_names)
if __name__ == '__main__':
main()

0 comments on commit 50db015

Please sign in to comment.