diff --git a/Cargo.lock b/Cargo.lock index 64023df..b9eb7c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1171,6 +1171,7 @@ dependencies = [ "regex", "reqwest", "scopeguard", + "socket2", "tao", "tempfile", "tikv-jemallocator", @@ -1316,6 +1317,7 @@ dependencies = [ "pin-project-lite", "smallvec", "tokio", + "want", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 2fe491b..131c8f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ futures = "0.3" hex = "0.4" http-body = "1.0" http-body-util = "0.1" -hyper = { version = "1.2", features = ["http1", "server"] } +hyper = { version = "1.2", features = ["client", "http1", "server"] } hyper-util = { version = "0.1", features = ["tokio"] } inquire = "0.7" log = { version = "0.4", features = ["std"] } @@ -30,6 +30,7 @@ rand = { version = "0.8", features = ["small_rng"] } regex = "1.10" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream", "socks"] } scopeguard = "1.2" +socket2 = "0.5" tempfile = "3.10" tokio = { version = "1", features = ["full", "parking_lot"] } tokio-openssl = "0.6" diff --git a/src/error.rs b/src/error.rs index 494aa0a..d977bb1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,6 @@ use core::fmt; +#[allow(clippy::enum_variant_names)] #[derive(Debug)] pub enum Error { VersionTooOld, @@ -7,6 +8,7 @@ pub enum Error { ConnectTestFail, InitSettingsMissing(String), HashMismatch { expected: [u8; 20], actual: [u8; 20] }, + ServerError { status: u16, body: Option }, } impl Error { @@ -28,6 +30,7 @@ impl fmt::Display for Error { Error::ConnectTestFail => write!(f, "Connect test failed"), Error::InitSettingsMissing(settings) => write!(f, "Missing init settings: {settings}"), Error::HashMismatch { expected, actual } => write!(f, "Hash missmatch. Expected={expected:?}, Actual={actual:?}"), + Error::ServerError { status, body } => write!(f, "Status={status}, Body={}", body.clone().unwrap_or_default()), } } } diff --git a/src/main.rs b/src/main.rs index 748b956..85b28c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,6 +41,7 @@ mod logger; mod middleware; mod route; mod rpc; +mod rpc_http_client; mod server; mod util; diff --git a/src/rpc.rs b/src/rpc.rs index 59d5b38..2e70c75 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -23,11 +23,7 @@ use parking_lot::{RwLock, RwLockUpgradableReadGuard}; use rand::prelude::SliceRandom; use reqwest::{IntoUrl, Url}; -use crate::{ - error::Error, - gallery_downloader::GalleryMeta, - util::{create_http_client, string_to_hash}, -}; +use crate::{error::Error, gallery_downloader::GalleryMeta, rpc_http_client::RPCHttpClient, util::string_to_hash}; const API_VERSION: i32 = 160; // For server check capabilities. const DEFAULT_SERVER: &str = "rpc.hentaiathome.net"; @@ -39,7 +35,7 @@ pub struct RPCClient { clock_offset: AtomicI64, id: i32, key: String, - reqwest: reqwest::Client, + http_client: RPCHttpClient, rpc_servers: RwLock>, running: AtomicBool, settings: Arc, @@ -128,7 +124,7 @@ impl RPCClient { clock_offset: AtomicI64::new(0), id, key: key.to_string(), - reqwest: create_http_client(Duration::from_secs(600), None), + http_client: RPCHttpClient::new(Duration::from_secs(600)), rpc_servers: RwLock::new(vec![]), running: AtomicBool::new(false), settings: Arc::new(Settings { @@ -206,11 +202,9 @@ impl RPCClient { pub async fn get_cert(&self) -> Option { let _provider = Provider::try_load(None, "legacy", true).unwrap(); let cert = self - .reqwest + .http_client .get(self.build_url("get_cert", "", None)) - .send() - .and_then(|res| async { res.error_for_status() }) - .and_then(|res| res.bytes()) + .and_then(|res| self.http_client.to_bytes(res)) .await .ok() .and_then(|data| Pkcs12::from_der(&data[..]).ok()) @@ -475,10 +469,9 @@ The program will now terminate. return Ok(response); } Err(err) => { - if err.is_connect() || err.is_timeout() || err.status().map_or(false, |s| s.is_server_error()) { - self.change_server(); - } - error = Box::new(err); + error!("Send request error: {}", err); + self.change_server(); + error = err; } } retry -= 1; @@ -487,17 +480,14 @@ The program will now terminate. Err(error) } - async fn send_request(&self, url: U) -> Result { - let res = self.reqwest.get(url).timeout(Duration::from_secs(600)).send().await?; - - if let Err(err) = res.error_for_status_ref() { - let status = res.status(); - let body = res.text().await.unwrap_or_default(); - warn!("Server response error: code={}, body={}", status, body); - return Err(err); + async fn send_request(&self, url: U) -> Result { + match self.http_client.get(url).await { + Ok(res) => self.http_client.to_text(res).await, + Err(err) => { + warn!("Server response error: {}", err); + Err(err) + } } - - res.text().await } fn build_url(&self, action: &str, additional: &str, endpoint: Option<&str>) -> Url { diff --git a/src/rpc_http_client.rs b/src/rpc_http_client.rs new file mode 100644 index 0000000..31c3414 --- /dev/null +++ b/src/rpc_http_client.rs @@ -0,0 +1,169 @@ +use std::{ + io::{ + Error as IoError, + ErrorKind::{AddrNotAvailable, InvalidInput, TimedOut}, + }, + net::ToSocketAddrs, + sync::Arc, + time::Duration, +}; + +use bytes::Bytes; +use http_body_util::{BodyExt, Collected, Empty}; +use hyper::{body::Incoming, client::conn::http1::handshake, http::Error as HttpError, Request, Response}; +use hyper_util::rt::TokioIo; +use log::{debug, error}; +use parking_lot::Mutex; +use rand::{seq::IteratorRandom, thread_rng}; +use reqwest::{IntoUrl, Url}; +use socket2::{SockRef, TcpKeepalive}; +use tokio::{ + net::{TcpSocket, TcpStream}, + task::AbortHandle, + time::timeout, +}; + +use crate::CLIENT_VERSION; + +type Connection = Arc)>>>; +type RequestError = Box; + +#[derive(Default)] +pub struct RPCHttpClient { + timeout: Duration, + preconnect: Connection, + connecting: Mutex>, +} + +impl RPCHttpClient { + pub fn new(timeout: Duration) -> Self { + Self { + timeout, + ..Default::default() + } + } + + pub async fn get(&self, url: U) -> Result, RequestError> { + let url = url.into_url()?; + let host = url.host().ok_or(IoError::new(InvalidInput, "uri has no host"))?; + let port = url.port().unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 }); + let server = format!("{host}:{port}"); + + let conn = self.preconnect.lock().take(); + let conn = match conn.and_then(|(key, stream)| if key == server { Some(stream) } else { None }) { + Some(v) => { + // Connection timeout 5s + match timeout(Duration::from_secs(5), v.inner().writable()).await { + Ok(Ok(_)) => v, + _ => create_stream(&server).await?, + } + } + None => create_stream(&server).await?, + }; + let (mut sender, conn) = handshake(conn).await?; + // spawn a task to poll the connection and drive the HTTP state + tokio::spawn(async move { + if let Err(e) = conn.await { + error!("Error in connection: {}", e); + } + }); + + match timeout(self.timeout, sender.send_request(build_request(&url)?)).await.ok() { + Some(v) => { + self.preconnect(&server); + match v { + Ok(r) => self.check_status(r).await, + Err(e) => Err(e.into()), + } + } + None => Err(Box::new(IoError::new(TimedOut, format!("Request timeout: url={}", url)))), + } + } + + pub fn preconnect(&self, server: &str) { + // Connected + if self.preconnect.lock().is_some() { + return; + } + // Connecting + let mut job = self.connecting.lock(); + if !job.as_ref().map_or(true, |job| job.is_finished()) { + return; + } + + let server = server.to_owned(); + let preconnect = Arc::downgrade(&self.preconnect); + job.replace( + tokio::spawn(async move { + debug!("Preconnecting to {}", server); + match create_stream(&server).await { + Ok(stream) => { + debug!("Preconnected to {}", server); + if let Some(preconnect) = preconnect.upgrade() { + preconnect.lock().replace((server.to_owned(), stream)); + } + } + Err(err) => debug!("Pre connect error: {}", err), + } + }) + .abort_handle(), + ); + } + + async fn check_status(&self, res: Response) -> Result, RequestError> { + let status = res.status(); + if status.is_client_error() || status.is_server_error() { + Err(crate::error::Error::ServerError { + status: res.status().as_u16(), + body: self.to_bytes(res).await.ok().and_then(|b| String::from_utf8(b.to_vec()).ok()), + } + .into()) + } else { + Ok(res) + } + } + + pub async fn to_text(&self, res: Response) -> Result { + Ok(String::from_utf8(self.to_bytes(res).await?.to_vec())?) + } + + pub async fn to_bytes(&self, res: Response) -> Result { + match timeout(self.timeout, res.into_body().collect()).await.ok() { + Some(v) => Ok(v.map(Collected::to_bytes)?), + None => Err(Box::new(IoError::new(TimedOut, "Read response timeout".to_string()))), + } + } +} + +async fn create_stream(server: &str) -> Result, IoError> { + if let Ok(addrs) = server.to_socket_addrs() { + let addr = addrs + .choose(&mut thread_rng()) + .ok_or(IoError::new(AddrNotAvailable, format!("Fail to resolve {}", server)))?; + + // Socket settings + let socket = if addr.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; + let socket2 = SockRef::from(&socket); + let keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(30)) + .with_interval(Duration::from_secs(15)); + let _ = socket2.set_tcp_keepalive(&keepalive); + let _ = socket2.set_nodelay(true); + + // Connect + match timeout(Duration::from_secs(5), socket.connect(addr)).await.ok() { + Some(r) => r.map(TokioIo::new), + None => Err(IoError::new(TimedOut, format!("Connect timeout: addr={addr}"))), + } + } else { + Err(IoError::new(InvalidInput, format!("Fail parse host {}", server))) + } +} + +fn build_request(url: &Url) -> Result>, HttpError> { + Request::builder() + .header("Host", url.host_str().unwrap()) + .header("User-Agent", format!("Hentai@Home {CLIENT_VERSION}")) + .uri(url.as_str()) + .body(Empty::::new()) +}