Skip to content

Commit

Permalink
Add more type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewsavage1 committed Jun 27, 2023
1 parent 872c839 commit 49c688a
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions tools/update_required_branch_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@
from typing import List

YOUR_GITHUB_TOKEN = ''
assert YOUR_GITHUB_TOKEN != '', 'YOUR_GITHUB_TOKEN must be set.'

TARGET_REPO = 'youtube/cobalt'

EXCLUDED_CHECK_PATTERNS = ['feedback/copybara', '_on_device_', r'${{']
EXCLUDED_CHECK_PATTERNS = [
'feedback/copybara',
'_on_device_',
'codecov',
# Excludes templated check names.
'${{'
]

# Exclude rc_11 and COBALT_9 releases.
MINIMUM_LTS_RELEASE_NUMBER = 19
Expand All @@ -35,41 +42,21 @@ def get_protected_branches() -> List[str]:
return branches


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'-b',
'--branch',
action='append',
help='Branch to update. Can be repeated to update multiple branches.'
' Defaults to all protected branches.')
parser.add_argument(
'--dry_run',
action='store_true',
default=False,
help='Only print protection updates.')
args = parser.parse_args()

if not args.branch:
args.branch = get_protected_branches()

return args


def initialize_repo_connection():
g = Github(YOUR_GITHUB_TOKEN)
return g.get_repo(TARGET_REPO)


def get_checks_for_branch(repo, branch: str):
def get_checks_for_branch(repo, branch: str) -> None:
prs = repo.get_pulls(
state='open', sort='updated', base=branch, direction='desc')
try:
latest_pr = prs[0]
except IndexError:
prs = repo.get_pulls(
state='closed', sort='updated', base=branch, direction='desc')
latest_pr = prs[0]
state='closed', sort='updated', base=branch, direction='desc')

latest_pr = None
for pr in prs:
if pr.merged:
latest_pr = pr
break

latest_pr_commit = repo.get_commit(latest_pr.head.sha)
checks = latest_pr_commit.get_check_runs()
return checks
Expand All @@ -89,19 +76,41 @@ def get_required_checks_for_branch(repo, branch: str) -> List[str]:
return check_names


def print_checks(branch: str, check_names: List[str]):
def print_checks(branch: str, check_names: List[str]) -> None:
print(f'Checks for {branch}:')
for check_name in check_names:
print(check_name)
print()


def update_protection_for_branch(repo, branch: str, check_names: List[str]):
def update_protection_for_branch(repo, branch: str,
check_names: List[str]) -> None:
branch = repo.get_branch(branch)
branch.edit_protection(contexts=check_names)


def main():
def parse_args() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
'-b',
'--branch',
action='append',
help='Branch to update. Can be repeated to update multiple branches.'
' Defaults to all protected branches.')
parser.add_argument(
'--dry_run',
action='store_true',
default=False,
help='Only print protection updates.')
args = parser.parse_args()

if not args.branch:
args.branch = get_protected_branches()

return args


def main() -> None:
args = parse_args()
repo = initialize_repo_connection()
for branch in args.branch:
Expand Down

0 comments on commit 49c688a

Please sign in to comment.