Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: reload permission on refresh event #54

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/bors/handlers/refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ pub async fn refresh_repository<Client: RepositoryClient>(
repo: &mut RepositoryState<Client>,
db: &dyn DbClient,
) -> anyhow::Result<()> {
let timeout = repo.config.timeout;
let res = cancel_timed_out_builds(repo, db).await;
reload_permission(repo).await;

res
}

async fn cancel_timed_out_builds<Client: RepositoryClient>(
repo: &mut RepositoryState<Client>,
db: &dyn DbClient,
) -> anyhow::Result<()> {
let running_builds = db.get_running_builds(&repo.repository).await?;
tracing::info!("Found {} running build(s)", running_builds.len());

for build in running_builds {
let timeout = repo.config.timeout;
if elapsed_time(build.created_at) >= timeout {
tracing::info!("Cancelling build {}", build.commit_sha);

Expand All @@ -41,10 +50,13 @@ pub async fn refresh_repository<Client: RepositoryClient>(
}
}
}

Ok(())
}

async fn reload_permission<Client: RepositoryClient>(repo: &mut RepositoryState<Client>) {
repo.permissions_resolver.reload().await
}

#[cfg(not(test))]
fn now() -> DateTime<Utc> {
Utc::now()
Expand All @@ -68,6 +80,7 @@ fn elapsed_time(date: DateTime<Utc>) -> Duration {
#[cfg(test)]
mod tests {
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use chrono::Utc;
Expand All @@ -77,6 +90,7 @@ mod tests {
use crate::bors::handlers::trybuild::TRY_BRANCH_NAME;
use crate::database::DbClient;
use crate::tests::event::{default_pr_number, WorkflowStartedBuilder};
use crate::tests::permissions::MockPermissions;
use crate::tests::state::{default_repo_name, ClientBuilder, RepoConfigBuilder};

#[tokio::test(flavor = "current_thread")]
Expand All @@ -85,6 +99,17 @@ mod tests {
state.refresh().await;
}

#[tokio::test(flavor = "current_thread")]
async fn refresh_permission() {
let permission_resolver = Arc::new(Mutex::new(MockPermissions::default()));
let mut state = ClientBuilder::default()
.permission_resolver(Box::new(Arc::clone(&permission_resolver)))
.create_state()
.await;
state.refresh().await;
assert_eq!(permission_resolver.lock().unwrap().num_reload_called, 1);
}

#[tokio::test(flavor = "current_thread")]
async fn refresh_do_nothing_before_timeout() {
let mut state = ClientBuilder::default()
Expand Down
41 changes: 8 additions & 33 deletions src/permissions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use axum::async_trait;
use std::collections::HashSet;
use std::time::{Duration, SystemTime};
use tokio::sync::Mutex;

use crate::github::GithubRepoName;
Expand All @@ -16,15 +15,13 @@
#[async_trait]
pub trait PermissionResolver {
async fn has_permission(&self, username: &str, permission: PermissionType) -> bool;
async fn reload(&self);
}

/// For how long should the permissions be cached.
const CACHE_DURATION: Duration = Duration::from_secs(60);

/// Loads permission information from the Rust Team API.
pub struct TeamApiPermissionResolver {
repo: GithubRepoName,
permissions: Mutex<CachedUserPermissions>,
permissions: Mutex<UserPermissions>,
}

impl TeamApiPermissionResolver {
Expand All @@ -33,14 +30,13 @@

Ok(Self {
repo,
permissions: Mutex::new(CachedUserPermissions::new(permissions)),
permissions: Mutex::new(permissions),
})
}

async fn reload_permissions(&self) {
let result = load_permissions(&self.repo).await;
match result {
Ok(perms) => *self.permissions.lock().await = CachedUserPermissions::new(perms),
Ok(perms) => *self.permissions.lock().await = perms,
Err(error) => {
tracing::error!("Cannot reload permissions for {}: {error:?}", self.repo);
}
Expand All @@ -51,16 +47,15 @@
#[async_trait]
impl PermissionResolver for TeamApiPermissionResolver {
async fn has_permission(&self, username: &str, permission: PermissionType) -> bool {
if self.permissions.lock().await.is_stale() {
self.reload_permissions().await;
}

self.permissions
.lock()
.await
.permissions
.has_permission(username, permission)
}

async fn reload(&self) {
self.reload_permissions().await
}
}

pub struct UserPermissions {
Expand All @@ -77,26 +72,6 @@
}
}

struct CachedUserPermissions {
permissions: UserPermissions,
created_at: SystemTime,
}
impl CachedUserPermissions {
fn new(permissions: UserPermissions) -> Self {
Self {
permissions,
created_at: SystemTime::now(),
}
}

fn is_stale(&self) -> bool {
self.created_at
.elapsed()
.map(|duration| duration > CACHE_DURATION)
.unwrap_or(true)
}
}

async fn load_permissions(repo: &GithubRepoName) -> anyhow::Result<UserPermissions> {
tracing::info!("Reloading permissions for repository {repo}");

Expand Down Expand Up @@ -128,7 +103,7 @@
PermissionType::Try => "try",
};

let normalized_name = repository_name.replace("-", "_");

Check warning on line 106 in src/permissions.rs

View workflow job for this annotation

GitHub Actions / Test

single-character string constant used as pattern
let url = format!("https://team-api.infra.rust-lang.org/v1/permissions/bors.{normalized_name}.{permission}.json");
let users = reqwest::get(url)
.await
Expand Down
26 changes: 26 additions & 0 deletions src/tests/permissions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::{Arc, Mutex};

use crate::permissions::{PermissionResolver, PermissionType};
use axum::async_trait;

Expand All @@ -8,6 +10,7 @@ impl PermissionResolver for NoPermissions {
async fn has_permission(&self, _username: &str, _permission: PermissionType) -> bool {
false
}
async fn reload(&self) {}
}

pub struct AllPermissions;
Expand All @@ -17,4 +20,27 @@ impl PermissionResolver for AllPermissions {
async fn has_permission(&self, _username: &str, _permission: PermissionType) -> bool {
true
}
async fn reload(&self) {}
}

pub struct MockPermissions {
pub num_reload_called: i32,
}

impl Default for MockPermissions {
fn default() -> Self {
Self {
num_reload_called: 0,
}
}
}

#[async_trait]
impl PermissionResolver for Arc<Mutex<MockPermissions>> {
async fn has_permission(&self, _username: &str, _permission: PermissionType) -> bool {
false
}
async fn reload(&self) {
self.lock().unwrap().num_reload_called += 1
}
}
Loading