Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Commit

Permalink
Improve resource authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
marlonbaeten committed Sep 24, 2024
1 parent 8712ab6 commit 8703336
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 42 deletions.
22 changes: 11 additions & 11 deletions openadr-vtn/src/api/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ mod test {
let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-1".to_string()),
AuthRole::VEN("ven-1".parse().unwrap()),
AuthRole::Business("business-2".to_string()),
],
);
Expand All @@ -584,19 +584,19 @@ mod test {
let (state, _) = state_with_events(vec![], db).await;
let mut app = state.clone().into_router();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = get_help("event-3", &token, &mut app).await;
assert_eq!(response.status(), StatusCode::OK);

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]);
let response = get_help("event-3", &token, &mut app).await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);

let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-2".to_string()),
AuthRole::VEN("ven-1".to_string()),
AuthRole::VEN("ven-2".parse().unwrap()),
AuthRole::VEN("ven-1".parse().unwrap()),
],
);
let response = get_help("event-3", &token, &mut app).await;
Expand All @@ -605,7 +605,7 @@ mod test {
let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-2".to_string()),
AuthRole::VEN("ven-2".parse().unwrap()),
AuthRole::Business("business-2".to_string()),
],
);
Expand All @@ -618,7 +618,7 @@ mod test {
let (state, _) = state_with_events(vec![], db).await;
let mut app = state.clone().into_router();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = retrieve_all_with_filter_help(&mut app, "", &token).await;
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
Expand All @@ -628,8 +628,8 @@ mod test {
let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-1".to_string()),
AuthRole::VEN("ven-2".to_string()),
AuthRole::VEN("ven-1".parse().unwrap()),
AuthRole::VEN("ven-2".parse().unwrap()),
],
);
let response = retrieve_all_with_filter_help(&mut app, "", &token).await;
Expand All @@ -641,7 +641,7 @@ mod test {
// VEN should not be able to filter on other ven names,
// even if they have a common set of events,
// as this would leak information about which events the VENs have in common.
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = retrieve_all_with_filter_help(
&mut app,
"targetType=VEN_NAME&targetValues=ven-2-name",
Expand Down Expand Up @@ -710,7 +710,7 @@ mod test {
let (state, _) = state_with_events(vec![], db).await;
let mut app = state.clone().into_router();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = help_create_event(&mut app, &default_event_content(), &token).await;
assert_eq!(response.status(), StatusCode::FORBIDDEN);

Expand Down
20 changes: 10 additions & 10 deletions openadr-vtn/src/api/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,19 +559,19 @@ mod test {
let body = response.into_body().collect().await.unwrap().to_bytes();
let program: Program = serde_json::from_slice(&body).unwrap();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = get_help(&mut app, &token, program.id.as_str()).await;
assert_eq!(response.status(), StatusCode::OK);

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]);
let response = get_help(&mut app, &token, program.id.as_str()).await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);

let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-2".to_string()),
AuthRole::VEN("ven-1".to_string()),
AuthRole::VEN("ven-2".parse().unwrap()),
AuthRole::VEN("ven-1".parse().unwrap()),
],
);
let response = get_help(&mut app, &token, program.id.as_str()).await;
Expand All @@ -583,7 +583,7 @@ mod test {
let (state, _) = state_with_programs(vec![], db).await;
let mut app = state.clone().into_router();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = retrieve_all_with_filter_help(&mut app, "", &token).await;
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
Expand All @@ -596,7 +596,7 @@ mod test {
names.sort();
assert_eq!(names, vec!["program-1", "program-3"]);

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]);
let response = retrieve_all_with_filter_help(&mut app, "", &token).await;
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
Expand All @@ -609,7 +609,7 @@ mod test {
names.sort();
assert_eq!(names, vec!["program-1", "program-2"]);

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-2".parse().unwrap())]);
let response = retrieve_all_with_filter_help(
&mut app,
"targetType=VEN_NAME&targetValues=ven-1",
Expand All @@ -624,8 +624,8 @@ mod test {
let token = jwt_test_token(
&state,
vec![
AuthRole::VEN("ven-2".to_string()),
AuthRole::VEN("ven-1".to_string()),
AuthRole::VEN("ven-2".parse().unwrap()),
AuthRole::VEN("ven-1".parse().unwrap()),
],
);
let response = retrieve_all_with_filter_help(&mut app, "", &token).await;
Expand All @@ -646,7 +646,7 @@ mod test {
let (state, _) = state_with_programs(vec![], db).await;
let mut app = state.clone().into_router();

let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".to_string())]);
let token = jwt_test_token(&state, vec![AuthRole::VEN("ven-1".parse().unwrap())]);
let response = help_create_program(&mut app, &token, &default_content()).await;
assert_eq!(response.status(), StatusCode::FORBIDDEN);

Expand Down
45 changes: 38 additions & 7 deletions openadr-vtn/src/api/resource.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::sync::Arc;

use axum::{
extract::{Path, State},
async_trait,
extract::{FromRef, FromRequestParts, Path, State},
http::request::Parts,
Json,
};
use openadr_wire::ven::VenId;
Expand All @@ -19,14 +21,43 @@ use crate::{
api::{AppResponse, ValidatedJson, ValidatedQuery},
data_source::ResourceCrud,
error::AppError,
jwt::{User, VenManagerUser},
jwt::{Claims, JwtManager, User},
};

pub struct ResourceUser(Claims);

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for ResourceUser
where
Arc<JwtManager>: FromRef<S>,
{
type Rejection = AppError;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let User(user_claims) = User::from_request_parts(parts, state).await?;
let Path(ven_id): Path<VenId> = Path::from_request_parts(parts, state)
.await
.map_err(|_| AppError::BadRequest("a valid VEN id is required"))?;

if user_claims.is_ven_manager() {
return Ok(ResourceUser(user_claims));
}

if user_claims.is_ven() && user_claims.ven_ids().contains(&ven_id) {
return Ok(ResourceUser(user_claims));
}

Err(AppError::Forbidden(
"User not authorized to access this resource",
))
}
}

pub async fn get_all(
State(resource_source): State<Arc<dyn ResourceCrud>>,
Path(ven_id): Path<VenId>,
ValidatedQuery(query_params): ValidatedQuery<QueryParams>,
VenManagerUser(user): VenManagerUser,
ResourceUser(user): ResourceUser,
) -> AppResponse<Vec<Resource>> {
trace!(?query_params);

Expand All @@ -41,7 +72,7 @@ pub async fn get(
State(resource_source): State<Arc<dyn ResourceCrud>>,
Path(ven_id): Path<VenId>,
Path(id): Path<ResourceId>,
User(user): User,
ResourceUser(user): ResourceUser,
) -> AppResponse<Resource> {
let ven = resource_source.retrieve(&id, ven_id, &user).await?;

Expand All @@ -50,7 +81,7 @@ pub async fn get(

pub async fn add(
State(resource_source): State<Arc<dyn ResourceCrud>>,
VenManagerUser(user): VenManagerUser,
ResourceUser(user): ResourceUser,
Path(ven_id): Path<VenId>,
ValidatedJson(new_resource): ValidatedJson<ResourceContent>,
) -> Result<(StatusCode, Json<Resource>), AppError> {
Expand All @@ -63,7 +94,7 @@ pub async fn edit(
State(resource_source): State<Arc<dyn ResourceCrud>>,
Path(ven_id): Path<VenId>,
Path(id): Path<ResourceId>,
VenManagerUser(user): VenManagerUser,
ResourceUser(user): ResourceUser,
ValidatedJson(content): ValidatedJson<ResourceContent>,
) -> AppResponse<Resource> {
let resource = resource_source.update(&id, ven_id, content, &user).await?;
Expand All @@ -77,7 +108,7 @@ pub async fn delete(
State(resource_source): State<Arc<dyn ResourceCrud>>,
Path(ven_id): Path<VenId>,
Path(id): Path<ResourceId>,
VenManagerUser(user): VenManagerUser,
ResourceUser(user): ResourceUser,
) -> AppResponse<Resource> {
let resource = resource_source.delete(&id, ven_id, &user).await?;
info!(%id, "deleted resource");
Expand Down
2 changes: 1 addition & 1 deletion openadr-vtn/src/api/ven.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub async fn get(
User(user): User,
) -> AppResponse<Ven> {
if user.is_ven() {
if !user.ven_ids().iter().any(|vid| vid == id.as_str()) {
if !user.ven_ids().iter().any(|vid| *vid == id) {
return Err(AppError::Forbidden("User does not have access to this VEN"));
}
} else if !user.is_ven_manager() {
Expand Down
4 changes: 2 additions & 2 deletions openadr-vtn/src/data_source/postgres/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl Crud for PgEventStorage {
"#,
id.as_str(),
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
user.is_business(),
business_ids.as_deref(),
)
Expand Down Expand Up @@ -304,7 +304,7 @@ impl Crud for PgEventStorage {
serde_json::to_value(pg_filter.targets)
.map_err(AppError::SerdeJsonInternalServerError)?,
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
user.is_business(),
business_ids.as_deref(),
pg_filter.skip,
Expand Down
4 changes: 2 additions & 2 deletions openadr-vtn/src/data_source/postgres/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ impl Crud for PgProgramStorage {
"#,
id.as_str(),
user.is_ven(),
&user.ven_ids()
&user.ven_ids_string()
)
.fetch_one(&self.db)
.await?
Expand Down Expand Up @@ -346,7 +346,7 @@ impl Crud for PgProgramStorage {
serde_json::to_value(pg_filter.targets)
.map_err(AppError::SerdeJsonInternalServerError)?,
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
pg_filter.skip,
pg_filter.limit,
)
Expand Down
8 changes: 4 additions & 4 deletions openadr-vtn/src/data_source/postgres/report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl Crud for PgReportStorage {
if !user
.ven_ids()
.into_iter()
.any(|user_ven| permitted_vens.contains(&user_ven))
.any(|user_ven| permitted_vens.contains(&user_ven.to_string()))
{
Err(AppError::NotFound)?
};
Expand Down Expand Up @@ -168,7 +168,7 @@ impl Crud for PgReportStorage {
"#,
id.as_str(),
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
business_ids.as_deref()
)
.fetch_one(&self.db)
Expand Down Expand Up @@ -201,7 +201,7 @@ impl Crud for PgReportStorage {
filter.event_id.clone().map(|x| x.to_string()),
filter.client_name,
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
business_ids.as_deref(),
filter.skip,
filter.limit,
Expand Down Expand Up @@ -241,7 +241,7 @@ impl Crud for PgReportStorage {
"#,
id.as_str(),
user.is_ven(),
&user.ven_ids(),
&user.ven_ids_string(),
business_ids.as_deref(),
new.program_id.as_str(),
new.event_id.as_str(),
Expand Down
8 changes: 5 additions & 3 deletions openadr-vtn/src/data_source/postgres/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
jwt::AuthRole,
};
use axum::async_trait;
use openadr_wire::IdentifierError;
use sqlx::PgPool;

pub struct PgAuthSource {
Expand Down Expand Up @@ -58,10 +59,11 @@ impl AuthSource for PgAuthSource {
.ok();

let mut ven_roles = vens
.map(|vens| {
.and_then(|vens| {
vens.into_iter()
.map(|ven| AuthRole::VEN(ven.id))
.collect::<Vec<_>>()
.map(|ven| Ok(AuthRole::VEN(ven.id.parse()?)))
.collect::<Result<Vec<_>, IdentifierError>>()
.ok()
})
.unwrap_or_default();

Expand Down
18 changes: 16 additions & 2 deletions openadr-vtn/src/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use axum_extra::{
TypedHeader,
};
use jsonwebtoken::{encode, DecodingKey, EncodingKey, Header};
use openadr_wire::ven::VenId;
use tracing::trace;

use crate::error::AppError;
Expand All @@ -26,7 +27,7 @@ pub enum AuthRole {
VenManager,
Business(String),
AnyBusiness,
VEN(String),
VEN(VenId),
}

impl AuthRole {
Expand Down Expand Up @@ -83,7 +84,7 @@ pub enum BusinessIds {
}

impl Claims {
pub fn ven_ids(&self) -> Vec<String> {
pub fn ven_ids(&self) -> Vec<VenId> {
self.roles
.iter()
.filter_map(|role| {
Expand All @@ -96,6 +97,19 @@ impl Claims {
.collect()
}

pub fn ven_ids_string(&self) -> Vec<String> {
self.roles
.iter()
.filter_map(|role| {
if let AuthRole::VEN(id) = role {
Some(id.to_string())
} else {
None
}
})
.collect()
}

pub fn business_ids(&self) -> BusinessIds {
let mut ids = vec![];

Expand Down

0 comments on commit 8703336

Please sign in to comment.