From 7a455f4199c34fd884e0dcd6f586b840fde0b3e3 Mon Sep 17 00:00:00 2001 From: "Kyle D. McCormick" Date: Mon, 5 Aug 2024 13:26:12 -0400 Subject: [PATCH] refactor: register repo checks with a decorator Removes a bit of redundancy / potential typos inherent in the CHECK list. This will also make it easier to break repo_checks.py into multiple modules, if we ever decide to do that. --- edx_repo_tools/repo_checks/repo_checks.py | 40 ++++++++++++++--------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/edx_repo_tools/repo_checks/repo_checks.py b/edx_repo_tools/repo_checks/repo_checks.py index d817e957..a0807e5a 100644 --- a/edx_repo_tools/repo_checks/repo_checks.py +++ b/edx_repo_tools/repo_checks/repo_checks.py @@ -14,6 +14,7 @@ import importlib.resources import re import textwrap +import typing as t from functools import cache from itertools import chain from pprint import pformat @@ -104,11 +105,24 @@ class Check: (is_relevant, check, fix, and dry_run). """ + _registered = {} + def __init__(self, api: GhApi, org: str, repo: str): self.api = api self.org_name = org self.repo_name = repo + @staticmethod + def register(check_subclass: type[t.Self]): + """ + Decorate a Check subclass so that it will be available in main() + """ + Check._registered[check_subclass.__name__] = check_subclass + + @staticmethod + def get_registered_checks() -> dict[str, type[t.Self]]: + return Check._registered.copy() + def is_relevant(self) -> bool: """ Checks to see if the given check is relevant to run on the @@ -152,6 +166,7 @@ def dry_run(self): raise NotImplementedError +@Check.register class EnsureRepoSettings(Check): """ There are certain settings that we agree we want to be set a specific way on all repos. This check @@ -240,6 +255,7 @@ def fix(self, dry_run=False): return steps +@Check.register class EnsureNoAdminOrMaintainTeams(Check): """ Teams should not be granted `admin` or `maintain` access to a repository unless the access @@ -309,6 +325,7 @@ def fix(self, dry_run=False): return steps +@Check.register class EnsureWorkflowTemplates(Check): """ There are certain github action workflows that we to exist on all @@ -594,6 +611,7 @@ def fix(self, dry_run=False): return steps +@Check.register class EnsureLabels(Check): """ All repos in the org should have certain labels. @@ -782,6 +800,7 @@ def fix(self, dry_run=False): raise +@Check.register class RequireTriageTeamAccess(RequireTeamPermission): """ Ensure that the openedx-triage team grants Triage access to every public repo in the org. @@ -797,6 +816,7 @@ def is_relevant(self): return is_public(self.api, self.org_name, self.repo_name) +@Check.register class RequiredCLACheck(Check): """ This class validates the following: @@ -1057,6 +1077,7 @@ def _get_update_params_from_get_branch_protection(self): return params +@Check.register class EnsureNoDirectRepoAccessToUsers(Check): """ Users should not have direct repo access @@ -1114,19 +1135,6 @@ def fix(self, dry_run=False): return steps -CHECKS = [ - RequiredCLACheck, - RequireTriageTeamAccess, - EnsureLabels, - EnsureWorkflowTemplates, - EnsureNoAdminOrMaintainTeams, - EnsureRepoSettings, - EnsureNoDirectRepoAccessToUsers, -] -CHECKS_BY_NAME = {check_cls.__name__: check_cls for check_cls in CHECKS} -CHECKS_BY_NAME_LOWER = {check_cls.__name__.lower(): check_cls for check_cls in CHECKS} - - @click.command() @click.option( "--github-token", @@ -1154,7 +1162,7 @@ def fix(self, dry_run=False): "check_names", default=None, multiple=True, - type=click.Choice(CHECKS_BY_NAME.keys(), case_sensitive=False), + type=click.Choice(Check.get_registered_checks().keys(), case_sensitive=False), help=f"Limit to specific check(s), case-insensitive.", ) @click.option( @@ -1193,9 +1201,9 @@ def main(org, dry_run, _github_token, check_names, repos, start_at): click.secho("No Actual Changes Being Made", fg="yellow") if check_names: - active_checks = [CHECKS_BY_NAME[check_name] for check_name in check_names] + active_checks = [Check.get_registered_checks()[check_name] for check_name in check_names] else: - active_checks = CHECKS + active_checks = list(Check.get_registered_checks().values()) click.secho(f"The following checks will be run:", fg="magenta", bold=True) active_checks_string = "\n".join( "\t" + check_cls.__name__ for check_cls in active_checks