diff --git a/Cargo.lock b/Cargo.lock index 61e26d3..e2cb534 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -656,6 +656,7 @@ dependencies = [ "proc-macro2", "quote", "syn", + "unicode-xid", ] [[package]] @@ -3101,6 +3102,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/Cargo.toml b/Cargo.toml index 69858d5..16c49de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt", "loc url = "2.5.4" urlencoding = "2.1.3" uuid = { version = "1.11.0", features = ["v7", "serde"] } +derive_more = { version = "1.0.0", features = ["display", "error"] } [target.'cfg(not(target_family = "unix"))'.dependencies] crossterm = { version = "0.28.1" } diff --git a/src/tunnel/server/handler_http2.rs b/src/tunnel/server/handler_http2.rs index 2a7ae12..b5fff65 100644 --- a/src/tunnel/server/handler_http2.rs +++ b/src/tunnel/server/handler_http2.rs @@ -1,5 +1,5 @@ use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::server::utils::{bad_request, inject_cookie}; +use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite}; @@ -22,7 +22,7 @@ pub(super) async fn http_server_upgrade( restrict_path_prefix: Option, client_addr: SocketAddr, mut req: Request, -) -> Response>> { +) -> HttpResponse { let (remote_addr, local_rx, local_tx, need_cookie) = match server .handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req) .await diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index 44e6df8..3eee4c6 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -1,5 +1,5 @@ use crate::restrictions::types::RestrictionsRules; -use crate::tunnel::server::utils::{bad_request, inject_cookie}; +use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; use crate::tunnel::transport::websocket::mk_websocket_tunnel; @@ -21,7 +21,7 @@ pub(super) async fn ws_server_upgrade( restrict_path_prefix: Option, client_addr: SocketAddr, mut req: Request, -) -> Response>> { +) -> HttpResponse { if !fastwebsockets::upgrade::is_upgrade_request(&req) { warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return bad_request(); diff --git a/src/tunnel/server/server.rs b/src/tunnel/server/server.rs index 8da0f14..53984aa 100644 --- a/src/tunnel/server/server.rs +++ b/src/tunnel/server/server.rs @@ -33,6 +33,7 @@ use crate::tunnel::server::handler_websocket::ws_server_upgrade; use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer; use crate::tunnel::server::utils::{ bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel, + HttpResponse, }; use crate::tunnel::tls_reloader::TlsReloader; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -89,68 +90,52 @@ impl WsServer { Pin>, bool, ), - Response>>, + HttpResponse, > { - match extract_x_forwarded_for(req) { - Ok(Some((x_forward_for, x_forward_for_str))) => { - info!("Request X-Forwarded-For: {:?}", x_forward_for); - Span::current().record("forwarded_for", x_forward_for_str); - client_addr.set_ip(x_forward_for); - } - Ok(_) => {} - Err(_err) => return Err(bad_request()), + if let Some((x_forward_for, x_forward_for_str)) = extract_x_forwarded_for(req) { + info!("Request X-Forwarded-For: {x_forward_for:?}"); + Span::current().record("forwarded_for", x_forward_for_str); + client_addr.set_ip(x_forward_for); }; - let path_prefix = match extract_path_prefix(req) { - Ok(p) => p, - Err(_err) => return Err(bad_request()), - }; + let path_prefix = extract_path_prefix(req.uri().path()).map_err(|err| { + warn!("Rejecting connection with {err}: {}", req.uri()); + bad_request() + })?; if let Some(restrict_path) = restrict_path_prefix { if path_prefix != restrict_path { warn!( - "Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)", - path_prefix, restrict_path + "Client requested upgrade path '{path_prefix}' does not match upgrade path restriction '{restrict_path}' (mTLS, etc.)" ); return Err(bad_request()); } } - let jwt = match extract_tunnel_info(req) { - Ok(jwt) => jwt, - Err(_err) => return Err(bad_request()), - }; + let jwt = extract_tunnel_info(req)?; Span::current().record("id", &jwt.claims.id); Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); - let remote = match RemoteAddr::try_from(jwt.claims) { - Ok(remote) => remote, - Err(err) => { - warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri()); - return Err(bad_request()); - } - }; + let remote = RemoteAddr::try_from(jwt.claims).map_err(|err| { + warn!("Rejecting connection with bad tunnel info: {err} {}", req.uri()); + bad_request() + })?; - let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) { - Some(matched_restriction) => { - info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name); - matched_restriction - } - None => { - warn!("Rejecting connection with not allowed destination: {:?}", remote); - return Err(bad_request()); - } - }; + let restriction = validate_tunnel(&remote, path_prefix, &restrictions).ok_or_else(|| { + warn!("Rejecting connection with not allowed destination: {remote:?}"); + bad_request() + })?; + info!("Tunnel accepted due to matched restriction: {}", restriction.name); let req_protocol = remote.protocol.clone(); let inject_cookie = req_protocol.is_dynamic_reverse_tunnel(); - let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await { - Ok(ret) => ret, - Err(err) => { - warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); - return Err(bad_request()); - } - }; + let tunnel = self + .exec_tunnel(restriction, remote, client_addr) + .await + .map_err(|err| { + warn!("Rejecting connection with bad upgrade request: {err} {}", req.uri()); + bad_request() + })?; let (remote_addr, local_rx, local_tx) = tunnel; info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port); diff --git a/src/tunnel/server/utils.rs b/src/tunnel/server/utils.rs index 35afd5b..e4f6515 100644 --- a/src/tunnel/server/utils.rs +++ b/src/tunnel/server/utils.rs @@ -5,6 +5,7 @@ use crate::restrictions::types::{ use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX}; use crate::tunnel::RemoteAddr; use bytes::Bytes; +use derive_more::{Display, Error}; use http_body_util::combinators::BoxBody; use http_body_util::Either; use hyper::body::{Body, Incoming}; @@ -17,7 +18,9 @@ use tracing::{error, info, warn}; use url::Host; use uuid::Uuid; -pub(super) fn bad_request() -> Response>> { +pub type HttpResponse = Response>>; + +pub(super) fn bad_request() -> HttpResponse { http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(Either::Left("Invalid request".to_string())) @@ -48,42 +51,41 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) - } #[inline] -pub(super) fn extract_x_forwarded_for(req: &Request) -> Result, ()> { - let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else { - return Ok(None); - }; +pub(super) fn extract_x_forwarded_for(req: &Request) -> Option<(IpAddr, &str)> { + let x_forward_for = req.headers().get("X-Forwarded-For")?; // X-Forwarded-For: , , let x_forward_for = x_forward_for.to_str().unwrap_or_default(); let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for); let ip: Option = x_forward_for.parse().ok(); - Ok(ip.map(|ip| (ip, x_forward_for))) + ip.map(|ip| (ip, x_forward_for)) } #[inline] -pub(super) fn extract_path_prefix(req: &Request) -> Result<&str, ()> { - let path = req.uri().path(); - let min_len = min(path.len(), 1); - if &path[0..min_len] != "/" { - warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); - return Err(()); +pub(super) fn extract_path_prefix(path: &str) -> Result<&str, PathPrefixErr> { + if !path.starts_with('/') { + return Err(PathPrefixErr::BadPathPrefix); } - let Some((l, r)) = path[min_len..].split_once('/') else { - warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(()); - }; + let (l, r) = path[1..].split_once('/').ok_or(PathPrefixErr::BadUpgradeRequest)?; - if !r.ends_with("events") { - warn!("Rejecting connection with bad upgrade request: {}", req.uri()); - return Err(()); + match r.ends_with("events") { + true => Ok(l), + false => Err(PathPrefixErr::BadUpgradeRequest), } +} - Ok(l) +#[derive(Debug, Display, Error)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub(super) enum PathPrefixErr { + #[display("bad path prefix in upgrade request")] + BadPathPrefix, + #[display("bad upgrade request")] + BadUpgradeRequest, } #[inline] -pub(super) fn extract_tunnel_info(req: &Request) -> Result, ()> { +pub(super) fn extract_tunnel_info(req: &Request) -> anyhow::Result, HttpResponse> { let jwt = req .headers() .get(SEC_WEBSOCKET_PROTOCOL) @@ -93,19 +95,13 @@ pub(super) fn extract_tunnel_info(req: &Request) -> Result jwt, - err => { - warn!( - "error while decoding jwt for tunnel info {:?} header {:?}", - err, - req.headers().get(SEC_WEBSOCKET_PROTOCOL) - ); - return Err(()); - } - }; - - Ok(jwt) + jwt_token_to_tunnel(jwt).map_err(|err| { + warn!( + "error while decoding jwt for tunnel info {err:?} header {:?}", + req.headers().get(SEC_WEBSOCKET_PROTOCOL) + ); + bad_request() + }) } impl RestrictionConfig { @@ -497,4 +493,31 @@ mod tests { assert!(!config.is_allowed(&remote)); assert!(!AllowConfig::from(config.clone()).is_allowed(&remote)); } + + #[test] + fn test_extract_path_prefix_happy_path() { + assert_eq!(extract_path_prefix("/prefix/events"), Ok("prefix")); + assert_eq!(extract_path_prefix("/prefix/a/events"), Ok("prefix")); + assert_eq!(extract_path_prefix("/prefix/a/b/events"), Ok("prefix")); + } + + #[test] + fn test_extract_path_prefix_no_events_suffix() { + assert_eq!(extract_path_prefix("/prefix/events/"), Err(PathPrefixErr::BadUpgradeRequest)); + assert_eq!(extract_path_prefix("/prefix"), Err(PathPrefixErr::BadUpgradeRequest)); + assert_eq!(extract_path_prefix("/prefixevents"), Err(PathPrefixErr::BadUpgradeRequest)); + assert_eq!(extract_path_prefix("/prefix/event"), Err(PathPrefixErr::BadUpgradeRequest)); + assert_eq!(extract_path_prefix("/prefix/a"), Err(PathPrefixErr::BadUpgradeRequest)); + assert_eq!(extract_path_prefix("/prefix/a/b"), Err(PathPrefixErr::BadUpgradeRequest)); + } + + #[test] + fn test_extract_path_prefix_no_slash_prefix() { + assert_eq!(extract_path_prefix(""), Err(PathPrefixErr::BadPathPrefix)); + assert_eq!(extract_path_prefix("p"), Err(PathPrefixErr::BadPathPrefix)); + assert_eq!(extract_path_prefix("\\"), Err(PathPrefixErr::BadPathPrefix)); + assert_eq!(extract_path_prefix("prefix/events"), Err(PathPrefixErr::BadPathPrefix)); + assert_eq!(extract_path_prefix("prefix/a/events"), Err(PathPrefixErr::BadPathPrefix)); + assert_eq!(extract_path_prefix("prefix/a/b/events"), Err(PathPrefixErr::BadPathPrefix)); + } }