Skip to content

Commit

Permalink
handle_tunnel_request: small code cleanup (#391)
Browse files Browse the repository at this point in the history
more idiomatic, less code, better readability
  • Loading branch information
tsnoam authored Jan 6, 2025
1 parent 4d33b62 commit d8e519c
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 81 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tracing = { version = "0.1.41", features = ["log"] }
url = "2.5.4"
urlencoding = "2.1.3"
uuid = { version = "1.11.0", features = ["v7", "serde"] }
derive_more = { version = "1.0.0", features = ["display", "error"] }

[target.'cfg(not(target_family = "unix"))'.dependencies]
crossterm = { version = "0.28.1" }
Expand Down
4 changes: 2 additions & 2 deletions src/tunnel/server/handler_http2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::server::utils::{bad_request, inject_cookie};
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
use crate::tunnel::server::WsServer;
use crate::tunnel::transport;
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
Expand All @@ -22,7 +22,7 @@ pub(super) async fn http_server_upgrade(
restrict_path_prefix: Option<String>,
client_addr: SocketAddr,
mut req: Request<Incoming>,
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
) -> HttpResponse {
let (remote_addr, local_rx, local_tx, need_cookie) = match server
.handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req)
.await
Expand Down
4 changes: 2 additions & 2 deletions src/tunnel/server/handler_websocket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::server::utils::{bad_request, inject_cookie};
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
use crate::tunnel::server::WsServer;
use crate::tunnel::transport;
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
Expand All @@ -21,7 +21,7 @@ pub(super) async fn ws_server_upgrade(
restrict_path_prefix: Option<String>,
client_addr: SocketAddr,
mut req: Request<Incoming>,
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
) -> HttpResponse {
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return bad_request();
Expand Down
71 changes: 28 additions & 43 deletions src/tunnel/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::tunnel::server::handler_websocket::ws_server_upgrade;
use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer;
use crate::tunnel::server::utils::{
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel,
HttpResponse,
};
use crate::tunnel::tls_reloader::TlsReloader;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
Expand Down Expand Up @@ -89,68 +90,52 @@ impl WsServer {
Pin<Box<dyn AsyncWrite + Send>>,
bool,
),
Response<Either<String, BoxBody<Bytes, anyhow::Error>>>,
HttpResponse,
> {
match extract_x_forwarded_for(req) {
Ok(Some((x_forward_for, x_forward_for_str))) => {
info!("Request X-Forwarded-For: {:?}", x_forward_for);
Span::current().record("forwarded_for", x_forward_for_str);
client_addr.set_ip(x_forward_for);
}
Ok(_) => {}
Err(_err) => return Err(bad_request()),
if let Some((x_forward_for, x_forward_for_str)) = extract_x_forwarded_for(req) {
info!("Request X-Forwarded-For: {x_forward_for:?}");
Span::current().record("forwarded_for", x_forward_for_str);
client_addr.set_ip(x_forward_for);
};

let path_prefix = match extract_path_prefix(req) {
Ok(p) => p,
Err(_err) => return Err(bad_request()),
};
let path_prefix = extract_path_prefix(req.uri().path()).map_err(|err| {
warn!("Rejecting connection with {err}: {}", req.uri());
bad_request()
})?;

if let Some(restrict_path) = restrict_path_prefix {
if path_prefix != restrict_path {
warn!(
"Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)",
path_prefix, restrict_path
"Client requested upgrade path '{path_prefix}' does not match upgrade path restriction '{restrict_path}' (mTLS, etc.)"
);
return Err(bad_request());
}
}

let jwt = match extract_tunnel_info(req) {
Ok(jwt) => jwt,
Err(_err) => return Err(bad_request()),
};
let jwt = extract_tunnel_info(req)?;

Span::current().record("id", &jwt.claims.id);
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
let remote = match RemoteAddr::try_from(jwt.claims) {
Ok(remote) => remote,
Err(err) => {
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
return Err(bad_request());
}
};
let remote = RemoteAddr::try_from(jwt.claims).map_err(|err| {
warn!("Rejecting connection with bad tunnel info: {err} {}", req.uri());
bad_request()
})?;

let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) {
Some(matched_restriction) => {
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
matched_restriction
}
None => {
warn!("Rejecting connection with not allowed destination: {:?}", remote);
return Err(bad_request());
}
};
let restriction = validate_tunnel(&remote, path_prefix, &restrictions).ok_or_else(|| {
warn!("Rejecting connection with not allowed destination: {remote:?}");
bad_request()
})?;
info!("Tunnel accepted due to matched restriction: {}", restriction.name);

let req_protocol = remote.protocol.clone();
let inject_cookie = req_protocol.is_dynamic_reverse_tunnel();
let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await {
Ok(ret) => ret,
Err(err) => {
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
return Err(bad_request());
}
};
let tunnel = self
.exec_tunnel(restriction, remote, client_addr)
.await
.map_err(|err| {
warn!("Rejecting connection with bad upgrade request: {err} {}", req.uri());
bad_request()
})?;

let (remote_addr, local_rx, local_tx) = tunnel;
info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port);
Expand Down
91 changes: 57 additions & 34 deletions src/tunnel/server/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::restrictions::types::{
use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use bytes::Bytes;
use derive_more::{Display, Error};
use http_body_util::combinators::BoxBody;
use http_body_util::Either;
use hyper::body::{Body, Incoming};
Expand All @@ -17,7 +18,9 @@ use tracing::{error, info, warn};
use url::Host;
use uuid::Uuid;

pub(super) fn bad_request() -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
pub type HttpResponse = Response<Either<String, BoxBody<Bytes, anyhow::Error>>>;

pub(super) fn bad_request() -> HttpResponse {
http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Either::Left("Invalid request".to_string()))
Expand Down Expand Up @@ -48,42 +51,41 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -
}

#[inline]
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, ()> {
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
return Ok(None);
};
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Option<(IpAddr, &str)> {
let x_forward_for = req.headers().get("X-Forwarded-For")?;

// X-Forwarded-For: <client>, <proxy1>, <proxy2>
let x_forward_for = x_forward_for.to_str().unwrap_or_default();
let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for);
let ip: Option<IpAddr> = x_forward_for.parse().ok();
Ok(ip.map(|ip| (ip, x_forward_for)))
ip.map(|ip| (ip, x_forward_for))
}

#[inline]
pub(super) fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, ()> {
let path = req.uri().path();
let min_len = min(path.len(), 1);
if &path[0..min_len] != "/" {
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
return Err(());
pub(super) fn extract_path_prefix(path: &str) -> Result<&str, PathPrefixErr> {
if !path.starts_with('/') {
return Err(PathPrefixErr::BadPathPrefix);
}

let Some((l, r)) = path[min_len..].split_once('/') else {
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return Err(());
};
let (l, r) = path[1..].split_once('/').ok_or(PathPrefixErr::BadUpgradeRequest)?;

if !r.ends_with("events") {
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
return Err(());
match r.ends_with("events") {
true => Ok(l),
false => Err(PathPrefixErr::BadUpgradeRequest),
}
}

Ok(l)
#[derive(Debug, Display, Error)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub(super) enum PathPrefixErr {
#[display("bad path prefix in upgrade request")]
BadPathPrefix,
#[display("bad upgrade request")]
BadUpgradeRequest,
}

#[inline]
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelConfig>, ()> {
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> anyhow::Result<TokenData<JwtTunnelConfig>, HttpResponse> {
let jwt = req
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
Expand All @@ -93,19 +95,13 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
.unwrap_or_default();

let jwt = match jwt_token_to_tunnel(jwt) {
Ok(jwt) => jwt,
err => {
warn!(
"error while decoding jwt for tunnel info {:?} header {:?}",
err,
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
);
return Err(());
}
};

Ok(jwt)
jwt_token_to_tunnel(jwt).map_err(|err| {
warn!(
"error while decoding jwt for tunnel info {err:?} header {:?}",
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
);
bad_request()
})
}

impl RestrictionConfig {
Expand Down Expand Up @@ -497,4 +493,31 @@ mod tests {
assert!(!config.is_allowed(&remote));
assert!(!AllowConfig::from(config.clone()).is_allowed(&remote));
}

#[test]
fn test_extract_path_prefix_happy_path() {
assert_eq!(extract_path_prefix("/prefix/events"), Ok("prefix"));
assert_eq!(extract_path_prefix("/prefix/a/events"), Ok("prefix"));
assert_eq!(extract_path_prefix("/prefix/a/b/events"), Ok("prefix"));
}

#[test]
fn test_extract_path_prefix_no_events_suffix() {
assert_eq!(extract_path_prefix("/prefix/events/"), Err(PathPrefixErr::BadUpgradeRequest));
assert_eq!(extract_path_prefix("/prefix"), Err(PathPrefixErr::BadUpgradeRequest));
assert_eq!(extract_path_prefix("/prefixevents"), Err(PathPrefixErr::BadUpgradeRequest));
assert_eq!(extract_path_prefix("/prefix/event"), Err(PathPrefixErr::BadUpgradeRequest));
assert_eq!(extract_path_prefix("/prefix/a"), Err(PathPrefixErr::BadUpgradeRequest));
assert_eq!(extract_path_prefix("/prefix/a/b"), Err(PathPrefixErr::BadUpgradeRequest));
}

#[test]
fn test_extract_path_prefix_no_slash_prefix() {
assert_eq!(extract_path_prefix(""), Err(PathPrefixErr::BadPathPrefix));
assert_eq!(extract_path_prefix("p"), Err(PathPrefixErr::BadPathPrefix));
assert_eq!(extract_path_prefix("\\"), Err(PathPrefixErr::BadPathPrefix));
assert_eq!(extract_path_prefix("prefix/events"), Err(PathPrefixErr::BadPathPrefix));
assert_eq!(extract_path_prefix("prefix/a/events"), Err(PathPrefixErr::BadPathPrefix));
assert_eq!(extract_path_prefix("prefix/a/b/events"), Err(PathPrefixErr::BadPathPrefix));
}
}

0 comments on commit d8e519c

Please sign in to comment.