From 1a814c238adc41c1801b383ef51a4fb26ef3f3bf Mon Sep 17 00:00:00 2001 From: james58899 Date: Fri, 27 Oct 2023 07:52:13 +0000 Subject: [PATCH] Add proxy option --- Cargo.lock | 13 +++++++++++++ Cargo.toml | 2 +- src/gallery_downloader.rs | 2 +- src/main.rs | 24 ++++++++++++++++++++++-- src/route/cache.rs | 17 +++++++++++++---- src/route/server_command.rs | 9 +++------ src/rpc.rs | 2 +- src/util.rs | 15 ++++++++++----- 8 files changed, 64 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 60b7c61..33d1730 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1642,6 +1642,7 @@ dependencies = [ "system-configuration", "tokio", "tokio-rustls", + "tokio-socks", "tokio-util", "tower-service", "url", @@ -2110,6 +2111,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" diff --git a/Cargo.toml b/Cargo.toml index 38dad17..07c739d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ parking_lot = { version = "0.12", features = ["hardware-lock-elision", "deadlock pin-project-lite = "0.2" rand = { version = "0.8", default-features = false, features = ["small_rng"] } regex = "1.10" -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream"] } +reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream", "socks"] } tempfile = "3.8" tokio = { version = "1", features = ["full", "parking_lot"] } tokio-stream = { version = "0.1", default-features = false, features = ["fs"] } diff --git a/src/gallery_downloader.rs b/src/gallery_downloader.rs index b6b16ec..eee2b1e 100644 --- a/src/gallery_downloader.rs +++ b/src/gallery_downloader.rs @@ -35,7 +35,7 @@ impl GalleryDownloader { pub fn new>(client: Arc, download_dir: P) -> GalleryDownloader { GalleryDownloader { client, - reqwest: util::create_http_client(Duration::from_secs(300)), + reqwest: util::create_http_client(Duration::from_secs(300), None), download_dir: download_dir.as_ref().to_path_buf(), } } diff --git a/src/main.rs b/src/main.rs index f7e3f82..c69ef83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,7 @@ use openssl::{ }; use parking_lot::{Mutex, RwLock}; use regex::Regex; +use reqwest::Proxy; use tempfile::TempPath; #[cfg(not(target_env = "msvc"))] use tikv_jemallocator::Jemalloc; @@ -130,6 +131,10 @@ struct Args { /// Disable server command ip check #[arg(long, default_value_t = false)] disable_ip_origin_check: bool, + + /// Configure proxy for fetch cache + #[arg(long)] + proxy: Option, } type DownloadState = RwLock, watch::Receiver)>>; @@ -141,6 +146,7 @@ struct AppState { download_state: DownloadState, cache_manager: Arc, command_channel: Sender, + has_proxy: bool, } pub enum Command { @@ -192,18 +198,32 @@ async fn main() -> Result<(), Box> { ) .await?; - // command channel + // Proxy + let proxy = match args.proxy.as_ref().map(Proxy::all) { + Some(Ok(proxy)) => { + info!("Using proxy for fetch cache: {}", args.proxy.unwrap()); + Some(proxy) + } + Some(Err(err)) => { + error!("Parser proxy setting error: {}", err); + None + } + None => None, + }; + let has_proxy = proxy.is_some(); + // Command channel let (tx, mut rx) = mpsc::channel::(1); let (server, cert_changer) = create_server( args.port.unwrap_or_else(|| init_settings.client_port()), client.get_cert().await.unwrap(), AppState { runtime: Handle::current(), - reqwest: create_http_client(Duration::from_secs(30)), + reqwest: create_http_client(Duration::from_secs(30), proxy), rpc: client.clone(), download_state: Default::default(), cache_manager: cache_manager.clone(), command_channel: tx.clone(), + has_proxy, }, ); let server_handle = server.handle(); diff --git a/src/route/cache.rs b/src/route/cache.rs index 2c71ad6..74791da 100644 --- a/src/route/cache.rs +++ b/src/route/cache.rs @@ -1,4 +1,4 @@ -use std::{io::SeekFrom, ops::RangeInclusive, sync::Arc}; +use std::{io::SeekFrom, ops::RangeInclusive, sync::Arc, time::Duration}; use actix_files::NamedFile; use actix_web::{ @@ -25,7 +25,7 @@ use tokio::{ use crate::{ cache_manager::CacheFileInfo, route::{forbidden, parse_additional}, - util::string_to_hash, + util::{create_http_client, string_to_hash}, AppState, }; @@ -120,8 +120,17 @@ async fn hath( }; let mut download = 0; - let request = data.reqwest.get(sources.next().unwrap()); - if let Ok(mut stream) = request.send().await.and_then(|r| r.error_for_status()).map(|r| r.bytes_stream()) { + let source = sources.next().unwrap(); + let mut request = data.reqwest.get(source).send().await; + if let Err(ref err) = request { + error!("Cache download error: {}", err); + + // Retry without proxy + if data.has_proxy && err.is_connect() { + request = create_http_client(Duration::from_secs(30), None).get(source).send().await; + } + }; + if let Ok(mut stream) = request.and_then(|r| r.error_for_status()).map(|r| r.bytes_stream()) { while let Some(bytes) = stream.next().await { let bytes = match &bytes { Ok(it) => it, diff --git a/src/route/server_command.rs b/src/route/server_command.rs index 53b8700..8511ee8 100644 --- a/src/route/server_command.rs +++ b/src/route/server_command.rs @@ -15,7 +15,7 @@ use reqwest::{ use crate::{ route::{forbidden, parse_additional, speed_test::random_response}, - util::string_to_hash, + util::{create_http_client, string_to_hash}, AppState, Command, MAX_KEY_TIME_DRIFT, }; @@ -88,16 +88,13 @@ async fn servercmd( ) .unwrap(); debug!("Speedtest thread start: {}", url); - let reqwest = data.reqwest.clone(); + let reqwest = create_http_client(Duration::from_secs(60), None); // No proxy http client requests.push(tokio::spawn(async move { for retry in 0..3 { if retry > 0 { debug!("Retrying.. ({} tries left)", 3 - retry); } - let request = reqwest - .get(url.clone()) - .header(CONNECTION, HeaderValue::from_static("Close")) - .timeout(Duration::from_secs(60)); + let request = reqwest.get(url.clone()).header(CONNECTION, HeaderValue::from_static("Close")); match request.send().await.and_then(|r| r.error_for_status()) { Ok(res) => { let start = Instant::now(); diff --git a/src/rpc.rs b/src/rpc.rs index 4fe2fe0..2c02a55 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -126,7 +126,7 @@ impl RPCClient { clock_offset: AtomicI64::new(0), id, key: key.to_string(), - reqwest: create_http_client(Duration::from_secs(600)), + reqwest: create_http_client(Duration::from_secs(600), None), rpc_servers: RwLock::new(vec![]), running: AtomicBool::new(false), settings: Arc::new(Settings { diff --git a/src/util.rs b/src/util.rs index 2b06025..0ddd80b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,6 +2,7 @@ use std::time::Duration; use futures::future::try_join_all; use openssl::sha::Sha1; +use reqwest::Proxy; use tokio::fs::create_dir_all; use crate::CLIENT_VERSION; @@ -12,8 +13,8 @@ pub fn string_to_hash(str: String) -> String { hex::encode(hasher.finish()) } -pub fn create_http_client(timeout: Duration) -> reqwest::Client { - reqwest::ClientBuilder::new() +pub fn create_http_client(timeout: Duration, proxy: Option) -> reqwest::Client { + let mut builder = reqwest::ClientBuilder::new() .user_agent(format!("Hentai@Home {CLIENT_VERSION}")) .tcp_keepalive(Duration::from_secs(75)) // Linux default keepalive inverval .connect_timeout(Duration::from_secs(5)) @@ -21,9 +22,13 @@ pub fn create_http_client(timeout: Duration) -> reqwest::Client { .pool_idle_timeout(Duration::from_secs(3600)) .pool_max_idle_per_host(8) .http1_title_case_headers() - .http1_only() - .build() - .unwrap() + .http1_only(); + + if let Some(proxy) = proxy { + builder = builder.proxy(proxy); + } + + builder.build().unwrap() } pub async fn create_dirs(dirs: Vec<&str>) -> Result, std::io::Error> {