From 6861d728397b9c338b39b6341e9569f67a49df43 Mon Sep 17 00:00:00 2001 From: Maximilian Pohl <maximilian@tweedegolf.com> Date: Mon, 7 Oct 2024 15:37:28 +0200 Subject: [PATCH] Minimize public interface of VTN lib Closes #12 Signed-off-by: Maximilian Pohl <maximilian@tweedegolf.com> --- openadr-vtn/src/api/event.rs | 12 +-- openadr-vtn/src/api/mod.rs | 22 ++--- openadr-vtn/src/api/program.rs | 12 +-- openadr-vtn/src/api/report.rs | 10 +-- openadr-vtn/src/api/resource.rs | 18 ++--- openadr-vtn/src/data_source/mod.rs | 14 ++-- openadr-vtn/src/data_source/postgres/event.rs | 81 +++++++++++-------- .../src/data_source/postgres/program.rs | 70 +++++++++------- .../src/data_source/postgres/report.rs | 14 ++-- .../src/data_source/postgres/resource.rs | 14 ++-- openadr-vtn/src/jwt.rs | 21 ++--- openadr-vtn/src/lib.rs | 2 +- openadr-vtn/src/state.rs | 4 +- 13 files changed, 163 insertions(+), 131 deletions(-) diff --git a/openadr-vtn/src/api/event.rs b/openadr-vtn/src/api/event.rs index 24cc6bf..b1ae474 100644 --- a/openadr-vtn/src/api/event.rs +++ b/openadr-vtn/src/api/event.rs @@ -26,7 +26,7 @@ use crate::{ pub async fn get_all( State(event_source): State<Arc<dyn EventCrud>>, ValidatedQuery(query_params): ValidatedQuery<QueryParams>, - User(user): User, + user: User, ) -> AppResponse<Vec<Event>> { trace!(?query_params); @@ -39,7 +39,7 @@ pub async fn get_all( pub async fn get( State(event_source): State<Arc<dyn EventCrud>>, Path(id): Path<EventId>, - User(user): User, + user: User, ) -> AppResponse<Event> { let event = event_source.retrieve(&id, &user).await?; trace!(%event.id, event.event_name=event.content.event_name, "retrieved event"); @@ -52,7 +52,7 @@ pub async fn add( BusinessUser(user): BusinessUser, ValidatedJson(new_event): ValidatedJson<EventContent>, ) -> Result<(StatusCode, Json<Event>), AppError> { - let event = event_source.create(new_event, &user).await?; + let event = event_source.create(new_event, &User(user)).await?; info!(%event.id, event_name=event.content.event_name, "event created"); @@ -65,7 +65,7 @@ pub async fn edit( BusinessUser(user): BusinessUser, ValidatedJson(content): ValidatedJson<EventContent>, ) -> AppResponse<Event> { - let event = event_source.update(&id, content, &user).await?; + let event = event_source.update(&id, content, &User(user)).await?; info!(%event.id, event_name=event.content.event_name, "event updated"); @@ -77,7 +77,7 @@ pub async fn delete( Path(id): Path<EventId>, BusinessUser(user): BusinessUser, ) -> AppResponse<Event> { - let event = event_source.delete(&id, &user).await?; + let event = event_source.delete(&id, &User(user)).await?; info!(%event.id, event.event_name=event.content.event_name, "deleted event"); Ok(Json(event)) } @@ -173,7 +173,7 @@ mod test { events.push( store .events() - .create(event.clone(), &Claims::any_business_user()) + .create(event.clone(), &User(Claims::any_business_user())) .await .unwrap(), ); diff --git a/openadr-vtn/src/api/mod.rs b/openadr-vtn/src/api/mod.rs index 5934cdf..16d32ad 100644 --- a/openadr-vtn/src/api/mod.rs +++ b/openadr-vtn/src/api/mod.rs @@ -11,24 +11,24 @@ use axum_extra::extract::{Query, QueryRejection}; use serde::de::DeserializeOwned; use validator::Validate; -pub mod auth; -pub mod event; -pub mod program; -pub mod report; -pub mod resource; -pub mod user; -pub mod ven; +pub(crate) mod auth; +pub(crate) mod event; +pub(crate) mod program; +pub(crate) mod report; +pub(crate) mod resource; +pub(crate) mod user; +pub(crate) mod ven; -pub type AppResponse<T> = Result<Json<T>, AppError>; +pub(crate) type AppResponse<T> = Result<Json<T>, AppError>; #[derive(Debug, Clone)] -pub struct ValidatedForm<T>(T); +pub(crate) struct ValidatedForm<T>(T); #[derive(Debug, Clone)] -pub struct ValidatedQuery<T>(pub T); +pub(crate) struct ValidatedQuery<T>(pub T); #[derive(Debug, Clone)] -pub struct ValidatedJson<T>(pub T); +pub(crate) struct ValidatedJson<T>(pub T); #[async_trait] impl<T, S> FromRequest<S> for ValidatedJson<T> diff --git a/openadr-vtn/src/api/program.rs b/openadr-vtn/src/api/program.rs index d9c5d87..e365586 100644 --- a/openadr-vtn/src/api/program.rs +++ b/openadr-vtn/src/api/program.rs @@ -24,7 +24,7 @@ use crate::{ pub async fn get_all( State(program_source): State<Arc<dyn ProgramCrud>>, ValidatedQuery(query_params): ValidatedQuery<QueryParams>, - User(user): User, + user: User, ) -> AppResponse<Vec<Program>> { trace!(?query_params); @@ -36,7 +36,7 @@ pub async fn get_all( pub async fn get( State(program_source): State<Arc<dyn ProgramCrud>>, Path(id): Path<ProgramId>, - User(user): User, + user: User, ) -> AppResponse<Program> { let program = program_source.retrieve(&id, &user).await?; Ok(Json(program)) @@ -47,7 +47,7 @@ pub async fn add( BusinessUser(user): BusinessUser, ValidatedJson(new_program): ValidatedJson<ProgramContent>, ) -> Result<(StatusCode, Json<Program>), AppError> { - let program = program_source.create(new_program, &user).await?; + let program = program_source.create(new_program, &User(user)).await?; Ok((StatusCode::CREATED, Json(program))) } @@ -58,7 +58,7 @@ pub async fn edit( BusinessUser(user): BusinessUser, ValidatedJson(content): ValidatedJson<ProgramContent>, ) -> AppResponse<Program> { - let program = program_source.update(&id, content, &user).await?; + let program = program_source.update(&id, content, &User(user)).await?; info!(%program.id, program.program_name=program.content.program_name, "program updated"); @@ -70,7 +70,7 @@ pub async fn delete( Path(id): Path<ProgramId>, BusinessUser(user): BusinessUser, ) -> AppResponse<Program> { - let program = program_source.delete(&id, &user).await?; + let program = program_source.delete(&id, &User(user)).await?; info!(%id, "deleted program"); Ok(Json(program)) } @@ -175,7 +175,7 @@ mod test { for program in new_programs { let p = store .programs() - .create(program.clone(), &Claims::any_business_user()) + .create(program.clone(), &User(Claims::any_business_user())) .await .unwrap(); assert_eq!(p.content, program); diff --git a/openadr-vtn/src/api/report.rs b/openadr-vtn/src/api/report.rs index dfa20ee..f5d456a 100644 --- a/openadr-vtn/src/api/report.rs +++ b/openadr-vtn/src/api/report.rs @@ -27,7 +27,7 @@ use crate::{ pub async fn get_all( State(report_source): State<Arc<dyn ReportCrud>>, ValidatedQuery(query_params): ValidatedQuery<QueryParams>, - User(user): User, + user: User, ) -> AppResponse<Vec<Report>> { let reports = report_source.retrieve_all(&query_params, &user).await?; @@ -38,7 +38,7 @@ pub async fn get_all( pub async fn get( State(report_source): State<Arc<dyn ReportCrud>>, Path(id): Path<ReportId>, - User(user): User, + user: User, ) -> AppResponse<Report> { let report: Report = report_source.retrieve(&id, &user).await?; Ok(Json(report)) @@ -50,7 +50,7 @@ pub async fn add( VENUser(user): VENUser, ValidatedJson(new_report): ValidatedJson<ReportContent>, ) -> Result<(StatusCode, Json<Report>), AppError> { - let report = report_source.create(new_report, &user).await?; + let report = report_source.create(new_report, &User(user)).await?; info!(%report.id, report_name=?report.content.report_name, "report created"); @@ -64,7 +64,7 @@ pub async fn edit( VENUser(user): VENUser, ValidatedJson(content): ValidatedJson<ReportContent>, ) -> AppResponse<Report> { - let report = report_source.update(&id, content, &user).await?; + let report = report_source.update(&id, content, &User(user)).await?; info!(%report.id, report_name=?report.content.report_name, "report updated"); @@ -78,7 +78,7 @@ pub async fn delete( BusinessUser(user): BusinessUser, Path(id): Path<ReportId>, ) -> AppResponse<Report> { - let report = report_source.delete(&id, &user).await?; + let report = report_source.delete(&id, &User(user)).await?; info!(%id, "deleted report"); Ok(Json(report)) } diff --git a/openadr-vtn/src/api/resource.rs b/openadr-vtn/src/api/resource.rs index 512c886..51c1b11 100644 --- a/openadr-vtn/src/api/resource.rs +++ b/openadr-vtn/src/api/resource.rs @@ -19,15 +19,15 @@ use crate::{ api::{AppResponse, ValidatedJson, ValidatedQuery}, data_source::ResourceCrud, error::AppError, - jwt::{Claims, User}, + jwt::User, }; -fn has_write_permission(user_claims: &Claims, ven_id: &VenId) -> Result<(), AppError> { - if user_claims.is_ven_manager() { +fn has_write_permission(User(claims): &User, ven_id: &VenId) -> Result<(), AppError> { + if claims.is_ven_manager() { return Ok(()); } - if user_claims.is_ven() && user_claims.ven_ids().contains(ven_id) { + if claims.is_ven() && claims.ven_ids().contains(ven_id) { return Ok(()); } @@ -40,7 +40,7 @@ pub async fn get_all( State(resource_source): State<Arc<dyn ResourceCrud>>, Path(ven_id): Path<VenId>, ValidatedQuery(query_params): ValidatedQuery<QueryParams>, - User(user): User, + user: User, ) -> AppResponse<Vec<Resource>> { has_write_permission(&user, &ven_id)?; trace!(?query_params); @@ -55,7 +55,7 @@ pub async fn get_all( pub async fn get( State(resource_source): State<Arc<dyn ResourceCrud>>, Path((ven_id, id)): Path<(VenId, ResourceId)>, - User(user): User, + user: User, ) -> AppResponse<Resource> { has_write_permission(&user, &ven_id)?; let ven = resource_source.retrieve(&id, ven_id, &user).await?; @@ -65,7 +65,7 @@ pub async fn get( pub async fn add( State(resource_source): State<Arc<dyn ResourceCrud>>, - User(user): User, + user: User, Path(ven_id): Path<VenId>, ValidatedJson(new_resource): ValidatedJson<ResourceContent>, ) -> Result<(StatusCode, Json<Resource>), AppError> { @@ -78,7 +78,7 @@ pub async fn add( pub async fn edit( State(resource_source): State<Arc<dyn ResourceCrud>>, Path((ven_id, id)): Path<(VenId, ResourceId)>, - User(user): User, + user: User, ValidatedJson(content): ValidatedJson<ResourceContent>, ) -> AppResponse<Resource> { has_write_permission(&user, &ven_id)?; @@ -92,7 +92,7 @@ pub async fn edit( pub async fn delete( State(resource_source): State<Arc<dyn ResourceCrud>>, Path((ven_id, id)): Path<(VenId, ResourceId)>, - User(user): User, + user: User, ) -> AppResponse<Resource> { has_write_permission(&user, &ven_id)?; let resource = resource_source.delete(&id, ven_id, &user).await?; diff --git a/openadr-vtn/src/data_source/mod.rs b/openadr-vtn/src/data_source/mod.rs index f2c13a3..c6eebb4 100644 --- a/openadr-vtn/src/data_source/mod.rs +++ b/openadr-vtn/src/data_source/mod.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::{ error::AppError, - jwt::{AuthRole, Claims}, + jwt::{AuthRole, Claims, User}, }; #[async_trait] @@ -107,7 +107,7 @@ pub trait ProgramCrud: NewType = ProgramContent, Error = AppError, Filter = crate::api::program::QueryParams, - PermissionFilter = Claims, + PermissionFilter = User, > { } @@ -118,7 +118,7 @@ pub trait ReportCrud: NewType = ReportContent, Error = AppError, Filter = crate::api::report::QueryParams, - PermissionFilter = Claims, + PermissionFilter = User, > { } @@ -129,7 +129,7 @@ pub trait EventCrud: NewType = EventContent, Error = AppError, Filter = crate::api::event::QueryParams, - PermissionFilter = Claims, + PermissionFilter = User, > { } @@ -185,7 +185,7 @@ pub trait ResourceCrud: NewType = ResourceContent, Error = AppError, Filter = crate::api::resource::QueryParams, - PermissionFilter = Claims, + PermissionFilter = User, > { } @@ -252,6 +252,6 @@ pub trait DataSource: Send + Sync + 'static { #[derive(Debug, Clone)] pub struct AuthInfo { - pub client_id: String, - pub roles: Vec<AuthRole>, + pub(crate) client_id: String, + pub(crate) roles: Vec<AuthRole>, } diff --git a/openadr-vtn/src/data_source/postgres/event.rs b/openadr-vtn/src/data_source/postgres/event.rs index 19cc8c7..d64d5c0 100644 --- a/openadr-vtn/src/data_source/postgres/event.rs +++ b/openadr-vtn/src/data_source/postgres/event.rs @@ -5,7 +5,7 @@ use crate::{ Crud, EventCrud, }, error::AppError, - jwt::{BusinessIds, Claims}, + jwt::{BusinessIds, Claims, User}, }; use axum::async_trait; use chrono::{DateTime, Utc}; @@ -197,12 +197,12 @@ impl Crud for PgEventStorage { type NewType = EventContent; type Error = AppError; type Filter = QueryParams; - type PermissionFilter = Claims; + type PermissionFilter = User; async fn create( &self, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { check_write_permission(new.program_id.as_str(), user, &self.db).await?; @@ -231,7 +231,7 @@ impl Crud for PgEventStorage { async fn retrieve( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let business_ids = match user.business_ids() { BusinessIds::Specific(ids) => Some(ids), @@ -266,7 +266,7 @@ impl Crud for PgEventStorage { async fn retrieve_all( &self, filter: &Self::Filter, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Vec<Self::Type>, Self::Error> { let pg_filter: PostgresFilter = filter.into(); trace!(?pg_filter); @@ -326,7 +326,7 @@ impl Crud for PgEventStorage { &self, id: &Self::Id, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { check_write_permission(new.program_id.as_str(), user, &self.db).await?; @@ -377,7 +377,7 @@ impl Crud for PgEventStorage { async fn delete( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let program_id = sqlx::query_as!( PgId, @@ -411,7 +411,7 @@ mod tests { api::event::QueryParams, data_source::{postgres::event::PgEventStorage, Crud}, error::AppError, - jwt::Claims, + jwt::{Claims, User}, }; use chrono::{DateTime, Duration, Utc}; use openadr_wire::{ @@ -525,7 +525,7 @@ mod tests { async fn default_get_all(db: PgPool) { let repo: PgEventStorage = db.into(); let mut events = repo - .retrieve_all(&Default::default(), &Claims::any_business_user()) + .retrieve_all(&Default::default(), &User(Claims::any_business_user())) .await .unwrap(); assert_eq!(events.len(), 3); @@ -542,7 +542,7 @@ mod tests { limit: 1, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -559,7 +559,7 @@ mod tests { skip: 1, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -571,7 +571,7 @@ mod tests { skip: 20, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -589,7 +589,7 @@ mod tests { target_values: Some(vec!["group-1".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -603,7 +603,7 @@ mod tests { target_values: Some(vec!["target-1".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -618,7 +618,7 @@ mod tests { target_values: Some(vec!["not-existent".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -631,7 +631,7 @@ mod tests { target_values: Some(vec!["target-1".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -644,7 +644,7 @@ mod tests { target_values: Some(vec!["target-1".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -662,7 +662,7 @@ mod tests { target_values: Some(vec!["private value".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -679,7 +679,7 @@ mod tests { program_id: Some("program-1".parse().unwrap()), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -693,7 +693,7 @@ mod tests { target_type: Some(TargetLabel::Group), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -706,7 +706,7 @@ mod tests { program_id: Some("not-existent".parse().unwrap()), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -721,7 +721,10 @@ mod tests { async fn get_existing(db: PgPool) { let repo: PgEventStorage = db.into(); let event = repo - .retrieve(&"event-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"event-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(event, event_1()); @@ -733,7 +736,7 @@ mod tests { let event = repo .retrieve( &"not-existent".parse().unwrap(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await; assert!(matches!(event, Err(AppError::NotFound))); @@ -747,7 +750,7 @@ mod tests { async fn add(db: PgPool) { let repo: PgEventStorage = db.into(); let event = repo - .create(event_1().content, &Claims::any_business_user()) + .create(event_1().content, &User(Claims::any_business_user())) .await .unwrap(); assert_eq!(event.content, event_1().content); @@ -761,7 +764,7 @@ mod tests { async fn add_existing_conflict_name(db: PgPool) { let repo: PgEventStorage = db.into(); let event = repo - .create(event_1().content, &Claims::any_business_user()) + .create(event_1().content, &User(Claims::any_business_user())) .await; assert!(event.is_ok()); } @@ -777,7 +780,7 @@ mod tests { .update( &"event-1".parse().unwrap(), event_1().content, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -801,13 +804,16 @@ mod tests { .update( &"event-1".parse().unwrap(), updated.clone(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); assert_eq!(event.content, updated); let event = repo - .retrieve(&"event-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"event-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(event.content, updated); @@ -820,7 +826,7 @@ mod tests { .update( &"event-1".parse().unwrap(), event_2().content, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await; assert!(event.is_ok()); @@ -834,18 +840,27 @@ mod tests { async fn delete_existing(db: PgPool) { let repo: PgEventStorage = db.into(); let event = repo - .delete(&"event-1".parse().unwrap(), &Claims::any_business_user()) + .delete( + &"event-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(event, event_1()); let event = repo - .retrieve(&"event-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"event-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await; assert!(matches!(event, Err(AppError::NotFound))); let event = repo - .retrieve(&"event-2".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"event-2".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(event, event_2()); @@ -857,7 +872,7 @@ mod tests { let event = repo .delete( &"not-existent".parse().unwrap(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await; assert!(matches!(event, Err(AppError::NotFound))); diff --git a/openadr-vtn/src/data_source/postgres/program.rs b/openadr-vtn/src/data_source/postgres/program.rs index ecd864a..8ae1e1a 100644 --- a/openadr-vtn/src/data_source/postgres/program.rs +++ b/openadr-vtn/src/data_source/postgres/program.rs @@ -5,7 +5,7 @@ use crate::{ Crud, ProgramCrud, }, error::AppError, - jwt::Claims, + jwt::User, }; use axum::async_trait; use chrono::{DateTime, Utc}; @@ -171,12 +171,12 @@ impl Crud for PgProgramStorage { type NewType = ProgramContent; type Error = AppError; type Filter = QueryParams; - type PermissionFilter = Claims; + type PermissionFilter = User; async fn create( &self, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let (targets, vens) = extract_vens(new.targets); let business_id = extract_business_id(user)?; @@ -266,7 +266,7 @@ impl Crud for PgProgramStorage { async fn retrieve( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { Ok(sqlx::query_as!( PostgresProgram, @@ -304,7 +304,7 @@ impl Crud for PgProgramStorage { async fn retrieve_all( &self, filter: &Self::Filter, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Vec<Self::Type>, Self::Error> { let pg_filter: PostgresFilter = filter.into(); trace!(?pg_filter); @@ -366,7 +366,7 @@ impl Crud for PgProgramStorage { &self, id: &Self::Id, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let (targets, vens) = extract_vens(new.targets); let business_id = extract_business_id(user)?; @@ -464,7 +464,7 @@ impl Crud for PgProgramStorage { async fn delete( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let business_id = extract_business_id(user)?; @@ -518,6 +518,7 @@ mod tests { }; use sqlx::PgPool; + use crate::jwt::User; impl Default for QueryParams { fn default() -> Self { Self { @@ -613,7 +614,7 @@ mod tests { async fn default_get_all(db: PgPool) { let repo: PgProgramStorage = db.into(); let mut programs = repo - .retrieve_all(&Default::default(), &Claims::any_business_user()) + .retrieve_all(&Default::default(), &User(Claims::any_business_user())) .await .unwrap(); assert_eq!(programs.len(), 3); @@ -630,7 +631,7 @@ mod tests { limit: 1, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -646,7 +647,7 @@ mod tests { skip: 1, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -658,7 +659,7 @@ mod tests { skip: 3, ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -676,7 +677,7 @@ mod tests { target_values: Some(vec!["group-1".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -689,7 +690,7 @@ mod tests { target_values: Some(vec!["not-existent".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -702,7 +703,7 @@ mod tests { target_values: Some(vec!["program-2".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -716,7 +717,7 @@ mod tests { target_values: Some(vec!["program-not-existent".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -734,7 +735,7 @@ mod tests { target_values: Some(vec!["private value".to_string()]), ..Default::default() }, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -750,7 +751,10 @@ mod tests { let repo: PgProgramStorage = db.into(); let program = repo - .retrieve(&"program-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"program-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(program, program_1()); @@ -762,7 +766,7 @@ mod tests { let program = repo .retrieve( &"program-not-existent".parse().unwrap(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await; @@ -779,7 +783,7 @@ mod tests { let repo: PgProgramStorage = db.into(); let program = repo - .create(program_1().content, &Claims::any_business_user()) + .create(program_1().content, &User(Claims::any_business_user())) .await .unwrap(); assert!(program.created_date_time < Utc::now() + Duration::minutes(10)); @@ -793,7 +797,7 @@ mod tests { let repo: PgProgramStorage = db.into(); let program = repo - .create(program_1().content, &Claims::any_business_user()) + .create(program_1().content, &User(Claims::any_business_user())) .await; assert!(matches!(program, Err(AppError::Conflict(_, _)))); } @@ -810,7 +814,7 @@ mod tests { .update( &"program-1".parse().unwrap(), program_1().content, - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); @@ -836,14 +840,17 @@ mod tests { .update( &"program-1".parse().unwrap(), updated.clone(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await .unwrap(); assert_eq!(program.content, updated); let program = repo - .retrieve(&"program-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"program-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(program.content, updated); @@ -857,18 +864,27 @@ mod tests { async fn delete_existing(db: PgPool) { let repo: PgProgramStorage = db.into(); let program = repo - .delete(&"program-1".parse().unwrap(), &Claims::any_business_user()) + .delete( + &"program-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(program, program_1()); let program = repo - .retrieve(&"program-1".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"program-1".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await; assert!(matches!(program, Err(AppError::NotFound))); let program = repo - .retrieve(&"program-2".parse().unwrap(), &Claims::any_business_user()) + .retrieve( + &"program-2".parse().unwrap(), + &User(Claims::any_business_user()), + ) .await .unwrap(); assert_eq!(program, program_2()); @@ -880,7 +896,7 @@ mod tests { let program = repo .delete( &"program-not-existing".parse().unwrap(), - &Claims::any_business_user(), + &User(Claims::any_business_user()), ) .await; assert!(matches!(program, Err(AppError::NotFound))); diff --git a/openadr-vtn/src/data_source/postgres/report.rs b/openadr-vtn/src/data_source/postgres/report.rs index 6ae2e5c..b1c92bc 100644 --- a/openadr-vtn/src/data_source/postgres/report.rs +++ b/openadr-vtn/src/data_source/postgres/report.rs @@ -5,7 +5,7 @@ use crate::{ Crud, ReportCrud, }, error::AppError, - jwt::Claims, + jwt::User, }; use axum::async_trait; use chrono::{DateTime, Utc}; @@ -85,12 +85,12 @@ impl Crud for PgReportStorage { type NewType = ReportContent; type Error = AppError; type Filter = QueryParams; - type PermissionFilter = Claims; + type PermissionFilter = User; async fn create( &self, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let permitted_vens = sqlx::query_as!( PgId, @@ -156,7 +156,7 @@ impl Crud for PgReportStorage { async fn retrieve( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let business_ids = extract_business_ids(user); @@ -188,7 +188,7 @@ impl Crud for PgReportStorage { async fn retrieve_all( &self, filter: &Self::Filter, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Vec<Self::Type>, Self::Error> { let business_ids = extract_business_ids(user); @@ -235,7 +235,7 @@ impl Crud for PgReportStorage { &self, id: &Self::Id, new: Self::NewType, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let business_ids = extract_business_ids(user); let report: Report = sqlx::query_as!( @@ -280,7 +280,7 @@ impl Crud for PgReportStorage { async fn delete( &self, id: &Self::Id, - user: &Self::PermissionFilter, + User(user): &Self::PermissionFilter, ) -> Result<Self::Type, Self::Error> { let business_ids = extract_business_ids(user); diff --git a/openadr-vtn/src/data_source/postgres/resource.rs b/openadr-vtn/src/data_source/postgres/resource.rs index a84dd7b..61f7679 100644 --- a/openadr-vtn/src/data_source/postgres/resource.rs +++ b/openadr-vtn/src/data_source/postgres/resource.rs @@ -5,7 +5,7 @@ use crate::{ ResourceCrud, VenScopedCrud, }, error::AppError, - jwt::Claims, + jwt::User, }; use axum::async_trait; use chrono::{DateTime, Utc}; @@ -126,7 +126,7 @@ impl VenScopedCrud for PgResourceStorage { type NewType = ResourceContent; type Error = AppError; type Filter = QueryParams; - type PermissionFilter = Claims; + type PermissionFilter = User; async fn create( &self, @@ -361,7 +361,7 @@ mod test { use crate::{ api::resource::QueryParams, data_source::{postgres::resource::PgResourceStorage, VenScopedCrud}, - jwt::AuthRole, + jwt::{AuthRole, User}, }; use sqlx::PgPool; @@ -379,16 +379,16 @@ mod test { #[sqlx::test(fixtures("users", "vens", "resources"))] async fn retrieve_all(db: PgPool) { let repo = PgResourceStorage::from(db.clone()); - let claims = crate::jwt::Claims::new(vec![AuthRole::VenManager]); + let user = User(crate::jwt::Claims::new(vec![AuthRole::VenManager])); let resources = repo - .retrieve_all("ven-1".parse().unwrap(), &Default::default(), &claims) + .retrieve_all("ven-1".parse().unwrap(), &Default::default(), &user) .await .unwrap(); assert_eq!(resources.len(), 2); let resources = repo - .retrieve_all("ven-2".parse().unwrap(), &Default::default(), &claims) + .retrieve_all("ven-2".parse().unwrap(), &Default::default(), &user) .await .unwrap(); assert_eq!(resources.len(), 3); @@ -400,7 +400,7 @@ mod test { }; let resources = repo - .retrieve_all("ven-1".parse().unwrap(), &filters, &claims) + .retrieve_all("ven-1".parse().unwrap(), &filters, &user) .await .unwrap(); assert_eq!(resources.len(), 1); diff --git a/openadr-vtn/src/jwt.rs b/openadr-vtn/src/jwt.rs index 7de54e6..ca5138b 100644 --- a/openadr-vtn/src/jwt.rs +++ b/openadr-vtn/src/jwt.rs @@ -50,11 +50,11 @@ impl AuthRole { } #[derive(Debug, serde::Serialize, serde::Deserialize)] -pub struct Claims { +pub(crate) struct Claims { exp: usize, nbf: usize, - pub sub: String, - pub roles: Vec<AuthRole>, + pub(crate) sub: String, + pub(crate) roles: Vec<AuthRole>, } #[cfg(test)] @@ -162,7 +162,7 @@ impl JwtManager { } /// Create a new JWT token with the given claims and expiration time - pub fn create( + pub(crate) fn create( &self, expires_in: std::time::Duration, client_id: String, @@ -184,7 +184,7 @@ impl JwtManager { } /// Decode and validate a given JWT token, returning the validated claims - pub fn decode_and_validate(&self, token: &str) -> Result<Claims, jsonwebtoken::errors::Error> { + fn decode_and_validate(&self, token: &str) -> Result<Claims, jsonwebtoken::errors::Error> { let validation = jsonwebtoken::Validation::default(); let token_data = jsonwebtoken::decode::<Claims>(token, &self.decoding_key, &validation)?; Ok(token_data.claims) @@ -192,19 +192,20 @@ impl JwtManager { } /// User claims extracted from the request -pub struct User(pub Claims); +pub struct User(pub(crate) Claims); /// User claims extracted from the request, with the requirement that the user is a business user -pub struct BusinessUser(pub Claims); +pub struct BusinessUser(pub(crate) Claims); /// User claims extracted from the request, with the requirement that the user is a VEN user -pub struct VENUser(pub Claims); +pub struct VENUser(pub(crate) Claims); /// User claims extracted from the request, with the requirement that the user is a user manager -pub struct UserManagerUser(pub Claims); +#[allow(dead_code)] +pub struct UserManagerUser(pub(crate) Claims); /// User claims extracted from the request, with the requirement that the user is a VEN manager -pub struct VenManagerUser(pub Claims); +pub struct VenManagerUser(pub(crate) Claims); #[async_trait] impl<S: Send + Sync> FromRequestParts<S> for User diff --git a/openadr-vtn/src/lib.rs b/openadr-vtn/src/lib.rs index 1105340..8dcf299 100644 --- a/openadr-vtn/src/lib.rs +++ b/openadr-vtn/src/lib.rs @@ -1,4 +1,4 @@ -pub mod api; +mod api; pub mod data_source; mod error; pub mod jwt; diff --git a/openadr-vtn/src/state.rs b/openadr-vtn/src/state.rs index 9156545..281819c 100644 --- a/openadr-vtn/src/state.rs +++ b/openadr-vtn/src/state.rs @@ -87,7 +87,7 @@ impl AppState { } } -pub async fn method_not_allowed(req: Request, next: Next) -> impl IntoResponse { +async fn method_not_allowed(req: Request, next: Next) -> impl IntoResponse { let resp = next.run(req).await; let status = resp.status(); match status { @@ -96,7 +96,7 @@ pub async fn method_not_allowed(req: Request, next: Next) -> impl IntoResponse { } } -pub async fn handler_404() -> AppError { +async fn handler_404() -> AppError { AppError::NotFound }