Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multiple basic/bearer credentials for 'Authorization' server support #432

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 192 additions & 37 deletions rama-http/src/layer/auth/require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@
//!
//! Custom validation can be made by implementing [`ValidateRequest`].

use base64::Engine as _;
use std::{fmt, marker::PhantomData};

use crate::layer::validate_request::{
ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
};
use crate::{
Request, Response, StatusCode,
header::{self, HeaderValue},
};
use base64::Engine as _;
use rama_core::Context;
use std::{fmt, marker::PhantomData, sync::Arc};

use rama_net::user::UserId;
use sealed::AuthorizerSeal;

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

Expand Down Expand Up @@ -199,29 +199,40 @@ impl<ResBody> fmt::Debug for Bearer<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Bearer<ResBody>>
// TODO: revisit ValidateRequest and related types so we do not require
// the associated Response types for all these traits. E.g. by forcing
// downstream users that their response bodies can be turned into the standard `rama::http::Body`
impl<S, B, C> ValidateRequest<S, B> for AuthorizeContext<C>
where
ResBody: Default + Send + 'static,
C: Authorizer,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
{
type ResponseBody = ResBody;
type ResponseBody = C::ResBody;

async fn validate(
&self,
ctx: Context<S>,
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
Some(header_value) if self.credential.is_valid(header_value) => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
let mut res = Response::new(Self::ResponseBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;

if let Some(www_auth) = C::www_authenticate_header() {
res.headers_mut().insert(header::WWW_AUTHENTICATE, www_auth);
} else {
res.headers_mut()
.insert(header::WWW_AUTHENTICATE, "Bearer".parse().unwrap());
}

Err(res)
}
}
Expand Down Expand Up @@ -267,35 +278,99 @@ impl<ResBody> fmt::Debug for Basic<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Basic<ResBody>>
where
ResBody: Default + Send + 'static,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
{
type ResponseBody = ResBody;
// Private module with the actual implementation details
mod sealed {
use super::*;

async fn validate(
&self,
ctx: Context<S>,
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res.headers_mut()
.insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
Err(res)
}
/// Private trait that contains the actual authorization logic
pub(super) trait AuthorizerSeal: Send + Sync + 'static {
/// Check if the given header value is valid for this authorizer.
fn is_valid(&self, header_value: &HeaderValue) -> bool;

/// Return the WWW-Authenticate header value if applicable.
fn www_authenticate_header() -> Option<HeaderValue>;
}

impl<ResBody: Default + Send + 'static> AuthorizerSeal for Basic<ResBody> {
fn is_valid(&self, header_value: &HeaderValue) -> bool {
header_value == &self.header_value
}

fn www_authenticate_header() -> Option<HeaderValue> {
Some(HeaderValue::from_static("Basic"))
}
}

impl<ResBody: Default + Send + 'static> AuthorizerSeal for Bearer<ResBody> {
fn is_valid(&self, header_value: &HeaderValue) -> bool {
header_value == &self.header_value
}

fn www_authenticate_header() -> Option<HeaderValue> {
None
}
}

impl<T, const N: usize> AuthorizerSeal for [T; N]
where
T: AuthorizerSeal,
{
fn is_valid(&self, header_value: &HeaderValue) -> bool {
self.iter().any(|auth| auth.is_valid(header_value))
}

fn www_authenticate_header() -> Option<HeaderValue> {
T::www_authenticate_header()
}
}

impl<T> AuthorizerSeal for Vec<T>
where
T: AuthorizerSeal,
{
fn is_valid(&self, header_value: &HeaderValue) -> bool {
self.iter().any(|auth| auth.is_valid(header_value))
}

fn www_authenticate_header() -> Option<HeaderValue> {
T::www_authenticate_header()
}
}

impl<T> AuthorizerSeal for Arc<T>
where
T: AuthorizerSeal,
{
fn is_valid(&self, header_value: &HeaderValue) -> bool {
(**self).is_valid(header_value)
}

fn www_authenticate_header() -> Option<HeaderValue> {
T::www_authenticate_header()
}
}
}

/// Trait for authorizing requests.
pub trait Authorizer: sealed::AuthorizerSeal {
type ResBody: Default + Send + 'static;
}

// Implement the public trait for our existing types
impl<ResBody: Default + Send + 'static> Authorizer for Basic<ResBody> {
type ResBody = ResBody;
}
impl<ResBody: Default + Send + 'static> Authorizer for Bearer<ResBody> {
type ResBody = ResBody;
}
impl<T: Authorizer, const N: usize> Authorizer for [T; N] {
type ResBody = T::ResBody;
}
impl<T: Authorizer> Authorizer for Vec<T> {
type ResBody = T::ResBody;
}
impl<T: Authorizer> Authorizer for Arc<T> {
type ResBody = T::ResBody;
}

pub struct AuthorizeContext<C> {
Expand All @@ -304,6 +379,7 @@ pub struct AuthorizeContext<C> {
}

impl<C> AuthorizeContext<C> {
/// Create a new [`AuthorizeContext`] with the given credential.
pub(crate) fn new(credential: C) -> Self {
Self {
credential,
Expand Down Expand Up @@ -332,11 +408,11 @@ impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;

use crate::layer::validate_request::ValidateRequestHeaderLayer;
use crate::{Body, header};

use rama_core::error::BoxError;
use rama_core::service::service_fn;
use rama_core::{Context, Layer, Service};
Expand Down Expand Up @@ -468,8 +544,83 @@ mod tests {
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
#[tokio::test]
async fn multiple_basic_auth_vec() {
let auth1 = Basic::new("user1", "pass1");
let auth2 = Basic::new("user2", "pass2");
let auth_vec = vec![auth1, auth2];
let auth_context = AuthorizeContext::new(auth_vec);
let service = ValidateRequestHeaderLayer::custom(auth_context).layer(service_fn(echo));

// Test first credential
let request = Request::builder()
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("user1:pass1")),
)
.body(Body::default())
.unwrap();
let response = service.serve(Context::default(), request).await.unwrap();
assert_eq!(StatusCode::OK, response.status());

// Test second credential
let request = Request::builder()
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("user2:pass2")),
)
.body(Body::default())
.unwrap();
let response = service.serve(Context::default(), request).await.unwrap();
assert_eq!(StatusCode::OK, response.status());

// Test invalid credential
let request = Request::builder()
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("invalid:invalid")),
)
.body(Body::default())
.unwrap();
let response = service.serve(Context::default(), request).await.unwrap();
assert_eq!(StatusCode::UNAUTHORIZED, response.status());
}

#[tokio::test]
async fn multiple_basic_auth_array() {
let auth1 = Basic::new("user1", "pass1");
let auth_array = [auth1.clone(), auth1.clone()];
let auth_context = AuthorizeContext::new(auth_array);
let service = ValidateRequestHeaderLayer::custom(auth_context).layer(service_fn(echo));

// Test valid credential
let request = Request::builder()
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("user1:pass1")),
)
.body(Body::default())
.unwrap();
let response = service.serve(Context::default(), request).await.unwrap();
assert_eq!(StatusCode::OK, response.status());
}

#[tokio::test]
async fn arc_basic_auth() {
let auth = Basic::new("user", "pass");
let arc_auth = Arc::new(auth);
let auth_context = AuthorizeContext::new(arc_auth);
let service = ValidateRequestHeaderLayer::custom(auth_context).layer(service_fn(echo));

let request = Request::builder()
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("user:pass")),
)
.body(Body::default())
.unwrap();
let response = service.serve(Context::default(), request).await.unwrap();
assert_eq!(StatusCode::OK, response.status());
}

#[tokio::test]
Expand Down Expand Up @@ -532,4 +683,8 @@ mod tests {

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}