Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Allow user with multiple roles, add UserManager role #86

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openadr-vtn/src/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub async fn token(
};

let expiration = std::time::Duration::from_secs(3600 * 24 * 30);
let token = jwt_manager.create(expiration, user.client_id, user.role, user.ven)?;
let token = jwt_manager.create(expiration, user.client_id, user.roles)?;

Ok(AccessTokenResponse {
access_token: token,
Expand Down
14 changes: 6 additions & 8 deletions openadr-vtn/src/api/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::api::{AppResponse, ValidatedQuery};
use crate::data_source::{Crud, EventCrud};
use crate::error::AppError;
use crate::error::AppError::NotImplemented;
use crate::jwt::{BLUser, User};
use crate::jwt::{BusinessUser, User};

impl EventCrud for RwLock<HashMap<EventId, Event>> {}

Expand Down Expand Up @@ -105,7 +105,7 @@ pub async fn get(

pub async fn add(
State(event_source): State<Arc<dyn EventCrud>>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
Json(new_event): Json<EventContent>,
) -> Result<(StatusCode, Json<Event>), AppError> {
let event = event_source.create(new_event).await?;
Expand All @@ -118,7 +118,7 @@ pub async fn add(
pub async fn edit(
State(event_source): State<Arc<dyn EventCrud>>,
Path(id): Path<EventId>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
Json(content): Json<EventContent>,
) -> AppResponse<Event> {
let event = event_source.update(&id, content).await?;
Expand All @@ -131,7 +131,7 @@ pub async fn edit(
pub async fn delete(
State(event_source): State<Arc<dyn EventCrud>>,
Path(id): Path<EventId>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
) -> AppResponse<Event> {
let event = event_source.delete(&id).await?;
info!(%id, "deleted event");
Expand Down Expand Up @@ -219,8 +219,7 @@ mod test {
store.auth.try_write().unwrap().push(AuthInfo {
client_id: "admin".to_string(),
client_secret: "admin".to_string(),
role: AuthRole::BL,
ven: None,
roles: vec![AuthRole::Business(None), AuthRole::UserManager],
});

{
Expand All @@ -239,8 +238,7 @@ mod test {
.create(
std::time::Duration::from_secs(3600),
"admin".to_string(),
AuthRole::BL,
None,
vec![AuthRole::Business(None), AuthRole::UserManager],
)
.unwrap()
}
Expand Down
14 changes: 6 additions & 8 deletions openadr-vtn/src/api/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use openadr_wire::Program;
use crate::api::{AppResponse, ValidatedQuery};
use crate::data_source::{Crud, ProgramCrud};
use crate::error::AppError;
use crate::jwt::{BLUser, User};
use crate::jwt::{BusinessUser, User};

impl ProgramCrud for RwLock<HashMap<ProgramId, Program>> {}

Expand Down Expand Up @@ -132,7 +132,7 @@ pub async fn get(

pub async fn add(
State(program_source): State<Arc<dyn ProgramCrud>>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
Json(new_program): Json<ProgramContent>,
) -> Result<(StatusCode, Json<Program>), AppError> {
let program = program_source.create(new_program).await?;
Expand All @@ -143,7 +143,7 @@ pub async fn add(
pub async fn edit(
State(program_source): State<Arc<dyn ProgramCrud>>,
Path(id): Path<ProgramId>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
Json(content): Json<ProgramContent>,
) -> AppResponse<Program> {
let program = program_source.update(&id, content).await?;
Expand All @@ -156,7 +156,7 @@ pub async fn edit(
pub async fn delete(
State(program_source): State<Arc<dyn ProgramCrud>>,
Path(id): Path<ProgramId>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
) -> AppResponse<Program> {
let program = program_source.delete(&id).await?;
info!(%id, "deleted program");
Expand Down Expand Up @@ -254,8 +254,7 @@ mod test {
store.auth.try_write().unwrap().push(AuthInfo {
client_id: "admin".to_string(),
client_secret: "admin".to_string(),
role: AuthRole::BL,
ven: None,
roles: vec![AuthRole::Business(None), AuthRole::UserManager],
});

{
Expand All @@ -274,8 +273,7 @@ mod test {
.create(
std::time::Duration::from_secs(3600),
"admin".to_string(),
AuthRole::BL,
None,
vec![AuthRole::Business(None), AuthRole::UserManager],
)
.unwrap()
}
Expand Down
4 changes: 2 additions & 2 deletions openadr-vtn/src/api/report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use openadr_wire::Report;
use crate::api::{AppResponse, ValidatedQuery};
use crate::data_source::{Crud, ReportCrud};
use crate::error::AppError;
use crate::jwt::{BLUser, User};
use crate::jwt::{BusinessUser, User};

impl ReportCrud for RwLock<HashMap<ReportId, Report>> {}

Expand Down Expand Up @@ -136,7 +136,7 @@ pub async fn edit(

pub async fn delete(
State(report_source): State<Arc<dyn ReportCrud>>,
BLUser(_user): BLUser,
BusinessUser(_user): BusinessUser,
Path(id): Path<ReportId>,
) -> AppResponse<Report> {
let report = report_source.delete(&id).await?;
Expand Down
3 changes: 1 addition & 2 deletions openadr-vtn/src/data_source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ pub trait DataSource: Send + Sync + 'static {
pub struct AuthInfo {
pub client_id: String,
pub client_secret: String,
pub role: AuthRole,
pub ven: Option<String>,
pub roles: Vec<AuthRole>,
}

#[derive(Default, Clone)]
Expand Down
146 changes: 117 additions & 29 deletions openadr-vtn/src/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,83 @@ pub struct JwtManager {
decoding_key: DecodingKey,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "role", content = "id")]
pub enum AuthRole {
BL,
VEN,
UserManager,
Business(Option<String>),
VEN(String),
}

impl AuthRole {
pub fn is_business(&self) -> bool {
matches!(self, AuthRole::Business(_))
}

pub fn is_ven(&self) -> bool {
matches!(self, AuthRole::VEN(_))
}

pub fn is_user_manager(&self) -> bool {
matches!(self, AuthRole::UserManager)
}
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Claims {
exp: usize,
nbf: usize,
pub sub: String,
pub role: AuthRole,
pub ven: Option<String>,
pub roles: Vec<AuthRole>,
}

impl Claims {
pub fn ven_ids(&self) -> Vec<String> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a new-type instead of String?

Copy link
Member Author

@rnijveld rnijveld Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will be a newtype eventually, given we already have several of those, but I wanted to keep this PR a little smaller.

self.roles
.iter()
.filter_map(|role| {
if let AuthRole::VEN(id) = role {
Some(id.clone())
} else {
None
}
})
.collect()
}

pub fn business_ids(&self) -> (Vec<String>, bool) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be better as an enum Foo {Any, Specific(Vec<String>)}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Will change!

let mut allow_any = false;
let ids = self
.roles
.iter()
.filter_map(|role| {
if let AuthRole::Business(id) = role {
if let Some(id) = id {
Some(id.clone())
} else {
allow_any = true;
None
}
} else {
None
}
})
.collect();

(ids, allow_any)
}

pub fn is_ven(&self) -> bool {
self.roles.iter().any(AuthRole::is_ven)
}

pub fn is_business(&self) -> bool {
self.roles.iter().any(AuthRole::is_business)
}

pub fn is_user_manager(&self) -> bool {
self.roles.iter().any(AuthRole::is_user_manager)
}
}

impl JwtManager {
Expand Down Expand Up @@ -63,8 +126,7 @@ impl JwtManager {
&self,
expires_in: std::time::Duration,
client_id: String,
role: AuthRole,
ven: Option<String>,
roles: Vec<AuthRole>,
) -> Result<String, jsonwebtoken::errors::Error> {
let now = chrono::Utc::now();
let exp = now + expires_in;
Expand All @@ -73,8 +135,7 @@ impl JwtManager {
exp: exp.timestamp() as usize,
nbf: now.timestamp() as usize,
sub: client_id,
role,
ven,
roles,
};

let token = encode(&Header::default(), &claims, &self.encoding_key)?;
Expand All @@ -90,9 +151,17 @@ impl JwtManager {
}
}

/// User claims extracted from the request
pub struct User(pub Claims);
pub struct BLUser(pub Claims);
// pub struct VENUser(pub Claims);

/// User claims extracted from the request, with the requirement that the user is a business user
pub struct BusinessUser(pub Claims);

/// User claims extracted from the request, with the requirement that the user is a VEN user
pub struct VENUser(pub Claims);

/// User claims extracted from the request, with the requirement that the user is a user manager
pub struct UserManagerUser(pub Claims);

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for User
Expand Down Expand Up @@ -125,36 +194,55 @@ where
}

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for BLUser
impl<S: Send + Sync> FromRequestParts<S> for BusinessUser
where
Arc<JwtManager>: FromRef<S>,
{
type Rejection = AppError;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = User::from_request_parts(parts, state).await?;

if user.0.role != AuthRole::BL {
let User(user) = User::from_request_parts(parts, state).await?;
if !user.is_business() {
return Err(AppError::Auth(
"User does not have the required role".to_string(),
));
}

Ok(BLUser(user.0))
Ok(BusinessUser(user))
}
}

// #[async_trait]
// impl<S: Send + Sync> FromRequestParts<S> for VENUser where Arc<JwtManager>: FromRef<S> {
// type Rejection = AppError;
#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for VENUser
where
Arc<JwtManager>: FromRef<S>,
{
type Rejection = AppError;

// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// let user = User::from_request_parts(parts, state).await?;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let User(user) = User::from_request_parts(parts, state).await?;
if !user.is_ven() {
return Err(AppError::Auth(
"User does not have the required role".to_string(),
));
}
Ok(VENUser(user))
}
}

// if user.0.role != AuthRole::VEN {
// return Err(AppError::Auth("User does not have the required role".to_string()));
// }
#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for UserManagerUser
where
Arc<JwtManager>: FromRef<S>,
{
type Rejection = AppError;

// Ok(VENUser(user.0))
// }
// }
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let User(user) = User::from_request_parts(parts, state).await?;
if !user.is_user_manager() {
return Err(AppError::Auth(
"User does not have the required role".to_string(),
));
}
Ok(UserManagerUser(user))
}
}
3 changes: 1 addition & 2 deletions openadr-vtn/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ async fn main() {
storage.auth.write().await.push(AuthInfo {
client_id: "admin".to_string(),
client_secret: "admin".to_string(),
role: AuthRole::BL,
ven: None,
roles: vec![AuthRole::Business(None), AuthRole::UserManager],
});
let state = AppState::new(storage, JwtManager::from_base64_secret("test").unwrap());

Expand Down