From fee7a73c2c70edabbee8f210fed8972c63989bd5 Mon Sep 17 00:00:00 2001 From: Ned Batchelder Date: Tue, 21 Nov 2023 11:39:28 -0500 Subject: [PATCH] refactor: a helper for paged results "When using a paged API endpoint, I want an iterable of pages" -- No one ever. This helper lets us focus on the things we care about: the items we are asking for. --- edx_repo_tools/repo_checks/repo_checks.py | 55 +++++++++++------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/edx_repo_tools/repo_checks/repo_checks.py b/edx_repo_tools/repo_checks/repo_checks.py index 8eaaa8e8..a522fb01 100644 --- a/edx_repo_tools/repo_checks/repo_checks.py +++ b/edx_repo_tools/repo_checks/repo_checks.py @@ -43,6 +43,13 @@ cache = lru_cache(maxsize=None) +def all_paged_items(func, *args, **kwargs): + """ + Get all items from a GhApi function returning paged results. + """ + return chain.from_iterable(paged(func, *args, per_page=100, **kwargs)) + + def is_security_private_fork(api, org, repo): """ Check to see if a specific repo is a private security fork. @@ -373,14 +380,11 @@ def fix(self, dry_run=False): ) # Check to see if a PR exists - prs = chain.from_iterable( - paged( - self.api.pulls.list, - owner=self.org_name, - repo=self.repo_name, - head=self.branch_name, - per_page=100, - ) + prs = all_paged_items( + self.api.pulls.list, + owner=self.org_name, + repo=self.repo_name, + head=self.branch_name, ) prs = [pr for pr in prs if pr.head.ref == self.branch_name] @@ -438,13 +442,10 @@ def check(self): """ See if our labels exist. """ - existing_labels_from_api = chain.from_iterable( - paged( - self.api.issues.list_labels_for_repo, - self.org_name, - self.repo_name, - per_page=100, - ) + existing_labels_from_api = all_paged_items( + self.api.issues.list_labels_for_repo, + self.org_name, + self.repo_name, ) existing_labels = { self._simplify_label(label.name): { @@ -555,13 +556,10 @@ def is_relevant(self): raise NotImplementedError def check(self): - teams = chain.from_iterable( - paged( - self.api.repos.list_teams, - self.org_name, - self.repo_name, - per_page=100, - ) + teams = all_paged_items( + self.api.repos.list_teams, + self.org_name, + self.repo_name, ) team_permissions = {team.slug: team.permission for team in teams} @@ -943,14 +941,11 @@ def main(org, dry_run, _github_token, check_names, repos, start_at): if not repos: repos = [ repo.name - for repo in chain.from_iterable( - paged( - api.repos.list_for_org, - org, - sort="created", - direction="desc", - per_page=100, - ) + for repo in all_paged_items( + api.repos.list_for_org, + org, + sort="created", + direction="desc", ) ]