From 8703336967ab341688ffcd763526ade2663729e8 Mon Sep 17 00:00:00 2001 From: Marlon Baeten Date: Tue, 24 Sep 2024 11:19:47 +0200 Subject: [PATCH] Improve resource authentication --- openadr-vtn/src/api/event.rs | 22 ++++----- openadr-vtn/src/api/program.rs | 20 ++++----- openadr-vtn/src/api/resource.rs | 45 ++++++++++++++++--- openadr-vtn/src/api/ven.rs | 2 +- openadr-vtn/src/data_source/postgres/event.rs | 4 +- .../src/data_source/postgres/program.rs | 4 +- .../src/data_source/postgres/report.rs | 8 ++-- openadr-vtn/src/data_source/postgres/user.rs | 8 ++-- openadr-vtn/src/jwt.rs | 18 +++++++- 9 files changed, 89 insertions(+), 42 deletions(-) diff --git a/openadr-vtn/src/api/event.rs b/openadr-vtn/src/api/event.rs index 5d917eb..6d61f20 100644 --- a/openadr-vtn/src/api/event.rs +++ b/openadr-vtn/src/api/event.rs @@ -571,7 +571,7 @@ mod test { let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-1".to_string()), + AuthRole::VEN("ven-1".parse().unwrap()), AuthRole::Business("business-2".to_string()), ], ); @@ -584,19 +584,19 @@ mod test { let (state, _) = state_with_events(vec![], db).await; let mut app = state.clone().into_router(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = get_help("event-3", &token, &mut app).await; assert_eq!(response.status(), StatusCode::OK); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]); let response = get_help("event-3", &token, &mut app).await; assert_eq!(response.status(), StatusCode::NOT_FOUND); let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-2".to_string()), - AuthRole::VEN("ven-1".to_string()), + AuthRole::VEN("ven-2".parse().unwrap()), + AuthRole::VEN("ven-1".parse().unwrap()), ], ); let response = get_help("event-3", &token, &mut app).await; @@ -605,7 +605,7 @@ mod test { let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-2".to_string()), + AuthRole::VEN("ven-2".parse().unwrap()), AuthRole::Business("business-2".to_string()), ], ); @@ -618,7 +618,7 @@ mod test { let (state, _) = state_with_events(vec![], db).await; let mut app = state.clone().into_router(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = retrieve_all_with_filter_help(&mut app, "", &token).await; assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); @@ -628,8 +628,8 @@ mod test { let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-1".to_string()), - AuthRole::VEN("ven-2".to_string()), + AuthRole::VEN("ven-1".parse().unwrap()), + AuthRole::VEN("ven-2".parse().unwrap()), ], ); let response = retrieve_all_with_filter_help(&mut app, "", &token).await; @@ -641,7 +641,7 @@ mod test { // VEN should not be able to filter on other ven names, // even if they have a common set of events, // as this would leak information about which events the VENs have in common. - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = retrieve_all_with_filter_help( &mut app, "targetType=VEN_NAME&targetValues=ven-2-name", @@ -710,7 +710,7 @@ mod test { let (state, _) = state_with_events(vec![], db).await; let mut app = state.clone().into_router(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = help_create_event(&mut app, &default_event_content(), &token).await; assert_eq!(response.status(), StatusCode::FORBIDDEN); diff --git a/openadr-vtn/src/api/program.rs b/openadr-vtn/src/api/program.rs index c10c299..f6ccec9 100644 --- a/openadr-vtn/src/api/program.rs +++ b/openadr-vtn/src/api/program.rs @@ -559,19 +559,19 @@ mod test { let body = response.into_body().collect().await.unwrap().to_bytes(); let program: Program = serde_json::from_slice(&body).unwrap(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = get_help(&mut app, &token, program.id.as_str()).await; assert_eq!(response.status(), StatusCode::OK); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]); let response = get_help(&mut app, &token, program.id.as_str()).await; assert_eq!(response.status(), StatusCode::NOT_FOUND); let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-2".to_string()), - AuthRole::VEN("ven-1".to_string()), + AuthRole::VEN("ven-2".parse().unwrap()), + AuthRole::VEN("ven-1".parse().unwrap()), ], ); let response = get_help(&mut app, &token, program.id.as_str()).await; @@ -583,7 +583,7 @@ mod test { let (state, _) = state_with_programs(vec![], db).await; let mut app = state.clone().into_router(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = retrieve_all_with_filter_help(&mut app, "", &token).await; assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); @@ -596,7 +596,7 @@ mod test { names.sort(); assert_eq!(names, vec!["program-1", "program-3"]); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]); let response = retrieve_all_with_filter_help(&mut app, "", &token).await; assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); @@ -609,7 +609,7 @@ mod test { names.sort(); assert_eq!(names, vec!["program-1", "program-2"]); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]); let response = retrieve_all_with_filter_help( &mut app, "targetType=VEN_NAME&targetValues=ven-1", @@ -624,8 +624,8 @@ mod test { let token = jwt_test_token( &state, vec![ - AuthRole::VEN("ven-2".to_string()), - AuthRole::VEN("ven-1".to_string()), + AuthRole::VEN("ven-2".parse().unwrap()), + AuthRole::VEN("ven-1".parse().unwrap()), ], ); let response = retrieve_all_with_filter_help(&mut app, "", &token).await; @@ -646,7 +646,7 @@ mod test { let (state, _) = state_with_programs(vec![], db).await; let mut app = state.clone().into_router(); - let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]); + let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]); let response = help_create_program(&mut app, &token, &default_content()).await; assert_eq!(response.status(), StatusCode::FORBIDDEN); diff --git a/openadr-vtn/src/api/resource.rs b/openadr-vtn/src/api/resource.rs index ed6d22e..9c1d0a5 100644 --- a/openadr-vtn/src/api/resource.rs +++ b/openadr-vtn/src/api/resource.rs @@ -1,7 +1,9 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, + async_trait, + extract::{FromRef, FromRequestParts, Path, State}, + http::request::Parts, Json, }; use openadr_wire::ven::VenId; @@ -19,14 +21,43 @@ use crate::{ api::{AppResponse, ValidatedJson, ValidatedQuery}, data_source::ResourceCrud, error::AppError, - jwt::{User, VenManagerUser}, + jwt::{Claims, JwtManager, User}, }; +pub struct ResourceUser(Claims); + +#[async_trait] +impl FromRequestParts for ResourceUser +where + Arc: FromRef, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let User(user_claims) = User::from_request_parts(parts, state).await?; + let Path(ven_id): Path = Path::from_request_parts(parts, state) + .await + .map_err(|_| AppError::BadRequest("a valid VEN id is required"))?; + + if user_claims.is_ven_manager() { + return Ok(ResourceUser(user_claims)); + } + + if user_claims.is_ven() && user_claims.ven_ids().contains(&ven_id) { + return Ok(ResourceUser(user_claims)); + } + + Err(AppError::Forbidden( + "User not authorized to access this resource", + )) + } +} + pub async fn get_all( State(resource_source): State>, Path(ven_id): Path, ValidatedQuery(query_params): ValidatedQuery, - VenManagerUser(user): VenManagerUser, + ResourceUser(user): ResourceUser, ) -> AppResponse> { trace!(?query_params); @@ -41,7 +72,7 @@ pub async fn get( State(resource_source): State>, Path(ven_id): Path, Path(id): Path, - User(user): User, + ResourceUser(user): ResourceUser, ) -> AppResponse { let ven = resource_source.retrieve(&id, ven_id, &user).await?; @@ -50,7 +81,7 @@ pub async fn get( pub async fn add( State(resource_source): State>, - VenManagerUser(user): VenManagerUser, + ResourceUser(user): ResourceUser, Path(ven_id): Path, ValidatedJson(new_resource): ValidatedJson, ) -> Result<(StatusCode, Json), AppError> { @@ -63,7 +94,7 @@ pub async fn edit( State(resource_source): State>, Path(ven_id): Path, Path(id): Path, - VenManagerUser(user): VenManagerUser, + ResourceUser(user): ResourceUser, ValidatedJson(content): ValidatedJson, ) -> AppResponse { let resource = resource_source.update(&id, ven_id, content, &user).await?; @@ -77,7 +108,7 @@ pub async fn delete( State(resource_source): State>, Path(ven_id): Path, Path(id): Path, - VenManagerUser(user): VenManagerUser, + ResourceUser(user): ResourceUser, ) -> AppResponse { let resource = resource_source.delete(&id, ven_id, &user).await?; info!(%id, "deleted resource"); diff --git a/openadr-vtn/src/api/ven.rs b/openadr-vtn/src/api/ven.rs index 42f14a0..cf9d001 100644 --- a/openadr-vtn/src/api/ven.rs +++ b/openadr-vtn/src/api/ven.rs @@ -39,7 +39,7 @@ pub async fn get( User(user): User, ) -> AppResponse { if user.is_ven() { - if !user.ven_ids().iter().any(|vid| vid == id.as_str()) { + if !user.ven_ids().iter().any(|vid| *vid == id) { return Err(AppError::Forbidden("User does not have access to this VEN")); } } else if !user.is_ven_manager() { diff --git a/openadr-vtn/src/data_source/postgres/event.rs b/openadr-vtn/src/data_source/postgres/event.rs index dd97f64..971f3ec 100644 --- a/openadr-vtn/src/data_source/postgres/event.rs +++ b/openadr-vtn/src/data_source/postgres/event.rs @@ -254,7 +254,7 @@ impl Crud for PgEventStorage { "#, id.as_str(), user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), user.is_business(), business_ids.as_deref(), ) @@ -304,7 +304,7 @@ impl Crud for PgEventStorage { serde_json::to_value(pg_filter.targets) .map_err(AppError::SerdeJsonInternalServerError)?, user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), user.is_business(), business_ids.as_deref(), pg_filter.skip, diff --git a/openadr-vtn/src/data_source/postgres/program.rs b/openadr-vtn/src/data_source/postgres/program.rs index b7528b7..ff6f2c3 100644 --- a/openadr-vtn/src/data_source/postgres/program.rs +++ b/openadr-vtn/src/data_source/postgres/program.rs @@ -294,7 +294,7 @@ impl Crud for PgProgramStorage { "#, id.as_str(), user.is_ven(), - &user.ven_ids() + &user.ven_ids_string() ) .fetch_one(&self.db) .await? @@ -346,7 +346,7 @@ impl Crud for PgProgramStorage { serde_json::to_value(pg_filter.targets) .map_err(AppError::SerdeJsonInternalServerError)?, user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), pg_filter.skip, pg_filter.limit, ) diff --git a/openadr-vtn/src/data_source/postgres/report.rs b/openadr-vtn/src/data_source/postgres/report.rs index 6f15015..7bce4c4 100644 --- a/openadr-vtn/src/data_source/postgres/report.rs +++ b/openadr-vtn/src/data_source/postgres/report.rs @@ -108,7 +108,7 @@ impl Crud for PgReportStorage { if !user .ven_ids() .into_iter() - .any(|user_ven| permitted_vens.contains(&user_ven)) + .any(|user_ven| permitted_vens.contains(&user_ven.to_string())) { Err(AppError::NotFound)? }; @@ -168,7 +168,7 @@ impl Crud for PgReportStorage { "#, id.as_str(), user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), business_ids.as_deref() ) .fetch_one(&self.db) @@ -201,7 +201,7 @@ impl Crud for PgReportStorage { filter.event_id.clone().map(|x| x.to_string()), filter.client_name, user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), business_ids.as_deref(), filter.skip, filter.limit, @@ -241,7 +241,7 @@ impl Crud for PgReportStorage { "#, id.as_str(), user.is_ven(), - &user.ven_ids(), + &user.ven_ids_string(), business_ids.as_deref(), new.program_id.as_str(), new.event_id.as_str(), diff --git a/openadr-vtn/src/data_source/postgres/user.rs b/openadr-vtn/src/data_source/postgres/user.rs index 49c666f..493ae90 100644 --- a/openadr-vtn/src/data_source/postgres/user.rs +++ b/openadr-vtn/src/data_source/postgres/user.rs @@ -3,6 +3,7 @@ use crate::{ jwt::AuthRole, }; use axum::async_trait; +use openadr_wire::IdentifierError; use sqlx::PgPool; pub struct PgAuthSource { @@ -58,10 +59,11 @@ impl AuthSource for PgAuthSource { .ok(); let mut ven_roles = vens - .map(|vens| { + .and_then(|vens| { vens.into_iter() - .map(|ven| AuthRole::VEN(ven.id)) - .collect::>() + .map(|ven| Ok(AuthRole::VEN(ven.id.parse()?))) + .collect::, IdentifierError>>() + .ok() }) .unwrap_or_default(); diff --git a/openadr-vtn/src/jwt.rs b/openadr-vtn/src/jwt.rs index a271b39..5ca2ec3 100644 --- a/openadr-vtn/src/jwt.rs +++ b/openadr-vtn/src/jwt.rs @@ -10,6 +10,7 @@ use axum_extra::{ TypedHeader, }; use jsonwebtoken::{encode, DecodingKey, EncodingKey, Header}; +use openadr_wire::ven::VenId; use tracing::trace; use crate::error::AppError; @@ -26,7 +27,7 @@ pub enum AuthRole { VenManager, Business(String), AnyBusiness, - VEN(String), + VEN(VenId), } impl AuthRole { @@ -83,7 +84,7 @@ pub enum BusinessIds { } impl Claims { - pub fn ven_ids(&self) -> Vec { + pub fn ven_ids(&self) -> Vec { self.roles .iter() .filter_map(|role| { @@ -96,6 +97,19 @@ impl Claims { .collect() } + pub fn ven_ids_string(&self) -> Vec { + self.roles + .iter() + .filter_map(|role| { + if let AuthRole::VEN(id) = role { + Some(id.to_string()) + } else { + None + } + }) + .collect() + } + pub fn business_ids(&self) -> BusinessIds { let mut ids = vec![];