From d1359b03248938753d3fb8f2c16f8d9dc181b055 Mon Sep 17 00:00:00 2001 From: James Rhodes Date: Thu, 20 Jun 2024 14:38:36 +0100 Subject: [PATCH] Use more robust http parsing in the host header tests Signed-off-by: James Rhodes --- Cargo.lock | 5 +- Cargo.toml | 1 + crates/extensions/c8y_auth_proxy/Cargo.toml | 1 + .../extensions/c8y_auth_proxy/src/server.rs | 65 ++++++++++++++----- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ca0d3d3525f..ea2b637ad9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -625,6 +625,7 @@ dependencies = [ "camino", "env_logger", "futures", + "httparse", "hyper", "mockito", "rcgen", @@ -1577,9 +1578,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" [[package]] name = "httpdate" diff --git a/Cargo.toml b/Cargo.toml index b30caa75342..e424a157fec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ glob = "0.3" heck = "0.4.1" http = "0.2" http-body = "0.4" +httparse = "1.9.3" humantime = "2.1.0" hyper = { version = "0.14", default-features = false } hyper-rustls = { version = "0.24", default_features = false, features = [ diff --git a/crates/extensions/c8y_auth_proxy/Cargo.toml b/crates/extensions/c8y_auth_proxy/Cargo.toml index 511e7b24446..59a215640cb 100644 --- a/crates/extensions/c8y_auth_proxy/Cargo.toml +++ b/crates/extensions/c8y_auth_proxy/Cargo.toml @@ -35,6 +35,7 @@ url = { workspace = true } [dev-dependencies] env_logger = { workspace = true } +httparse = { workspace = true } mockito = { workspace = true } rcgen = { workspace = true } rustls = { workspace = true, features = ["dangerous_configuration"] } diff --git a/crates/extensions/c8y_auth_proxy/src/server.rs b/crates/extensions/c8y_auth_proxy/src/server.rs index cb7c7269581..4c49e38f66f 100644 --- a/crates/extensions/c8y_auth_proxy/src/server.rs +++ b/crates/extensions/c8y_auth_proxy/src/server.rs @@ -153,6 +153,7 @@ impl From for AppState { target_host: TargetHost { http: format!("{http}://{host}").into(), ws: format!("{ws}://{host}").into(), + without_scheme: host.into(), }, token_manager: value.token_manager, } @@ -175,6 +176,7 @@ impl FromRef for SharedTokenManager { struct TargetHost { http: Arc, ws: Arc, + without_scheme: Arc, } fn axum_to_tungstenite(message: axum::extract::ws::Message) -> tungstenite::Message { @@ -218,6 +220,7 @@ async fn connect_to_websocket( token: &str, headers: &HeaderMap, uri: &str, + host: &TargetHost, ) -> Result>, tokio_tungstenite::tungstenite::Error> { let mut req = Request::builder(); for (name, value) in headers { @@ -226,7 +229,7 @@ async fn connect_to_websocket( req = req.header("Authorization", format!("Bearer {token}")); let req = req .uri(uri) - .header(HOST, url::Url::parse(uri).expect("URI must have valid host to reach c8y proxy").host_str().unwrap()) + .header(HOST, host.without_scheme.as_ref()) .body(()) .expect("Builder should always work as the headers are copied from a previous request, so must be valid"); tokio_tungstenite::connect_async(req) @@ -245,11 +248,11 @@ async fn proxy_ws( use tungstenite::error::Error; let uri = format!("{}/{path}", host.ws); let mut token = retrieve_token.not_matching(None).await; - let c8y = match connect_to_websocket(&token, &headers, &uri).await { + let c8y = match connect_to_websocket(&token, &headers, &uri, &host).await { Ok(c8y) => Ok(c8y), Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => { token = retrieve_token.not_matching(Some(&token)).await; - match connect_to_websocket(&token, &headers, &uri).await { + match connect_to_websocket(&token, &headers, &uri, &host).await { Ok(c8y) => Ok(c8y), Err(e) => { Err(anyhow::Error::from(e).context("Failed to connect to proxied websocket")) @@ -568,26 +571,56 @@ mod tests { .unwrap() .unwrap(); - let mut incoming_payload = Vec::with_capacity(256); - tcp_stream.read_buf(&mut incoming_payload).await.unwrap(); - let incoming_payload = std::str::from_utf8(&incoming_payload).unwrap(); + let request = parse_raw_request(&mut tcp_stream).await; + tcp_stream .write_all(b"HTTP/1.1 204 No Content") .await .unwrap(); - assert!( - !incoming_payload.contains(&format!("host: {}", proxy_host)), - "Found host header with incorrect value {proxy_host:?} in incoming request:\n{}", - indent(incoming_payload) - ); - assert!(incoming_payload.contains(&format!("host: {}", destination_host)), "Did not find correct host header {destination_host:?} in incoming request:\n{incoming_payload}"); + assert_eq!(host_header_values(&request), [&destination_host], "Did not find correct host header. The value should be the proxy destination ({destination_host}), not the proxy itself ({proxy_host})"); + } + + #[tokio::test] + async fn does_not_forward_host_header_for_websocket_requests() { + let target_host = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let target = target_host.local_addr().unwrap(); + + let proxy_port = start_server_port(target.port(), vec!["unused token"]); + tokio::spawn(async move { + connect_to_websocket_port(proxy_port).await; + }); + + let proxy_host = format!("127.0.0.1:{proxy_port}"); + let destination_host = format!("127.0.0.1:{}", target.port()); + + let (mut tcp_stream, _) = + tokio::time::timeout(Duration::from_secs(5), target_host.accept()) + .await + .unwrap() + .unwrap(); + + let request = parse_raw_request(&mut tcp_stream).await; + + assert_eq!(host_header_values(&request), [&destination_host], "Did not find correct host header. The value should be the proxy destination ({destination_host}), not the proxy itself ({proxy_host})"); + } + + async fn parse_raw_request(tcp_stream: &mut TcpStream) -> httparse::Request<'static, 'static> { + let mut incoming_payload = Vec::with_capacity(10000); + tcp_stream.read_buf(&mut incoming_payload).await.unwrap(); + let headers = Vec::from([httparse::EMPTY_HEADER; 64]).leak(); + let mut request = httparse::Request::new(headers); + request.parse(incoming_payload.leak()).unwrap(); + + request } - fn indent(text: &str) -> String { - text.lines() - .map(|l| format!(" {l}")) + fn host_header_values<'a>(request: &httparse::Request<'a, '_>) -> Vec<&'a str> { + request + .headers + .iter() + .filter(|header| header.name.to_lowercase() == "host") + .map(|header| std::str::from_utf8(header.value).unwrap()) .collect::>() - .join("\n") } #[tokio::test]