From 8ba77745754c64ac198d9f416a3d4b79fcb8d70b Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 17:44:18 +0200 Subject: [PATCH 01/21] reduce lifetime of printer closure passed to client middleware --- src/main.rs | 31 ++++++++++++--------------- src/middleware.rs | 54 ++++++++++++++++------------------------------- 2 files changed, 32 insertions(+), 53 deletions(-) diff --git a/src/main.rs b/src/main.rs index 68526a22..699fdd25 100644 --- a/src/main.rs +++ b/src/main.rs @@ -538,21 +538,26 @@ fn run(args: Cli) -> Result { printer.print_request_body(&mut request)?; } + let mut client = ClientWithMiddleware::new(client); + if !args.offline { let mut response = { let history_print = args.history_print.unwrap_or(print); - let mut client = ClientWithMiddleware::new(&client); - if args.all { - client = client.with_printer(|prev_response, next_request| { + if args.follow { + client = client.with(RedirectFollower::new(args.max_redirects.unwrap_or(10))); + } + if let Some(Auth::Digest(username, password)) = &auth { + client = client.with(DigestAuthMiddleware::new(username, password)); + } + client.execute(request, |prev_response, next_request| { + if !args.all { + return Ok(()); + } if history_print.response_headers { printer.print_response_headers(prev_response)?; } if history_print.response_body { - printer.print_response_body( - prev_response, - response_charset, - response_mime, - )?; + printer.print_response_body(prev_response, response_charset, response_mime)?; printer.print_separator()?; } if history_print.response_meta { @@ -565,15 +570,7 @@ fn run(args: Cli) -> Result { printer.print_request_body(next_request)?; } Ok(()) - }); - } - if args.follow { - client = client.with(RedirectFollower::new(args.max_redirects.unwrap_or(10))); - } - if let Some(Auth::Digest(username, password)) = &auth { - client = client.with(DigestAuthMiddleware::new(username, password)); - } - client.execute(request)? + })? }; let status = response.status(); diff --git a/src/middleware.rs b/src/middleware.rs index d60c24a2..c0f61be9 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -24,18 +24,18 @@ impl ResponseExt for Response { } } -type Printer<'a, 'b> = &'a mut (dyn FnMut(&mut Response, &mut Request) -> Result<()> + 'b); +type Printer<'a> = &'a mut (dyn FnMut(&mut Response, &mut Request) -> Result<()> + 'a); -pub struct Context<'a, 'b> { - client: &'a Client, - printer: Option>, +pub struct Context<'a, 'b, 'c> { + client: Client, + printer: Printer<'c>, middlewares: &'a mut [Box], } -impl<'a, 'b> Context<'a, 'b> { +impl<'a, 'b, 'c> Context<'a, 'b, 'c> { fn new( - client: &'a Client, - printer: Option>, + client: Client, + printer: Printer<'c>, middlewares: &'a mut [Box], ) -> Self { Context { @@ -57,8 +57,7 @@ impl<'a, 'b> Context<'a, 'b> { Ok(response) } [ref mut head, tail @ ..] => head.handle( - #[allow(clippy::needless_option_as_deref)] - Context::new(self.client, self.printer.as_deref_mut(), tail), + Context::new(self.client.clone(), self.printer, tail), request, ), } @@ -78,51 +77,34 @@ pub trait Middleware { response: &mut Response, request: &mut Request, ) -> Result<()> { - if let Some(ref mut printer) = ctx.printer { - printer(response, request)?; - } - + (ctx.printer)(response, request)?; Ok(()) } } -pub struct ClientWithMiddleware<'a, T> -where - T: FnMut(&mut Response, &mut Request) -> Result<()>, -{ - client: &'a Client, - printer: Option, +pub struct ClientWithMiddleware<'a> { + client: Client, middlewares: Vec>, } -impl<'a, T> ClientWithMiddleware<'a, T> -where - T: FnMut(&mut Response, &mut Request) -> Result<()> + 'a, -{ - pub fn new(client: &'a Client) -> Self { +impl<'a> ClientWithMiddleware<'a> { + pub fn new(client: Client) -> Self { ClientWithMiddleware { client, - printer: None, middlewares: vec![], } } - pub fn with_printer(mut self, printer: T) -> Self { - self.printer = Some(printer); - self - } - pub fn with(mut self, middleware: impl Middleware + 'a) -> Self { self.middlewares.push(Box::new(middleware)); self } - pub fn execute(&mut self, request: Request) -> Result { - let mut ctx = Context::new( - self.client, - self.printer.as_mut().map(|p| p as _), - &mut self.middlewares[..], - ); + pub fn execute<'b, T>(&mut self, request: Request, mut printer: T) -> Result + where + T: FnMut(&mut Response, &mut Request) -> Result<()> + 'b, + { + let mut ctx = Context::new(self.client.clone(), &mut printer, &mut self.middlewares[..]); ctx.execute(request) } } From aa8fbf562ab15d53a6ef2101d941887be27a4d46 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 18:19:16 +0200 Subject: [PATCH 02/21] support http over unix domain sockets --- Cargo.lock | 1 + Cargo.toml | 3 ++ src/cli.rs | 6 +++ src/main.rs | 47 ++++++++++++-------- src/to_curl.rs | 5 +++ src/unix_socket.rs | 104 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 148 insertions(+), 18 deletions(-) create mode 100644 src/unix_socket.rs diff --git a/Cargo.lock b/Cargo.lock index 458e4574..e3373a65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2523,6 +2523,7 @@ dependencies = [ "env_logger", "flate2", "form_urlencoded", + "http", "http-body-util", "hyper", "hyper-util", diff --git a/Cargo.toml b/Cargo.toml index 0cadc32d..ae6c5656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,10 @@ dirs = "5.0" encoding_rs = "0.8.28" encoding_rs_io = "0.1.7" flate2 = "1.0.22" +http = "1" # Add "tracing" feature to hyper once it stabilizes hyper = { version = "1.2", default-features = false } +hyper-util = { version = "0.1", features = ["tokio"] } indicatif = "0.17" jsonxf = "1.1.0" memchr = "2.4.1" @@ -47,6 +49,7 @@ serde_urlencoded = "0.7.0" supports-hyperlinks = "3.0.0" termcolor = "1.1.2" time = "0.3.16" +tokio = { version = "1", features = ["rt-multi-thread"] } unicode-width = "0.1.9" url = "2.2.2" ruzstd = { version = "0.7", default-features = false, features = ["std"]} diff --git a/src/cli.rs b/src/cli.rs index a6b37d59..5f6f741b 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -347,6 +347,12 @@ Example: --print=Hb" #[clap(short = '6', long)] pub ipv6: bool, + /// Connect using a Unix domain socket. + /// + /// Example: --unix_socket=/var/run/temp.sock + #[clap(long, value_name = "FILE")] + pub unix_socket: Option, + /// Do not attempt to read stdin. /// /// This disables the default behaviour of reading the request body from stdin diff --git a/src/main.rs b/src/main.rs index 699fdd25..66f89903 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,8 @@ mod redirect; mod request_items; mod session; mod to_curl; +#[cfg(target_family = "unix")] +mod unix_socket; mod utils; mod vendored; @@ -35,7 +37,6 @@ use reqwest::header::{ }; use reqwest::tls; use url::Host; -use utils::reason_phrase; use crate::auth::{Auth, DigestAuthMiddleware}; use crate::buffer::Buffer; @@ -45,7 +46,9 @@ use crate::middleware::ClientWithMiddleware; use crate::printer::Printer; use crate::request_items::{Body, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; use crate::session::Session; -use crate::utils::{test_mode, test_pretend_term, url_with_query}; +#[cfg(target_family = "unix")] +use crate::unix_socket::UnixSocket; +use crate::utils::{reason_phrase, test_mode, test_pretend_term, url_with_query}; use crate::vendored::reqwest_cookie_store; #[cfg(not(any(feature = "native-tls", feature = "rustls")))] @@ -549,27 +552,35 @@ fn run(args: Cli) -> Result { if let Some(Auth::Digest(username, password)) = &auth { client = client.with(DigestAuthMiddleware::new(username, password)); } + #[cfg(target_family = "unix")] + if let Some(unix_socket) = args.unix_socket { + client = client.with(UnixSocket::new(unix_socket)); + } + #[cfg(not(target_family = "unix"))] + if let Some(_) = args.unix_socket { + log::warn!("HTTP over Unix domain sockets is not supported on this platform"); + } client.execute(request, |prev_response, next_request| { if !args.all { return Ok(()); } - if history_print.response_headers { - printer.print_response_headers(prev_response)?; - } - if history_print.response_body { + if history_print.response_headers { + printer.print_response_headers(prev_response)?; + } + if history_print.response_body { printer.print_response_body(prev_response, response_charset, response_mime)?; - printer.print_separator()?; - } - if history_print.response_meta { - printer.print_response_meta(prev_response)?; - } - if history_print.request_headers { - printer.print_request_headers(next_request, &*cookie_jar)?; - } - if history_print.request_body { - printer.print_request_body(next_request)?; - } - Ok(()) + printer.print_separator()?; + } + if history_print.response_meta { + printer.print_response_meta(prev_response)?; + } + if history_print.request_headers { + printer.print_request_headers(next_request, &*cookie_jar)?; + } + if history_print.request_body { + printer.print_request_body(next_request)?; + } + Ok(()) })? }; diff --git a/src/to_curl.rs b/src/to_curl.rs index 744970e4..a6f0f3d7 100644 --- a/src/to_curl.rs +++ b/src/to_curl.rs @@ -299,6 +299,11 @@ pub fn translate(args: Cli) -> Result { cmd.arg(interface); }; + if let Some(unix_socket) = args.unix_socket { + cmd.arg("--unix-socket"); + cmd.arg(unix_socket); + } + if !args.resolve.is_empty() { let port = url .port_or_known_default() diff --git a/src/unix_socket.rs b/src/unix_socket.rs new file mode 100644 index 00000000..7187d76c --- /dev/null +++ b/src/unix_socket.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use reqwest::blocking::{Request, Response}; +use reqwest::header::{HeaderValue, HOST}; +use std::path::PathBuf; +use std::time::Instant; + +use crate::middleware::{Context, Middleware, ResponseMeta}; +use crate::utils::test_mode; + +pub struct UnixSocket { + rt: tokio::runtime::Runtime, + socket_path: PathBuf, +} + +impl UnixSocket { + pub fn new(socket_path: PathBuf) -> Self { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + Self { rt, socket_path } + } + + pub fn execute(&self, request: Request) -> Result { + self.rt.block_on(async { + // TODO: Support named pipes by replacing tokio::net::UnixStream::connect(..) with: + // + // use std::time::Duration; + // use tokio::net::windows::named_pipe; + // use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; + // + // let stream = loop { + // match named_pipe::ClientOptions::new().open(r"\\.\pipe\docker_engine") { + // Ok(client) => break client, + // Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + // Err(e) => return Err(e)?, + // } + // + // tokio::time::sleep(Duration::from_millis(50)).await; + // }; + let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; + + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .title_case_headers(true) + .handshake(hyper_util::rt::TokioIo::new(stream)) + .await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + log::error!("Connection failed: {:?}", err); + } + }); + + // TODO: figure out how to support cookies. + // TODO: don't ignore value from --timeout option + let http_request = into_async_request(request)?; + let response = sender.send_request(http_request.try_into()?).await?; + + Ok(Response::from(response.map(|b| reqwest::Body::wrap(b)))) + }) + } +} + +impl Middleware for UnixSocket { + fn handle(&mut self, mut _ctx: Context, request: Request) -> Result { + let starting_time = Instant::now(); + let mut response = self.execute(request)?; + response.extensions_mut().insert(ResponseMeta { + request_duration: starting_time.elapsed(), + content_download_duration: None, + }); + Ok(response) + } +} + +fn into_async_request(mut request: Request) -> Result> { + let mut http_request = http::Request::builder() + .version(request.version()) + .method(request.method()) + .uri(request.url().as_str()) + .body(reqwest::Body::default())?; + + *http_request.headers_mut() = request.headers_mut().clone(); + + if let Some(host) = request.url().host_str() { + http_request.headers_mut().entry(HOST).or_insert_with(|| { + if test_mode() { + HeaderValue::from_str("http.mock") + } else if let Some(port) = request.url().port() { + HeaderValue::from_str(&format!("{}:{}", host, port)) + } else { + HeaderValue::from_str(host) + } + .expect("hostname should already be validated/parsed") + }); + } + + if let Some(body) = request.body_mut().as_mut() { + *http_request.body_mut() = reqwest::Body::from(body.buffer()?.to_owned()); + } + + Ok(http_request) +} From fe56317ad41745f0a239def7785f8e6e32caed55 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 18:21:20 +0200 Subject: [PATCH 03/21] implement test server for http_unix --- tests/server/mod.rs | 129 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 20 deletions(-) diff --git a/tests/server/mod.rs b/tests/server/mod.rs index 9ca6c604..16935bac 100644 --- a/tests/server/mod.rs +++ b/tests/server/mod.rs @@ -2,7 +2,7 @@ // with some slight tweaks use std::convert::Infallible; use std::future::Future; -use std::net; +use std::path::PathBuf; use std::sync::mpsc as std_mpsc; use std::sync::{Arc, Mutex}; use std::thread; @@ -18,8 +18,20 @@ use tokio::sync::oneshot; type Body = Full; type Builder = hyper_util::server::conn::auto::Builder; +enum Addr { + TcpAddr(std::net::SocketAddr), + #[cfg(target_family = "unix")] + UnixAddr(tokio::net::unix::SocketAddr), +} + +enum Listener { + TcpListener(tokio::net::TcpListener), + #[cfg(target_family = "unix")] + UnixListener(tokio::net::UnixListener), +} + pub struct Server { - addr: net::SocketAddr, + addr: Addr, panic_rx: std_mpsc::Receiver<()>, successful_hits: Arc>, total_hits: Arc>, @@ -29,19 +41,43 @@ pub struct Server { impl Server { pub fn base_url(&self) -> String { - format!("http://{}", self.addr) + match self.addr { + Addr::TcpAddr(addr) => format!("http://{}", addr), + #[cfg(target_family = "unix")] + _ => panic!("no base_url for unix server"), + } } pub fn url(&self, path: &str) -> String { - format!("http://{}{}", self.addr, path) + match self.addr { + Addr::TcpAddr(addr) => format!("http://{}{}", addr, path), + #[cfg(target_family = "unix")] + _ => panic!("no url for unix server"), + } } pub fn host(&self) -> String { - String::from("127.0.0.1") + match self.addr { + Addr::TcpAddr(_) => String::from("127.0.0.1"), + #[cfg(target_family = "unix")] + _ => panic!("no host for unix server"), + } + } + + #[cfg(target_family = "unix")] + pub fn socket_path(&self) -> PathBuf { + match &self.addr { + Addr::UnixAddr(addr) => addr.as_pathname().unwrap().to_path_buf(), + _ => panic!("no socket_path for tcp server"), + } } pub fn port(&self) -> u16 { - self.addr.port() + match self.addr { + Addr::TcpAddr(addr) => addr.port(), + #[cfg(target_family = "unix")] + _ => panic!("no port for unix server"), + } } pub fn assert_hits(&self, hits: u8) { @@ -89,13 +125,36 @@ where F: Fn(Request) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { - http_inner(Arc::new(move |req| Box::new(Box::pin(func(req))))) + http_inner(Arc::new(move |req| Box::new(Box::pin(func(req)))), None) +} + +#[cfg(target_family = "unix")] +pub fn http_unix(func: F) -> Server +where + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + use rand::Rng; + let file_name: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(10) + .map(char::from) + .collect(); + let path = PathBuf::from(format!("/tmp/{file_name}.sock")); + if path.exists() { + std::fs::remove_file(&path).expect("could not remove old socket"); + } + + http_inner( + Arc::new(move |req| Box::new(Box::pin(func(req)))), + Some(path), + ) } type Serv = dyn Fn(Request) -> Box + Send + Sync; type ServFut = dyn Future> + Send + Unpin; -fn http_inner(func: Arc) -> Server { +fn http_inner(func: Arc, socket_path: Option) -> Server { // Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { let rt = runtime::Builder::new_current_thread() @@ -104,12 +163,30 @@ fn http_inner(func: Arc) -> Server { .expect("new rt"); let successful_hits = Arc::new(Mutex::new(0)); let total_hits = Arc::new(Mutex::new(0)); - let listener = rt.block_on(async move { - tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) - .await - .unwrap() + + let (listener, addr) = rt.block_on(async move { + #[allow(unused_variables)] + if let Some(path) = &socket_path { + #[cfg(target_family = "unix")] + { + let listener = tokio::net::UnixListener::bind(path).unwrap(); + let addr = listener.local_addr().unwrap(); + (Listener::UnixListener(listener), Addr::UnixAddr(addr)) + } + + #[cfg(not(target_family = "unix"))] + { + unreachable!("cannot create http_unix server outside of unix target_family") + } + } else { + let listener = + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + (Listener::TcpListener(listener), Addr::TcpAddr(addr)) + } }); - let addr = listener.local_addr().unwrap(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let (panic_tx, panic_rx) = std_mpsc::channel(); @@ -145,14 +222,26 @@ fn http_inner(func: Arc) -> Server { }) }; - let (io, _) = listener.accept().await.unwrap(); - let builder = builder.clone(); - tokio::spawn(async move { - let _ = builder - .serve_connection(hyper_util::rt::TokioIo::new(io), svc) - .await; - }); + match &listener { + Listener::TcpListener(listener) => { + let (io, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let _ = builder + .serve_connection(hyper_util::rt::TokioIo::new(io), svc) + .await; + }); + } + #[cfg(target_family = "unix")] + Listener::UnixListener(listener) => { + let (io, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let _ = builder + .serve_connection(hyper_util::rt::TokioIo::new(io), svc) + .await; + }); + } + } } }); let _ = rt.block_on(shutdown_rx); From 531dba91c93d1867a626f38fda776b1e0d0367e5 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 18:23:38 +0200 Subject: [PATCH 04/21] add initial tests for http_unix --- tests/cases/http_unix.rs | 75 ++++++++++++++++++++++++++++++++++++++++ tests/cases/mod.rs | 1 + 2 files changed, 76 insertions(+) create mode 100644 tests/cases/http_unix.rs diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs new file mode 100644 index 00000000..cf7ba796 --- /dev/null +++ b/tests/cases/http_unix.rs @@ -0,0 +1,75 @@ +use indoc::indoc; + +use crate::prelude::*; + +#[test] +fn json_post() { + let server = server::http_unix(|req| async move { + assert_eq!(req.method(), "POST"); + assert_eq!(req.headers()["Content-Type"], "application/json"); + assert_eq!(req.body_as_string().await, "{\"foo\":\"bar\"}"); + + hyper::Response::builder() + .header(hyper::header::CONTENT_TYPE, "application/json") + .body(r#"{"status":"ok"}"#.into()) + .unwrap() + }); + + get_command() + .arg("--print=b") + .arg("--pretty=format") + .arg("post") + .arg("http://example.com") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("foo=bar") + .assert() + .stdout(indoc! {r#" + { + "status": "ok" + } + + + "#}); +} + +#[test] +fn redirects_stay_on_same_server() { + let server = server::http_unix(|req| async move { + match dbg!(req.uri().to_string().as_str()) { + "http://example.com/first_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "http://localhost:8000/second_page") + .body("redirecting...".into()) + .unwrap(), + "http://localhost:8000/second_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "/third_page") + .body("redirecting...".into()) + .unwrap(), + "http://localhost:8000/third_page" => hyper::Response::builder() + .header("Date", "N/A") + .body("final destination".into()) + .unwrap(), + _ => panic!("unknown path"), + } + }); + + get_command() + .arg("http://example.com/first_page") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--follow") + .assert() + .success(); + + server.assert_hits(3); +} + +// TODO: add tests for cookies diff --git a/tests/cases/mod.rs b/tests/cases/mod.rs index 7059ce95..53f7c76d 100644 --- a/tests/cases/mod.rs +++ b/tests/cases/mod.rs @@ -1 +1,2 @@ +mod http_unix; mod logging; From bf4b1b822e62d8c72b28517ea7aa49e22fa5be13 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 19:06:14 +0200 Subject: [PATCH 05/21] fix clippy warnings --- src/unix_socket.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/unix_socket.rs b/src/unix_socket.rs index 7187d76c..a2af0858 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -55,9 +55,9 @@ impl UnixSocket { // TODO: figure out how to support cookies. // TODO: don't ignore value from --timeout option let http_request = into_async_request(request)?; - let response = sender.send_request(http_request.try_into()?).await?; + let response = sender.send_request(http_request).await?; - Ok(Response::from(response.map(|b| reqwest::Body::wrap(b)))) + Ok(Response::from(response.map(reqwest::Body::wrap))) }) } } From 80242c27d5195ea747be546aeedb27740329be69 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sun, 29 Dec 2024 19:19:44 +0200 Subject: [PATCH 06/21] disable http_unix tests in windows --- tests/cases/http_unix.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index cf7ba796..18d74b32 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -1,7 +1,10 @@ +#[cfg(target_family = "unix")] use indoc::indoc; +#[cfg(target_family = "unix")] use crate::prelude::*; +#[cfg(target_family = "unix")] #[test] fn json_post() { let server = server::http_unix(|req| async move { @@ -35,6 +38,7 @@ fn json_post() { "#}); } +#[cfg(target_family = "unix")] #[test] fn redirects_stay_on_same_server() { let server = server::http_unix(|req| async move { From 8cee2fe98ee323b61c7f532bfd2153a5cef15f5c Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 10:17:35 +0200 Subject: [PATCH 07/21] add middleware for managing cookies --- src/cookie.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 8 +++++--- src/unix_socket.rs | 1 - 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 src/cookie.rs diff --git a/src/cookie.rs b/src/cookie.rs new file mode 100644 index 00000000..1ffa2dd2 --- /dev/null +++ b/src/cookie.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use anyhow::Result; +use reqwest::{ + blocking::{Request, Response}, + cookie::CookieStore, + header, +}; + +use crate::middleware::{Context, Middleware}; + +pub struct CookieMiddleware(Arc); + +impl CookieMiddleware { + pub fn new(cookie_jar: Arc) -> Self { + CookieMiddleware(cookie_jar) + } +} + +impl Middleware for CookieMiddleware { + fn handle(&mut self, mut ctx: Context, mut request: Request) -> Result { + let url = request.url().clone(); + + if let Some(header) = self.0.cookies(&url) { + request + .headers_mut() + .entry(header::COOKIE) + .or_insert(header); + } + + let response = self.next(&mut ctx, request)?; + + let mut cookies = response + .headers() + .get_all(header::SET_COOKIE) + .iter() + .peekable(); + if cookies.peek().is_some() { + self.0.set_cookies(&mut cookies, &url); + } + + Ok(response) + } +} diff --git a/src/main.rs b/src/main.rs index 66f89903..c0a845fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod auth; mod buffer; mod cli; +mod cookie; mod decoder; mod download; mod formatting; @@ -29,6 +30,7 @@ use std::str::FromStr; use std::sync::Arc; use anyhow::{anyhow, Context, Result}; +use cookie::CookieMiddleware; use cookie_store::{CookieStore, RawCookie}; use redirect::RedirectFollower; use reqwest::blocking::Client; @@ -285,9 +287,6 @@ fn run(args: Cli) -> Result { None => client, }; - let cookie_jar = Arc::new(reqwest_cookie_store::CookieStoreMutex::default()); - client = client.cookie_provider(cookie_jar.clone()); - client = match (args.ipv4, args.ipv6) { (true, false) => client.local_address(IpAddr::from(Ipv4Addr::UNSPECIFIED)), (false, true) => client.local_address(IpAddr::from(Ipv6Addr::UNSPECIFIED)), @@ -339,6 +338,8 @@ fn run(args: Cli) -> Result { log::trace!("{client:#?}"); let client = client.build()?; + let cookie_jar = Arc::new(reqwest_cookie_store::CookieStoreMutex::default()); + let mut session = match &args.session { Some(name_or_path) => Some( Session::load_session(url.clone(), name_or_path.clone(), args.is_session_read_only) @@ -552,6 +553,7 @@ fn run(args: Cli) -> Result { if let Some(Auth::Digest(username, password)) = &auth { client = client.with(DigestAuthMiddleware::new(username, password)); } + client = client.with(CookieMiddleware::new(cookie_jar.clone())); #[cfg(target_family = "unix")] if let Some(unix_socket) = args.unix_socket { client = client.with(UnixSocket::new(unix_socket)); diff --git a/src/unix_socket.rs b/src/unix_socket.rs index a2af0858..ec05255d 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -52,7 +52,6 @@ impl UnixSocket { } }); - // TODO: figure out how to support cookies. // TODO: don't ignore value from --timeout option let http_request = into_async_request(request)?; let response = sender.send_request(http_request).await?; From 652e3df82c7abad436ba440a7ef8af47874ae4b3 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 10:44:21 +0200 Subject: [PATCH 08/21] add test for cookies --- tests/cases/http_unix.rs | 109 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 2 deletions(-) diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index 18d74b32..16cb80c8 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -70,10 +70,115 @@ fn redirects_stay_on_same_server() { server.socket_path().to_string_lossy() )) .arg("--follow") + .arg("--verbose") + .arg("--all") .assert() - .success(); + .stdout(indoc! {r#" + GET /first_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: http://localhost:8000/second_page + + redirecting... + + GET /second_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: /third_page + + redirecting... + + GET /third_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 200 OK + Content-Length: 17 + Date: N/A + + final destination + "#}); server.assert_hits(3); } -// TODO: add tests for cookies +#[cfg(target_family = "unix")] +#[test] +fn cookies_persist_across_redirects() { + let server = server::http_unix(|req| async move { + match req.uri().path() { + "/first_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "/second_page") + .header("set-cookie", "hello=world") + .body("redirecting...".into()) + .unwrap(), + "/second_page" => hyper::Response::builder() + .header("Date", "N/A") + .body("final destination".into()) + .unwrap(), + _ => panic!("unknown path"), + } + }); + + get_command() + .arg("localhost:3000/first_page") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--follow") + .arg("--verbose") + .arg("--all") + .assert() + .stdout(indoc! {r#" + GET /first_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: /second_page + Set-Cookie: hello=world + + redirecting... + + GET /second_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Cookie: hello=world + Host: http.mock + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 200 OK + Content-Length: 17 + Date: N/A + + final destination + "#}); +} + +// TODO: add tests for connection timeout From b6d59d9f9d16faaa2d5b610ac7b0cac9c50bc0a4 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 11:23:17 +0200 Subject: [PATCH 09/21] avoid mocking host header --- src/printer.rs | 6 ++--- src/unix_socket.rs | 5 +--- tests/cases/http_unix.rs | 10 ++++---- tests/cli.rs | 52 ++++++++++++++++++++-------------------- 4 files changed, 34 insertions(+), 39 deletions(-) diff --git a/src/printer.rs b/src/printer.rs index 729bba26..cd4eb822 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -20,7 +20,7 @@ use crate::{ formatting::serde_json_format, formatting::{get_json_formatter, Highlighter}, middleware::ResponseExt, - utils::{copy_largebuf, test_mode, BUFFER_SIZE}, + utils::{copy_largebuf, BUFFER_SIZE}, }; const BINARY_SUPPRESSOR: &str = concat!( @@ -345,9 +345,7 @@ impl Printer { // even know if we're going to use HTTP/2 yet. headers.entry(HOST).or_insert_with(|| { // Added at https://github.com/hyperium/hyper-util/blob/53aadac50d/src/client/legacy/client.rs#L278 - if test_mode() { - HeaderValue::from_str("http.mock") - } else if let Some(port) = request.url().port() { + if let Some(port) = request.url().port() { HeaderValue::from_str(&format!("{}:{}", host, port)) } else { HeaderValue::from_str(host) diff --git a/src/unix_socket.rs b/src/unix_socket.rs index ec05255d..bb0898af 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -5,7 +5,6 @@ use std::path::PathBuf; use std::time::Instant; use crate::middleware::{Context, Middleware, ResponseMeta}; -use crate::utils::test_mode; pub struct UnixSocket { rt: tokio::runtime::Runtime, @@ -84,9 +83,7 @@ fn into_async_request(mut request: Request) -> Result Date: Mon, 30 Dec 2024 14:31:00 +0200 Subject: [PATCH 10/21] use shortened version of cfg unix check --- src/main.rs | 8 ++++---- tests/cases/http_unix.rs | 10 +++++----- tests/server/mod.rs | 22 +++++++++++----------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/main.rs b/src/main.rs index c0a845fe..47e72259 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,7 @@ mod redirect; mod request_items; mod session; mod to_curl; -#[cfg(target_family = "unix")] +#[cfg(unix)] mod unix_socket; mod utils; mod vendored; @@ -48,7 +48,7 @@ use crate::middleware::ClientWithMiddleware; use crate::printer::Printer; use crate::request_items::{Body, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; use crate::session::Session; -#[cfg(target_family = "unix")] +#[cfg(unix)] use crate::unix_socket::UnixSocket; use crate::utils::{reason_phrase, test_mode, test_pretend_term, url_with_query}; use crate::vendored::reqwest_cookie_store; @@ -554,11 +554,11 @@ fn run(args: Cli) -> Result { client = client.with(DigestAuthMiddleware::new(username, password)); } client = client.with(CookieMiddleware::new(cookie_jar.clone())); - #[cfg(target_family = "unix")] + #[cfg(unix)] if let Some(unix_socket) = args.unix_socket { client = client.with(UnixSocket::new(unix_socket)); } - #[cfg(not(target_family = "unix"))] + #[cfg(not(unix))] if let Some(_) = args.unix_socket { log::warn!("HTTP over Unix domain sockets is not supported on this platform"); } diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index acb77067..c393f8e3 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -1,10 +1,10 @@ -#[cfg(target_family = "unix")] +#[cfg(unix)] use indoc::indoc; -#[cfg(target_family = "unix")] +#[cfg(unix)] use crate::prelude::*; -#[cfg(target_family = "unix")] +#[cfg(unix)] #[test] fn json_post() { let server = server::http_unix(|req| async move { @@ -38,7 +38,7 @@ fn json_post() { "#}); } -#[cfg(target_family = "unix")] +#[cfg(unix)] #[test] fn redirects_stay_on_same_server() { let server = server::http_unix(|req| async move { @@ -119,7 +119,7 @@ fn redirects_stay_on_same_server() { server.assert_hits(3); } -#[cfg(target_family = "unix")] +#[cfg(unix)] #[test] fn cookies_persist_across_redirects() { let server = server::http_unix(|req| async move { diff --git a/tests/server/mod.rs b/tests/server/mod.rs index 16935bac..97f81579 100644 --- a/tests/server/mod.rs +++ b/tests/server/mod.rs @@ -20,13 +20,13 @@ type Builder = hyper_util::server::conn::auto::Builder String { match self.addr { Addr::TcpAddr(addr) => format!("http://{}", addr), - #[cfg(target_family = "unix")] + #[cfg(unix)] _ => panic!("no base_url for unix server"), } } @@ -51,7 +51,7 @@ impl Server { pub fn url(&self, path: &str) -> String { match self.addr { Addr::TcpAddr(addr) => format!("http://{}{}", addr, path), - #[cfg(target_family = "unix")] + #[cfg(unix)] _ => panic!("no url for unix server"), } } @@ -59,12 +59,12 @@ impl Server { pub fn host(&self) -> String { match self.addr { Addr::TcpAddr(_) => String::from("127.0.0.1"), - #[cfg(target_family = "unix")] + #[cfg(unix)] _ => panic!("no host for unix server"), } } - #[cfg(target_family = "unix")] + #[cfg(unix)] pub fn socket_path(&self) -> PathBuf { match &self.addr { Addr::UnixAddr(addr) => addr.as_pathname().unwrap().to_path_buf(), @@ -75,7 +75,7 @@ impl Server { pub fn port(&self) -> u16 { match self.addr { Addr::TcpAddr(addr) => addr.port(), - #[cfg(target_family = "unix")] + #[cfg(unix)] _ => panic!("no port for unix server"), } } @@ -128,7 +128,7 @@ where http_inner(Arc::new(move |req| Box::new(Box::pin(func(req)))), None) } -#[cfg(target_family = "unix")] +#[cfg(unix)] pub fn http_unix(func: F) -> Server where F: Fn(Request) -> Fut + Send + Sync + 'static, @@ -167,14 +167,14 @@ fn http_inner(func: Arc, socket_path: Option) -> Server { let (listener, addr) = rt.block_on(async move { #[allow(unused_variables)] if let Some(path) = &socket_path { - #[cfg(target_family = "unix")] + #[cfg(unix)] { let listener = tokio::net::UnixListener::bind(path).unwrap(); let addr = listener.local_addr().unwrap(); (Listener::UnixListener(listener), Addr::UnixAddr(addr)) } - #[cfg(not(target_family = "unix"))] + #[cfg(not(unix))] { unreachable!("cannot create http_unix server outside of unix target_family") } @@ -232,7 +232,7 @@ fn http_inner(func: Arc, socket_path: Option) -> Server { .await; }); } - #[cfg(target_family = "unix")] + #[cfg(unix)] Listener::UnixListener(listener) => { let (io, _) = listener.accept().await.unwrap(); tokio::spawn(async move { From c7fb645c8908b391c786907fa353233026522151 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 14:48:41 +0200 Subject: [PATCH 11/21] throw an error if unix-socket used in unsupported os --- src/main.rs | 4 +++- tests/cases/http_unix.rs | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 47e72259..4db4f6a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -560,7 +560,9 @@ fn run(args: Cli) -> Result { } #[cfg(not(unix))] if let Some(_) = args.unix_socket { - log::warn!("HTTP over Unix domain sockets is not supported on this platform"); + return Err(anyhow!( + "HTTP over Unix domain sockets is not supported on this platform" + )); } client.execute(request, |prev_response, next_request| { if !args.all { diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index c393f8e3..d00101f3 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -4,6 +4,21 @@ use indoc::indoc; #[cfg(unix)] use crate::prelude::*; +#[cfg(not(unix))] +#[test] +fn error_on_unsupported_platform() { + use predicates::str::contains; + + get_command() + .arg(format!("--unix-socket=/tmp/missing.sock",)) + .arg(":/index.html") + .assert() + .failure() + .stderr(contains( + "HTTP over Unix domain sockets is not supported on this platform", + )); +} + #[cfg(unix)] #[test] fn json_post() { From e0791b37cabb2876f2205a5d7a1b145b7a00f7ff Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 14:51:01 +0200 Subject: [PATCH 12/21] provide complete example for unix-socket usage --- src/cli.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli.rs b/src/cli.rs index 5f6f741b..7bf7f0f4 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -349,7 +349,7 @@ Example: --print=Hb" /// Connect using a Unix domain socket. /// - /// Example: --unix_socket=/var/run/temp.sock + /// Example: xh :/index.html --unix-socket=/var/run/temp.sock #[clap(long, value_name = "FILE")] pub unix_socket: Option, From f09a19107aee91acef97dd837e242b6bdb41b836 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 14:58:49 +0200 Subject: [PATCH 13/21] fix missing import --- tests/cases/http_unix.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index d00101f3..87390419 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -1,7 +1,6 @@ #[cfg(unix)] use indoc::indoc; -#[cfg(unix)] use crate::prelude::*; #[cfg(not(unix))] From 67730e1ac1d4257a6985ab146bc945a11b6e5c04 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Mon, 30 Dec 2024 16:24:38 +0200 Subject: [PATCH 14/21] check that host header is passed --- tests/cases/http_unix.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index 87390419..d5590fa0 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -24,6 +24,7 @@ fn json_post() { let server = server::http_unix(|req| async move { assert_eq!(req.method(), "POST"); assert_eq!(req.headers()["Content-Type"], "application/json"); + assert_eq!(req.headers()["Host"], "example.com"); assert_eq!(req.body_as_string().await, "{\"foo\":\"bar\"}"); hyper::Response::builder() From 0a8040728ea4a7784cad0fd4c59d259cd2b7f30e Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Tue, 31 Dec 2024 20:38:49 +0200 Subject: [PATCH 15/21] store unix_client in ClientWithMiddleware --- src/main.rs | 13 ++--------- src/middleware.rs | 56 ++++++++++++++++++++++++++++++++++------------ src/unix_socket.rs | 36 ++++------------------------- 3 files changed, 48 insertions(+), 57 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1467dc71..a07b5fc8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,8 +48,6 @@ use crate::middleware::ClientWithMiddleware; use crate::printer::Printer; use crate::request_items::{Body, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; use crate::session::Session; -#[cfg(unix)] -use crate::unix_socket::UnixSocket; use crate::utils::{reason_phrase, test_mode, test_pretend_term, url_with_query}; use crate::vendored::reqwest_cookie_store; @@ -573,15 +571,8 @@ fn run(args: Cli) -> Result { client = client.with(DigestAuthMiddleware::new(username, password)); } client = client.with(CookieMiddleware::new(cookie_jar.clone())); - #[cfg(unix)] - if let Some(unix_socket) = args.unix_socket { - client = client.with(UnixSocket::new(unix_socket)); - } - #[cfg(not(unix))] - if let Some(_) = args.unix_socket { - return Err(anyhow!( - "HTTP over Unix domain sockets is not supported on this platform" - )); + if let Some(socket_path) = args.unix_socket { + client = client.with_unix_socket(socket_path)?; } client.execute(request, |prev_response, next_request| { if !args.all { diff --git a/src/middleware.rs b/src/middleware.rs index c0f61be9..f335b06e 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,7 +1,10 @@ -use std::time::{Duration, Instant}; +use std::{ + path::PathBuf, + time::{Duration, Instant}, +}; use anyhow::Result; -use reqwest::blocking::{Client, Request, Response}; +use reqwest::blocking::{Request, Response}; #[derive(Clone)] pub struct ResponseMeta { @@ -26,15 +29,15 @@ impl ResponseExt for Response { type Printer<'a> = &'a mut (dyn FnMut(&mut Response, &mut Request) -> Result<()> + 'a); -pub struct Context<'a, 'b, 'c> { - client: Client, +pub struct Context<'a, 'b, 'c, 'd> { + client: &'d Client, printer: Printer<'c>, middlewares: &'a mut [Box], } -impl<'a, 'b, 'c> Context<'a, 'b, 'c> { +impl<'a, 'b, 'c, 'd> Context<'a, 'b, 'c, 'd> { fn new( - client: Client, + client: &'d Client, printer: Printer<'c>, middlewares: &'a mut [Box], ) -> Self { @@ -49,17 +52,20 @@ impl<'a, 'b, 'c> Context<'a, 'b, 'c> { match self.middlewares { [] => { let starting_time = Instant::now(); - let mut response = self.client.execute(request)?; + let mut response = match self.client { + Client::Http(client) => client.execute(request)?, + #[cfg(unix)] + Client::Unix(client) => client.execute(request)?, + }; response.extensions_mut().insert(ResponseMeta { request_duration: starting_time.elapsed(), content_download_duration: None, }); Ok(response) } - [ref mut head, tail @ ..] => head.handle( - Context::new(self.client.clone(), self.printer, tail), - request, - ), + [ref mut head, tail @ ..] => { + head.handle(Context::new(self.client, self.printer, tail), request) + } } } } @@ -82,19 +88,41 @@ pub trait Middleware { } } +enum Client { + Http(reqwest::blocking::Client), + #[cfg(unix)] + Unix(crate::unix_socket::UnixClient), +} + pub struct ClientWithMiddleware<'a> { client: Client, middlewares: Vec>, } impl<'a> ClientWithMiddleware<'a> { - pub fn new(client: Client) -> Self { + pub fn new(client: reqwest::blocking::Client) -> Self { ClientWithMiddleware { - client, + client: Client::Http(client), middlewares: vec![], } } + #[allow(unused)] + pub fn with_unix_socket(mut self, socket_path: PathBuf) -> Result { + #[cfg(not(unix))] + { + return Err(anyhow::anyhow!( + "HTTP over Unix domain sockets is not supported on this platform" + )); + } + + #[cfg(unix)] + { + self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path)); + Ok(self) + } + } + pub fn with(mut self, middleware: impl Middleware + 'a) -> Self { self.middlewares.push(Box::new(middleware)); self @@ -104,7 +132,7 @@ impl<'a> ClientWithMiddleware<'a> { where T: FnMut(&mut Response, &mut Request) -> Result<()> + 'b, { - let mut ctx = Context::new(self.client.clone(), &mut printer, &mut self.middlewares[..]); + let mut ctx = Context::new(&self.client, &mut printer, &mut self.middlewares[..]); ctx.execute(request) } } diff --git a/src/unix_socket.rs b/src/unix_socket.rs index bb0898af..713b3723 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -2,16 +2,13 @@ use anyhow::Result; use reqwest::blocking::{Request, Response}; use reqwest::header::{HeaderValue, HOST}; use std::path::PathBuf; -use std::time::Instant; -use crate::middleware::{Context, Middleware, ResponseMeta}; - -pub struct UnixSocket { +pub struct UnixClient { rt: tokio::runtime::Runtime, socket_path: PathBuf, } -impl UnixSocket { +impl UnixClient { pub fn new(socket_path: PathBuf) -> Self { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -23,21 +20,8 @@ impl UnixSocket { pub fn execute(&self, request: Request) -> Result { self.rt.block_on(async { - // TODO: Support named pipes by replacing tokio::net::UnixStream::connect(..) with: - // - // use std::time::Duration; - // use tokio::net::windows::named_pipe; - // use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; - // - // let stream = loop { - // match named_pipe::ClientOptions::new().open(r"\\.\pipe\docker_engine") { - // Ok(client) => break client, - // Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), - // Err(e) => return Err(e)?, - // } - // - // tokio::time::sleep(Duration::from_millis(50)).await; - // }; + // TODO: Add support for Windows named pipes by replacing UnixStream with namedPipeClient. + // See https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.ClientOptions.html#method.open let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; let (mut sender, conn) = hyper::client::conn::http1::Builder::new() @@ -60,18 +44,6 @@ impl UnixSocket { } } -impl Middleware for UnixSocket { - fn handle(&mut self, mut _ctx: Context, request: Request) -> Result { - let starting_time = Instant::now(); - let mut response = self.execute(request)?; - response.extensions_mut().insert(ResponseMeta { - request_duration: starting_time.elapsed(), - content_download_duration: None, - }); - Ok(response) - } -} - fn into_async_request(mut request: Request) -> Result> { let mut http_request = http::Request::builder() .version(request.version()) From 31c442058ec8bfc4fa5614c9e0bdf4957108c5e7 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Tue, 7 Jan 2025 12:42:09 +0200 Subject: [PATCH 16/21] warn or error if unix-socket used with unsupported option --- src/cli.rs | 8 ++++++-- src/main.rs | 14 +++++++++++++- src/middleware.rs | 16 +++------------- src/unix_socket.rs | 1 - tests/cases/http_unix.rs | 21 ++++++++++++++++++++- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index e02c32ae..840ee9fb 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -350,7 +350,11 @@ Example: --print=Hb" /// Connect using a Unix domain socket. /// /// Example: xh :/index.html --unix-socket=/var/run/temp.sock - #[clap(long, value_name = "FILE")] + #[clap( + long, + value_name = "FILE", + conflicts_with_all=["proxy", "verify", "cert", "cert_key", "ssl", "resolve", "interface", "ipv4", "ipv6", "https", "http_version"] + )] pub unix_socket: Option, /// Do not attempt to read stdin. @@ -1017,7 +1021,7 @@ impl FromStr for Print { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Timeout(Duration); impl Timeout { diff --git a/src/main.rs b/src/main.rs index 78a8e808..d6f9db31 100644 --- a/src/main.rs +++ b/src/main.rs @@ -578,7 +578,19 @@ fn run(args: Cli) -> Result { } client = client.with(CookieMiddleware::new(cookie_jar.clone())); if let Some(socket_path) = args.unix_socket { - client = client.with_unix_socket(socket_path)?; + #[cfg(not(unix))] + { + return Err(anyhow::anyhow!( + "HTTP over Unix domain sockets is not supported on this platform" + )); + } + #[cfg(unix)] + { + if (args.timeout.and_then(|t| t.as_duration())).is_some() { + log::warn!("Timeout is not supported for HTTP over Unix domain sockets"); + } + client = client.with_unix_socket(socket_path)?; + } } client.execute(request, |prev_response, next_request| { if !args.all { diff --git a/src/middleware.rs b/src/middleware.rs index f335b06e..d7c6d126 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -107,20 +107,10 @@ impl<'a> ClientWithMiddleware<'a> { } } - #[allow(unused)] + #[cfg(unix)] pub fn with_unix_socket(mut self, socket_path: PathBuf) -> Result { - #[cfg(not(unix))] - { - return Err(anyhow::anyhow!( - "HTTP over Unix domain sockets is not supported on this platform" - )); - } - - #[cfg(unix)] - { - self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path)); - Ok(self) - } + self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path)); + Ok(self) } pub fn with(mut self, middleware: impl Middleware + 'a) -> Self { diff --git a/src/unix_socket.rs b/src/unix_socket.rs index 713b3723..2d1e87c6 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -35,7 +35,6 @@ impl UnixClient { } }); - // TODO: don't ignore value from --timeout option let http_request = into_async_request(request)?; let response = sender.send_request(http_request).await?; diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index d5590fa0..9ad2b653 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -196,4 +196,23 @@ fn cookies_persist_across_redirects() { "#}); } -// TODO: add tests for connection timeout +#[cfg(unix)] +#[test] +fn timeout_is_unsupported_warning() { + let server = server::http_unix(|_req| async move { + hyper::Response::builder() + .header(hyper::header::CONTENT_TYPE, "application/json") + .body(r#"{"status":"ok"}"#.into()) + .unwrap() + }); + + get_command() + .arg(":") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--timeout=30") + .assert() + .stderr("xh: warning: Timeout is not supported for HTTP over Unix domain sockets\n"); +} From 8bcbb0da90859f4145a6ad4d0a1fe2ed1d533978 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Tue, 7 Jan 2025 12:45:44 +0200 Subject: [PATCH 17/21] disable failing badssl.com tests --- tests/cli.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/cli.rs b/tests/cli.rs index c79dbbb4..306afce2 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -1282,6 +1282,7 @@ fn native_tls_works() { } #[cfg(feature = "online-tests")] +#[ignore = "404 errors"] #[test] fn good_tls_version() { get_command() @@ -1292,6 +1293,7 @@ fn good_tls_version() { } #[cfg(all(feature = "native-tls", feature = "online-tests"))] +#[ignore = "404 errors"] #[test] fn good_tls_version_nativetls() { get_command() From d6019468ff15f0e4c1039593af991eb9ad9088eb Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Tue, 7 Jan 2025 23:46:34 +0200 Subject: [PATCH 18/21] implement read timeout for unix_socket requests --- Cargo.lock | 2 + Cargo.toml | 2 + src/main.rs | 18 ++++-- src/middleware.rs | 8 ++- src/unix_socket.rs | 116 +++++++++++++++++++++++++++++++++++++-- tests/cases/http_unix.rs | 19 +++---- 6 files changed, 143 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a01c232e..ef86b999 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2817,6 +2817,7 @@ dependencies = [ "env_logger", "flate2", "form_urlencoded", + "futures-core", "http", "http-body-util", "hyper", @@ -2833,6 +2834,7 @@ dependencies = [ "once_cell", "os_display", "pem", + "pin-project-lite", "predicates", "rand", "regex-lite", diff --git a/Cargo.toml b/Cargo.toml index 796143e2..4a6c3475 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ dirs = "5.0" encoding_rs = "0.8.28" encoding_rs_io = "0.1.7" flate2 = "1.0.22" +futures-core = { version = "0.3.28", default-features = false } http = "1" # Add "tracing" feature to hyper once it stabilizes hyper = { version = "1.2", default-features = false } @@ -40,6 +41,7 @@ mime_guess = "2.0" once_cell = "1.8.0" os_display = "0.1.3" pem = "3.0" +pin-project-lite = "0.2" regex-lite = "0.1.5" roff = "0.2.1" rpassword = "7.2.0" diff --git a/src/main.rs b/src/main.rs index d6f9db31..73f78305 100644 --- a/src/main.rs +++ b/src/main.rs @@ -91,6 +91,15 @@ fn main() { eprintln!(); eprintln!("Try running without the --native-tls flag."); } + if msg.starts_with("deadline has elapsed") { + process::exit(2); + } + #[cfg(unix)] + { + if err.downcast_ref::().is_some() { + process::exit(2); + } + } if let Some(err) = err.downcast_ref::() { if err.is_timeout() { process::exit(2); @@ -155,6 +164,7 @@ fn run(args: Cli) -> Result { .http1_title_case_headers() .http2_adaptive_window(true) .redirect(reqwest::redirect::Policy::none()) + // TODO: replace with connect_timeout + read_timeout .timeout(args.timeout.and_then(|t| t.as_duration())) .no_gzip() .no_deflate() @@ -586,10 +596,10 @@ fn run(args: Cli) -> Result { } #[cfg(unix)] { - if (args.timeout.and_then(|t| t.as_duration())).is_some() { - log::warn!("Timeout is not supported for HTTP over Unix domain sockets"); - } - client = client.with_unix_socket(socket_path)?; + client = client.with_unix_socket( + socket_path, + args.timeout.and_then(|t| t.as_duration()), + )?; } } client.execute(request, |prev_response, next_request| { diff --git a/src/middleware.rs b/src/middleware.rs index d7c6d126..d3bac384 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -108,8 +108,12 @@ impl<'a> ClientWithMiddleware<'a> { } #[cfg(unix)] - pub fn with_unix_socket(mut self, socket_path: PathBuf) -> Result { - self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path)); + pub fn with_unix_socket( + mut self, + socket_path: PathBuf, + timeout: Option, + ) -> Result { + self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path, timeout)); Ok(self) } diff --git a/src/unix_socket.rs b/src/unix_socket.rs index 2d1e87c6..aa7104df 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -1,29 +1,44 @@ -use anyhow::Result; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use pin_project_lite::pin_project; use reqwest::blocking::{Request, Response}; use reqwest::header::{HeaderValue, HOST}; -use std::path::PathBuf; +use tokio::time::Sleep; pub struct UnixClient { rt: tokio::runtime::Runtime, socket_path: PathBuf, + timeout: Option, } impl UnixClient { - pub fn new(socket_path: PathBuf) -> Self { + pub fn new(socket_path: PathBuf, timeout: Option) -> Self { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); - Self { rt, socket_path } + Self { + rt, + socket_path, + timeout, + } } pub fn execute(&self, request: Request) -> Result { self.rt.block_on(async { // TODO: Add support for Windows named pipes by replacing UnixStream with namedPipeClient. // See https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.ClientOptions.html#method.open + + // TODO: connection timeout?? let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; + // TODO: connection timeout let (mut sender, conn) = hyper::client::conn::http1::Builder::new() .title_case_headers(true) .handshake(hyper_util::rt::TokioIo::new(stream)) @@ -36,9 +51,22 @@ impl UnixClient { }); let http_request = into_async_request(request)?; - let response = sender.send_request(http_request).await?; - Ok(Response::from(response.map(reqwest::Body::wrap))) + let response = if let Some(timeout) = self.timeout { + tokio::time::timeout(timeout, sender.send_request(http_request)) + .await + .map_err(|_| anyhow!(TimeoutError))? + } else { + sender.send_request(http_request).await + }?; + + Ok(Response::from(response.map(|body| { + if let Some(timeout) = self.timeout { + reqwest::Body::wrap(ReadTimeoutBody::new(body, timeout)) + } else { + reqwest::Body::wrap(body) + } + }))) }) } } @@ -69,3 +97,79 @@ fn into_async_request(mut request: Request) -> Result std::fmt::Result { + write!(f, "operation timed out") + } +} + +// Copied from https://github.com/seanmonstar/reqwest/blob/8b8fdd2552ad645c7e9dd494930b3e95e2aedef2/src/async_impl/body.rs#L347 +// with some slight tweaks +pin_project! { + pub(crate) struct ReadTimeoutBody { + #[pin] + inner: B, + #[pin] + sleep: Option, + timeout: Duration, + } +} + +impl ReadTimeoutBody { + fn new(body: B, timeout: Duration) -> ReadTimeoutBody { + ReadTimeoutBody { + inner: body, + sleep: None, + timeout, + } + } +} + +impl hyper::body::Body for ReadTimeoutBody +where + B: hyper::body::Body, + B::Error: Into>, +{ + type Data = B::Data; + type Error = anyhow::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + let mut this = self.project(); + + // Start the `Sleep` if not active. + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { + some + } else { + this.sleep.set(Some(tokio::time::sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() + }; + + // Error if the timeout has expired. + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Some(Err(anyhow!(TimeoutError)))); + } + + let item = futures_core::ready!(this.inner.poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(|e| anyhow!(e.into()))); + // a ready frame means timeout is reset + this.sleep.set(None); + Poll::Ready(item) + } + + #[inline] + fn size_hint(&self) -> hyper::body::SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs index 9ad2b653..8c2012eb 100644 --- a/tests/cases/http_unix.rs +++ b/tests/cases/http_unix.rs @@ -1,13 +1,12 @@ #[cfg(unix)] use indoc::indoc; +use predicates::str::contains; use crate::prelude::*; #[cfg(not(unix))] #[test] fn error_on_unsupported_platform() { - use predicates::str::contains; - get_command() .arg(format!("--unix-socket=/tmp/missing.sock",)) .arg(":/index.html") @@ -198,13 +197,12 @@ fn cookies_persist_across_redirects() { #[cfg(unix)] #[test] -fn timeout_is_unsupported_warning() { - let server = server::http_unix(|_req| async move { - hyper::Response::builder() - .header(hyper::header::CONTENT_TYPE, "application/json") - .body(r#"{"status":"ok"}"#.into()) - .unwrap() +fn timeout() { + let mut server = server::http_unix(|_req| async move { + tokio::time::sleep(std::time::Duration::from_secs_f32(0.5)).await; + hyper::Response::default() }); + server.disable_hit_checks(); get_command() .arg(":") @@ -212,7 +210,8 @@ fn timeout_is_unsupported_warning() { "--unix-socket={}", server.socket_path().to_string_lossy() )) - .arg("--timeout=30") + .arg("--timeout=0.1") .assert() - .stderr("xh: warning: Timeout is not supported for HTTP over Unix domain sockets\n"); + .code(2) + .stderr(contains("operation timed out")); } From c84e8ef7522a302acab352292ebe5359435d3572 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Wed, 8 Jan 2025 21:46:00 +0200 Subject: [PATCH 19/21] implement connect timeout for unix_socket requests --- src/unix_socket.rs | 58 ++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/src/unix_socket.rs b/src/unix_socket.rs index aa7104df..860f72f8 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -30,35 +30,30 @@ impl UnixClient { } } - pub fn execute(&self, request: Request) -> Result { - self.rt.block_on(async { - // TODO: Add support for Windows named pipes by replacing UnixStream with namedPipeClient. - // See https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.ClientOptions.html#method.open - - // TODO: connection timeout?? - let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; - - // TODO: connection timeout - let (mut sender, conn) = hyper::client::conn::http1::Builder::new() - .title_case_headers(true) - .handshake(hyper_util::rt::TokioIo::new(stream)) - .await?; + async fn connect(&self) -> Result> { + // TODO: Add support for Windows named pipes by replacing UnixStream with namedPipeClient. + // See https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.ClientOptions.html#method.open + let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; + let (sender, conn) = hyper::client::conn::http1::Builder::new() + .title_case_headers(true) + .handshake(hyper_util::rt::TokioIo::new(stream)) + .await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + log::error!("Connection failed: {:?}", err); + } + }); - tokio::task::spawn(async move { - if let Err(err) = conn.await { - log::error!("Connection failed: {:?}", err); - } - }); + Ok(sender) + } + pub fn execute(&self, request: Request) -> Result { + self.rt.block_on(async { let http_request = into_async_request(request)?; - let response = if let Some(timeout) = self.timeout { - tokio::time::timeout(timeout, sender.send_request(http_request)) - .await - .map_err(|_| anyhow!(TimeoutError))? - } else { - sender.send_request(http_request).await - }?; + let mut sender = with_timeout(self.connect(), self.timeout).await??; + let response = with_timeout(sender.send_request(http_request), self.timeout).await??; Ok(Response::from(response.map(|body| { if let Some(timeout) = self.timeout { @@ -98,6 +93,19 @@ fn into_async_request(mut request: Request) -> Result(fut: F, timeout: Option) -> Result +where + F: std::future::IntoFuture, +{ + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, fut) + .await + .map_err(|_| anyhow!(TimeoutError)) + } else { + Ok(fut.await) + } +} + #[derive(Debug, Clone)] pub struct TimeoutError; From e362e7fe79ce283280e913a431ee40ae2d8197a2 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sat, 11 Jan 2025 10:25:28 +0200 Subject: [PATCH 20/21] switch from read timeout to total timeout the former is not yet supported in blocking client --- src/main.rs | 1 - src/unix_socket.rs | 47 ++++++++++++++++------------------------------ 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/main.rs b/src/main.rs index 73f78305..ab809919 100644 --- a/src/main.rs +++ b/src/main.rs @@ -164,7 +164,6 @@ fn run(args: Cli) -> Result { .http1_title_case_headers() .http2_adaptive_window(true) .redirect(reqwest::redirect::Policy::none()) - // TODO: replace with connect_timeout + read_timeout .timeout(args.timeout.and_then(|t| t.as_duration())) .no_gzip() .no_deflate() diff --git a/src/unix_socket.rs b/src/unix_socket.rs index 860f72f8..baae32f5 100644 --- a/src/unix_socket.rs +++ b/src/unix_socket.rs @@ -8,7 +8,6 @@ use anyhow::{anyhow, Result}; use pin_project_lite::pin_project; use reqwest::blocking::{Request, Response}; use reqwest::header::{HeaderValue, HOST}; -use tokio::time::Sleep; pub struct UnixClient { rt: tokio::runtime::Runtime, @@ -57,7 +56,7 @@ impl UnixClient { Ok(Response::from(response.map(|body| { if let Some(timeout) = self.timeout { - reqwest::Body::wrap(ReadTimeoutBody::new(body, timeout)) + reqwest::Body::wrap(TotalTimeoutBody::new(body, timeout)) } else { reqwest::Body::wrap(body) } @@ -115,29 +114,27 @@ impl std::fmt::Display for TimeoutError { } } -// Copied from https://github.com/seanmonstar/reqwest/blob/8b8fdd2552ad645c7e9dd494930b3e95e2aedef2/src/async_impl/body.rs#L347 +// Copied from https://github.com/seanmonstar/reqwest/blob/8b8fdd2552ad645c7e9dd494930b3e95e2aedef2/src/async_impl/body.rs#L314 // with some slight tweaks pin_project! { - pub(crate) struct ReadTimeoutBody { + pub(crate) struct TotalTimeoutBody { #[pin] inner: B, - #[pin] - sleep: Option, - timeout: Duration, + timeout: Pin>, } } -impl ReadTimeoutBody { - fn new(body: B, timeout: Duration) -> ReadTimeoutBody { - ReadTimeoutBody { +impl TotalTimeoutBody { + fn new(body: B, timeout: Duration) -> TotalTimeoutBody { + let total_timeout = Box::pin(tokio::time::sleep(timeout)); + TotalTimeoutBody { inner: body, - sleep: None, - timeout, + timeout: total_timeout, } } } -impl hyper::body::Body for ReadTimeoutBody +impl hyper::body::Body for TotalTimeoutBody where B: hyper::body::Body, B::Error: Into>, @@ -149,26 +146,14 @@ where self: Pin<&mut Self>, cx: &mut Context, ) -> Poll, Self::Error>>> { - let mut this = self.project(); - - // Start the `Sleep` if not active. - let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { - some - } else { - this.sleep.set(Some(tokio::time::sleep(*this.timeout))); - this.sleep.as_mut().as_pin_mut().unwrap() - }; - - // Error if the timeout has expired. - if let Poll::Ready(()) = sleep_pinned.poll(cx) { + let this = self.project(); + if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) { return Poll::Ready(Some(Err(anyhow!(TimeoutError)))); } - - let item = futures_core::ready!(this.inner.poll_frame(cx)) - .map(|opt_chunk| opt_chunk.map_err(|e| anyhow!(e.into()))); - // a ready frame means timeout is reset - this.sleep.set(None); - Poll::Ready(item) + Poll::Ready( + futures_core::ready!(this.inner.poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(|e| anyhow!(e.into()))), + ) } #[inline] From 92df488c3a966a24880acf42142a82bcaae74af9 Mon Sep 17 00:00:00 2001 From: Mohamed Daahir Date: Sat, 11 Jan 2025 10:27:59 +0200 Subject: [PATCH 21/21] revert disabling badssl tests --- tests/cli.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cli.rs b/tests/cli.rs index 306afce2..c79dbbb4 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -1282,7 +1282,6 @@ fn native_tls_works() { } #[cfg(feature = "online-tests")] -#[ignore = "404 errors"] #[test] fn good_tls_version() { get_command() @@ -1293,7 +1292,6 @@ fn good_tls_version() { } #[cfg(all(feature = "native-tls", feature = "online-tests"))] -#[ignore = "404 errors"] #[test] fn good_tls_version_nativetls() { get_command()