Skip to content

Commit

Permalink
Use more robust http parsing in the host header tests
Browse files Browse the repository at this point in the history
Signed-off-by: James Rhodes <[email protected]>
  • Loading branch information
jarhodes314 committed Jun 20, 2024
1 parent dc64912 commit d1359b0
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ glob = "0.3"
heck = "0.4.1"
http = "0.2"
http-body = "0.4"
httparse = "1.9.3"
humantime = "2.1.0"
hyper = { version = "0.14", default-features = false }
hyper-rustls = { version = "0.24", default_features = false, features = [
Expand Down
1 change: 1 addition & 0 deletions crates/extensions/c8y_auth_proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ url = { workspace = true }

[dev-dependencies]
env_logger = { workspace = true }
httparse = { workspace = true }
mockito = { workspace = true }
rcgen = { workspace = true }
rustls = { workspace = true, features = ["dangerous_configuration"] }
Expand Down
65 changes: 49 additions & 16 deletions crates/extensions/c8y_auth_proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ impl From<AppData> for AppState {
target_host: TargetHost {
http: format!("{http}://{host}").into(),
ws: format!("{ws}://{host}").into(),
without_scheme: host.into(),
},
token_manager: value.token_manager,
}
Expand All @@ -175,6 +176,7 @@ impl FromRef<AppState> for SharedTokenManager {
struct TargetHost {
http: Arc<str>,
ws: Arc<str>,
without_scheme: Arc<str>,
}

fn axum_to_tungstenite(message: axum::extract::ws::Message) -> tungstenite::Message {
Expand Down Expand Up @@ -218,6 +220,7 @@ async fn connect_to_websocket(
token: &str,
headers: &HeaderMap<HeaderValue>,
uri: &str,
host: &TargetHost,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
let mut req = Request::builder();
for (name, value) in headers {
Expand All @@ -226,7 +229,7 @@ async fn connect_to_websocket(
req = req.header("Authorization", format!("Bearer {token}"));
let req = req
.uri(uri)
.header(HOST, url::Url::parse(uri).expect("URI must have valid host to reach c8y proxy").host_str().unwrap())
.header(HOST, host.without_scheme.as_ref())
.body(())
.expect("Builder should always work as the headers are copied from a previous request, so must be valid");
tokio_tungstenite::connect_async(req)
Expand All @@ -245,11 +248,11 @@ async fn proxy_ws(
use tungstenite::error::Error;
let uri = format!("{}/{path}", host.ws);
let mut token = retrieve_token.not_matching(None).await;
let c8y = match connect_to_websocket(&token, &headers, &uri).await {
let c8y = match connect_to_websocket(&token, &headers, &uri, &host).await {
Ok(c8y) => Ok(c8y),
Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => {
token = retrieve_token.not_matching(Some(&token)).await;
match connect_to_websocket(&token, &headers, &uri).await {
match connect_to_websocket(&token, &headers, &uri, &host).await {
Ok(c8y) => Ok(c8y),
Err(e) => {
Err(anyhow::Error::from(e).context("Failed to connect to proxied websocket"))
Expand Down Expand Up @@ -568,26 +571,56 @@ mod tests {
.unwrap()
.unwrap();

let mut incoming_payload = Vec::with_capacity(256);
tcp_stream.read_buf(&mut incoming_payload).await.unwrap();
let incoming_payload = std::str::from_utf8(&incoming_payload).unwrap();
let request = parse_raw_request(&mut tcp_stream).await;

tcp_stream
.write_all(b"HTTP/1.1 204 No Content")
.await
.unwrap();
assert!(
!incoming_payload.contains(&format!("host: {}", proxy_host)),
"Found host header with incorrect value {proxy_host:?} in incoming request:\n{}",
indent(incoming_payload)
);
assert!(incoming_payload.contains(&format!("host: {}", destination_host)), "Did not find correct host header {destination_host:?} in incoming request:\n{incoming_payload}");
assert_eq!(host_header_values(&request), [&destination_host], "Did not find correct host header. The value should be the proxy destination ({destination_host}), not the proxy itself ({proxy_host})");
}

#[tokio::test]
async fn does_not_forward_host_header_for_websocket_requests() {
let target_host = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let target = target_host.local_addr().unwrap();

let proxy_port = start_server_port(target.port(), vec!["unused token"]);
tokio::spawn(async move {
connect_to_websocket_port(proxy_port).await;
});

let proxy_host = format!("127.0.0.1:{proxy_port}");
let destination_host = format!("127.0.0.1:{}", target.port());

let (mut tcp_stream, _) =
tokio::time::timeout(Duration::from_secs(5), target_host.accept())
.await
.unwrap()
.unwrap();

let request = parse_raw_request(&mut tcp_stream).await;

assert_eq!(host_header_values(&request), [&destination_host], "Did not find correct host header. The value should be the proxy destination ({destination_host}), not the proxy itself ({proxy_host})");
}

async fn parse_raw_request(tcp_stream: &mut TcpStream) -> httparse::Request<'static, 'static> {
let mut incoming_payload = Vec::with_capacity(10000);
tcp_stream.read_buf(&mut incoming_payload).await.unwrap();
let headers = Vec::from([httparse::EMPTY_HEADER; 64]).leak();
let mut request = httparse::Request::new(headers);
request.parse(incoming_payload.leak()).unwrap();

request
}

fn indent(text: &str) -> String {
text.lines()
.map(|l| format!(" {l}"))
fn host_header_values<'a>(request: &httparse::Request<'a, '_>) -> Vec<&'a str> {
request
.headers
.iter()
.filter(|header| header.name.to_lowercase() == "host")
.map(|header| std::str::from_utf8(header.value).unwrap())
.collect::<Vec<_>>()
.join("\n")
}

#[tokio::test]
Expand Down

0 comments on commit d1359b0

Please sign in to comment.