Skip to content

Commit

Permalink
make it easier for others to support their own username label extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
glendc committed Mar 22, 2024
1 parent b5dde27 commit 845c0bb
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 86 deletions.
19 changes: 8 additions & 11 deletions examples/http_connect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//!
//! ```sh
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john:secret' http://www.example.com/
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john-cc-us:secret' http://www.example.com/
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john-red-blue:secret' http://www.example.com/
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john:secret' https://www.example.com/
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john:secret' http://echo.example/foo/bar
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john:secret' -XPOST http://echo.example/lucky/7
Expand All @@ -21,7 +21,7 @@
//!
//! ```sh
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john:secret' http://echo.example/foo/bar
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john-cc-us:secret' http://echo.example/foo/bar
//! curl -v -x http://127.0.0.1:8080 --proxy-user 'john-red-blue:secret' http://echo.example/foo/bar
//! ```
//!
//! You should see in all the above examples the responses from the server.
Expand Down Expand Up @@ -57,7 +57,7 @@ use rama::{
http::{
client::HttpClient,
layer::{
proxy_auth::ProxyAuthLayer,
proxy_auth::{ProxyAuthLayer, ProxyUsernameLabels},
trace::TraceLayer,
upgrade::{UpgradeLayer, Upgraded},
},
Expand All @@ -70,14 +70,13 @@ use rama::{
},
Body, IntoResponse, Request, Response, StatusCode,
},
proxy::ProxyFilter,
rt::Executor,
service::{layer::HijackLayer, service_fn, Context, Service, ServiceBuilder},
tcp::utils::is_connection_error,
};
use serde::Deserialize;
use serde_json::json;
use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{convert::Infallible, ops::Deref, sync::Arc, time::Duration};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};

Expand Down Expand Up @@ -110,11 +109,9 @@ async fn main() {
"127.0.0.1:8080",
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
// - specify it as `with_proxy_filter_labels::<'_'>()`
// in case you want to define a different separator, such as '_'.
// - specify `.with_labels::<T>()` in case you want to use a custom labels extractor.
// - `ProxyAuthLayer::new` can be used for a custom Credentials type.
.layer(ProxyAuthLayer::basic(("john", "secret")).with_proxy_filter_labels_default())
// See [`ProxyAuthLayer::with_labels`] for more information,
// e.g. can also be used to extract upstream proxy filters
.layer(ProxyAuthLayer::basic(("john", "secret")).with_labels::<ProxyUsernameLabels>())
// example of how one might insert an API layer into their proxy
.layer(HijackLayer::new(
DomainMatcher::new("echo.example"),
Expand All @@ -128,7 +125,7 @@ async fn main() {
Json(json!({
"method": req.method().as_str(),
"path": req.uri().path(),
"filter": ctx.get::<ProxyFilter>().map(|f| format!("{:?}", f)),
"username_labels": ctx.get::<ProxyUsernameLabels>().map(|labels| labels.deref()),
}))
},
_ => StatusCode::NOT_FOUND,
Expand Down
175 changes: 135 additions & 40 deletions src/http/layer/proxy_auth/auth.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::{
error::BoxError,
http::headers::{
authorization::{Basic, Credentials},
Authorization,
},
proxy::UsernameConfig,
proxy::{username::UsernameConfigError, ProxyFilter, UsernameConfig},
service::context::Extensions,
};
use std::future::Future;
use std::{
future::Future,
ops::{Deref, DerefMut},
};

/// The `ProxyAuthority` trait is used to determine if a set of [`Credential`]s are authorized.
///
Expand All @@ -28,6 +32,25 @@ pub trait ProxyAuthoritySync<C, L>: Send + Sync + 'static {
fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool;
}

/// A trait to convert a username into a tuple of username and meta info attached to the username.
///
/// This trait is to be implemented in case you want to define your own metadata extractor from the username.
///
/// See [`UsernameConfig`] for an example of why one might want to use this.
pub trait FromUsername {
/// The output type of the username metadata extraction,
/// and which will be added to the [`Request`]'s [`Extensions`].
///
/// [`Request`]: crate::http::Request
type Output: Clone + Send + Sync + 'static;

/// The error type that can be returned when parsing the username went wrong.
type Error;

/// Parse the username and return the username and the metadata.
fn from_username(username: &str) -> Result<(String, Option<Self::Output>), Self::Error>;
}

impl<A, C, L> ProxyAuthority<C, L> for A
where
A: ProxyAuthoritySync<C, L>,
Expand Down Expand Up @@ -55,7 +78,7 @@ impl ProxyAuthoritySync<Basic, ()> for Basic {
}
}

impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for Basic {
impl<T: FromUsername> ProxyAuthoritySync<Basic, T> for Basic {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
let username = credentials.username();
let password = credentials.password();
Expand All @@ -64,8 +87,8 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for Basic {
return false;
}

let (username, mut filter) = match username.parse::<UsernameConfig<C>>() {
Ok(t) => t.into_parts(),
let (username, mut metadata) = match T::from_username(username) {
Ok(t) => t,
Err(_) => {
return if self == credentials {
ext.insert(self.clone());
Expand All @@ -80,8 +103,8 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for Basic {
return false;
}

if let Some(filter) = filter.take() {
ext.insert(filter);
if let Some(metadata) = metadata.take() {
ext.insert(metadata);
}
ext.insert(Authorization::basic(username.as_str(), password).0);
true
Expand All @@ -99,7 +122,7 @@ impl ProxyAuthoritySync<Basic, ()> for (&'static str, &'static str) {
}
}

impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (&'static str, &'static str) {
impl<T: FromUsername> ProxyAuthoritySync<Basic, T> for (&'static str, &'static str) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
let username = credentials.username();
let password = credentials.password();
Expand All @@ -108,8 +131,8 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (&'static s
return false;
}

let (username, mut filter) = match username.parse::<UsernameConfig<C>>() {
Ok(t) => t.into_parts(),
let (username, mut metadata) = match T::from_username(username) {
Ok(t) => t,
Err(_) => {
return if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0, self.1).0);
Expand All @@ -124,8 +147,8 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (&'static s
return false;
}

if let Some(filter) = filter.take() {
ext.insert(filter);
if let Some(metadata) = metadata.take() {
ext.insert(metadata);
}
ext.insert(Authorization::basic(username.as_str(), password).0);
true
Expand All @@ -143,7 +166,7 @@ impl ProxyAuthoritySync<Basic, ()> for (String, String) {
}
}

impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (String, String) {
impl<T: FromUsername> ProxyAuthoritySync<Basic, T> for (String, String) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
let username = credentials.username();
let password = credentials.password();
Expand All @@ -152,8 +175,8 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (String, St
return false;
}

let (username, mut filter) = match username.parse::<UsernameConfig<C>>() {
Ok(t) => t.into_parts(),
let (username, mut metadata) = match T::from_username(username) {
Ok(t) => t,
Err(_) => {
return if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0.as_str(), self.1.as_str()).0);
Expand All @@ -168,37 +191,24 @@ impl<const C: char> ProxyAuthoritySync<Basic, UsernameConfig<C>> for (String, St
return false;
}

if let Some(filter) = filter.take() {
ext.insert(filter);
if let Some(metadata) = metadata.take() {
ext.insert(metadata);
}
ext.insert(Authorization::basic(username.as_str(), password).0);
true
}
}

macro_rules! impl_proxy_auth_sync_tuple {
($($T:ident),+ $(,)?) => {
#[allow(unused_parens)]
#[allow(non_snake_case)]
impl<C, L, $($T),+> ProxyAuthoritySync<C, L> for ($($T),+,)
where C: Credentials + Send + 'static,
$(
$T: ProxyAuthoritySync<C, L>,
)+

{
fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
let ($($T),+,) = self;
$(
ProxyAuthoritySync::authorized($T, ext, &credentials)
)||+
}
}
};
impl<C, L, T, const N: usize> ProxyAuthoritySync<C, L> for [T; N]
where
C: Credentials + Send + 'static,
T: ProxyAuthoritySync<C, L>,
{
fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
self.iter().any(|t| t.authorized(ext, credentials))
}
}

all_the_tuples_no_last_special_case!(impl_proxy_auth_sync_tuple);

impl<C, L, T> ProxyAuthoritySync<C, L> for Vec<T>
where
C: Credentials + Send + 'static,
Expand All @@ -209,11 +219,58 @@ where
}
}

/// A wrapper type to extract username labels and store them as-is in the [`Extensions`].
#[derive(Debug, Clone)]
pub struct ProxyUsernameLabels<const C: char = '-'>(pub Vec<String>);

impl<const C: char> Deref for ProxyUsernameLabels<C> {
type Target = Vec<String>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<const C: char> DerefMut for ProxyUsernameLabels<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<const C: char> FromUsername for ProxyUsernameLabels<C> {
type Output = Self;
type Error = BoxError;

fn from_username(username: &str) -> Result<(String, Option<Self::Output>), Self::Error> {
let mut it = username.split(C);
let username = match it.next() {
Some(username) => username.to_owned(),
None => return Err("no username found".into()),
};
let labels: Vec<_> = it.map(str::to_owned).collect();
if labels.is_empty() {
Ok((username, None))
} else {
Ok((username, Some(Self(labels))))
}
}
}

impl<const C: char> FromUsername for UsernameConfig<C> {
type Output = ProxyFilter;
type Error = UsernameConfigError;

fn from_username(username: &str) -> Result<(String, Option<Self::Output>), Self::Error> {
let username_cfg: Self = username.parse()?;
let (username, filter) = username_cfg.into_parts();
Ok((username, filter))
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::proxy::{ProxyFilter, UsernameConfig};

use super::ProxyAuthority;
use headers::{authorization::Basic, Authorization};

#[tokio::test]
Expand Down Expand Up @@ -248,6 +305,27 @@ mod test {
assert_eq!(filter.country, Some("us".to_owned()));
}

#[tokio::test]
async fn basic_authorization_with_labels_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];

let ext = ProxyAuthority::<_, ProxyUsernameLabels>::authorized(
&auths,
Authorization::basic("john-green-red", "secret").0,
)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

let labels: &ProxyUsernameLabels = ext.get().unwrap();
assert_eq!(labels.deref(), &vec!["green".to_owned(), "red".to_owned()]);
}

#[tokio::test]
async fn basic_authorization_with_filter_not_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
Expand All @@ -265,6 +343,23 @@ mod test {
assert!(ext.get::<ProxyFilter>().is_none());
}

#[tokio::test]
async fn basic_authorization_with_labels_not_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];

let ext = ProxyAuthority::<_, ProxyUsernameLabels>::authorized(&auths, auth.clone())
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

assert!(ext.get::<ProxyUsernameLabels>().is_none());
}

#[tokio::test]
async fn basic_authorization_tuple() {
let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")];
Expand Down
Loading

0 comments on commit 845c0bb

Please sign in to comment.