diff --git a/src/api_clients.rs b/src/api_clients.rs index b537f55..008e6c0 100644 --- a/src/api_clients.rs +++ b/src/api_clients.rs @@ -2,36 +2,56 @@ use std::collections::HashMap; use std::env; use std::sync::Arc; -use anyhow::{anyhow, Context, Ok}; +use anyhow::Context; use octocrab::Octocrab; +use tokio::sync::{AcquireError, Semaphore, SemaphorePermit}; use crate::remote::Remote; +#[derive(Debug)] +pub struct Client { + semaphore: Semaphore, + octocrab: Arc, +} + +impl Client { + pub async fn lock(&self) -> Result<(SemaphorePermit<'_>, &Arc), AcquireError> { + let permit = self.semaphore.acquire().await?; + Ok((permit, &self.octocrab)) + } +} + pub struct ClientSet { - octocrab: HashMap>, + clients: HashMap>, } impl ClientSet { pub fn new() -> Self { ClientSet { - octocrab: HashMap::new(), + clients: HashMap::new(), } } - pub fn add(&mut self, remote: &Remote) -> Result<(), anyhow::Error> { + pub fn fill(&mut self, remote: &mut Remote) -> Result<(), anyhow::Error> { + let host = remote.host.to_string(); + let client = self.get_client(&host)?; + remote.client = Some(client); + Ok(()) + } + + fn get_client(&mut self, host: &str) -> Result, anyhow::Error> { + if let Some(client) = self.clients.get(host) { + return Ok(client.clone()); + } + let mut api_endpoint = "https://api.github.com".to_string(); let mut env_name = "GITHUB_TOKEN".to_string(); - if remote.host.to_string() != "github.com" { - api_endpoint = format!("https://{}/api/v3", &remote.host); + if host != "github.com" { + api_endpoint = format!("https://{host}/api/v3"); env_name = format!( "GITHUB_{}_TOKEN", - &remote - .host - .to_string() - .replace('.', "_") - .to_uppercase() - .trim_start_matches("GITHUB_") + host.replace('.', "_").to_uppercase().trim_start_matches("GITHUB_") ); }; @@ -43,14 +63,11 @@ impl ClientSet { .build() .context("failed to build octocrab client")?, ); - self.octocrab.insert(remote.host.to_string(), octocrab::instance()); - - Ok(()) - } - - pub fn get(&self, remote: &Remote) -> Result<&Arc, anyhow::Error> { - self.octocrab - .get(&remote.host.to_string()) - .ok_or(anyhow!("no api client for {}", &remote.host)) + let client = Arc::new(Client { + semaphore: Semaphore::new(5), // i.e. up to 5 API calls in parallel to the same GitHub instance + octocrab: octocrab::instance(), + }); + self.clients.insert(host.to_owned(), client.clone()); + Ok(client) } } diff --git a/src/changes.rs b/src/changes.rs index 30de182..483cc56 100644 --- a/src/changes.rs +++ b/src/changes.rs @@ -3,7 +3,6 @@ use octocrab::models::commits::Commit; use octocrab::models::pulls::Review; use octocrab::models::pulls::ReviewState::Approved; -use crate::api_clients::ClientSet; use crate::remote::Remote; #[derive(Clone, Debug)] @@ -16,21 +15,18 @@ pub struct RepoChangeset { } impl RepoChangeset { - pub async fn analyze_commits(&mut self, client_set: &ClientSet) -> Result<(), anyhow::Error> { - let compare = self - .remote - .compare(client_set, &self.base_commit, &self.head_commit) - .await?; + pub async fn analyze_commits(&mut self) -> Result<(), anyhow::Error> { + let compare = self.remote.compare(&self.base_commit, &self.head_commit).await?; for commit in &compare.commits { - self.analyze_commit(client_set, commit).await?; + self.analyze_commit(commit).await?; } Ok(()) } - async fn analyze_commit(&mut self, client_set: &ClientSet, commit: &Commit) -> Result<(), anyhow::Error> { - let associated_prs = self.remote.associated_prs(client_set, commit).await?; + async fn analyze_commit(&mut self, commit: &Commit) -> Result<(), anyhow::Error> { + let associated_prs = self.remote.associated_prs(commit).await?; let change_commit = CommitMetadata::new(commit); if associated_prs.is_empty() { @@ -43,7 +39,7 @@ impl RepoChangeset { } for associated_pr in &associated_prs { - let pr_reviews = self.remote.pr_reviews(client_set, associated_pr.number).await?; + let pr_reviews = self.remote.pr_reviews(associated_pr.number).await?; let associated_pr_link = Some( associated_pr @@ -53,7 +49,7 @@ impl RepoChangeset { .to_string(), ); - let head_sha = self.remote.pr_head_hash(client_set, associated_pr.number).await?; + let head_sha = self.remote.pr_head_hash(associated_pr.number).await?; if let Some(changeset) = self.changes.iter_mut().find(|cs| cs.pr_link == associated_pr_link) { changeset.commits.push(change_commit.clone()); diff --git a/src/main.rs b/src/main.rs index ca11c4c..b10a426 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,8 +3,8 @@ mod api_clients; mod changes; mod helm_config; -mod repo; mod remote; +mod repo; use std::str; use std::sync::LazyLock; @@ -16,8 +16,8 @@ use clap::builder::styling::Style; use clap::{Parser, Subcommand}; use git2::Repository; use helm_config::ImageRefs; -use url::{Host, Url}; use remote::Remote; +use url::{Host, Url}; const BOLD_UNDERLINE: Style = Style::new().bold().underline(); static GITHUB_TOKEN_HELP: LazyLock = LazyLock::new(|| { @@ -84,8 +84,8 @@ async fn main() -> Result<(), anyhow::Error> { match &cli.command { Commands::Repo { remote } => { - let remote = Remote::parse(remote)?; - api_clients.add(&remote)?; + let mut remote = Remote::parse(remote)?; + api_clients.fill(&mut remote)?; let repo = &mut RepoChangeset { name: remote.repository.clone(), remote, @@ -93,9 +93,7 @@ async fn main() -> Result<(), anyhow::Error> { head_commit: cli.head, changes: Vec::new(), }; - repo.analyze_commits(&api_clients) - .await - .context("while finding reviews")?; + repo.analyze_commits().await.context("while finding reviews")?; print_changes(&[repo.clone()])?; }, Commands::HelmChart { workspace } => { @@ -103,10 +101,8 @@ async fn main() -> Result<(), anyhow::Error> { find_values_yaml(workspace.clone(), &cli.base, &cli.head).context("while finding values.yaml files")?; for repo in &mut changes { - api_clients.add(&repo.remote)?; - repo.analyze_commits(&api_clients) - .await - .context("while collecting repo changes")?; + api_clients.fill(&mut repo.remote)?; + repo.analyze_commits().await.context("while collecting repo changes")?; } print_changes(&changes)?; diff --git a/src/remote.rs b/src/remote.rs index 98672e7..4b10b5f 100644 --- a/src/remote.rs +++ b/src/remote.rs @@ -3,9 +3,12 @@ use octocrab::commits::PullRequestTarget; use octocrab::models::commits::{Commit, CommitComparison}; use octocrab::models::pulls::{PullRequest, Review}; use octocrab::models::repos::RepoCommit; +use octocrab::Octocrab; +use std::sync::Arc; +use tokio::sync::SemaphorePermit; use url::Url; -use crate::api_clients::ClientSet; +use crate::api_clients::Client; #[derive(Clone, Debug)] #[allow(dead_code)] @@ -15,6 +18,7 @@ pub struct Remote { pub owner: String, pub repository: String, pub original: String, + pub client: Option>, } impl Remote { @@ -27,16 +31,19 @@ impl Remote { owner: path_elements[0].to_string(), repository: path_elements[1].trim_end_matches(".git").to_string(), original: url.into(), + client: None, }) } - pub async fn associated_prs( - &self, - client_set: &ClientSet, - commit: &Commit, - ) -> Result, anyhow::Error> { - let mut associated_prs_page = client_set - .get(self)? + async fn get_client(&self) -> Result<(SemaphorePermit<'_>, &Arc), anyhow::Error> { + let client = self.client.as_ref().ok_or(anyhow!("no client attached to remote"))?; + client.lock().await.context("cannot obtain semaphore for client") + } + + pub async fn associated_prs(&self, commit: &Commit) -> Result, anyhow::Error> { + let (_permit, octocrab) = self.get_client().await?; + + let mut associated_prs_page = octocrab .commits(&self.owner, &self.repository) .associated_pull_requests(PullRequestTarget::Sha(commit.clone().sha.clone())) .send() @@ -49,14 +56,10 @@ impl Remote { Ok(associated_prs_page.take_items()) } - pub async fn compare( - &self, - client_set: &ClientSet, - base_commit: &str, - head_commit: &str, - ) -> Result { - client_set - .get(self)? + pub async fn compare(&self, base_commit: &str, head_commit: &str) -> Result { + let (_permit, octocrab) = self.get_client().await?; + + octocrab .commits(&self.owner, &self.repository) .compare(base_commit, head_commit) .send() @@ -69,9 +72,9 @@ impl Remote { )) } - pub async fn pr_head_hash(&self, client_set: &ClientSet, pr_number: u64) -> Result { + pub async fn pr_head_hash(&self, pr_number: u64) -> Result { Ok(self - .pr_commits(client_set, pr_number) + .pr_commits(pr_number) .await .context("failed to get pr commits")? .last() @@ -80,9 +83,10 @@ impl Remote { .clone()) } - pub async fn pr_commits(&self, client_set: &ClientSet, pr_number: u64) -> Result, anyhow::Error> { - let mut pr_commits_page = client_set - .get(self)? + pub async fn pr_commits(&self, pr_number: u64) -> Result, anyhow::Error> { + let (_permit, octocrab) = self.get_client().await?; + + let mut pr_commits_page = octocrab .pulls(&self.owner, &self.repository) .pr_commits(pr_number) .await @@ -101,9 +105,10 @@ impl Remote { Ok(pr_commits) } - pub async fn pr_reviews(&self, client_set: &ClientSet, pr_number: u64) -> Result, anyhow::Error> { - let mut pr_reviews_page = client_set - .get(self)? + pub async fn pr_reviews(&self, pr_number: u64) -> Result, anyhow::Error> { + let (_permit, octocrab) = self.get_client().await?; + + let mut pr_reviews_page = octocrab .pulls(&self.owner, &self.repository) .list_reviews(pr_number) .send()