Skip to content

Commit

Permalink
utils: Add custom header support
Browse files Browse the repository at this point in the history
Related: #17
  • Loading branch information
taoky committed Nov 21, 2024
1 parent 1070ccd commit f1912fc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ Options:
Ignore 404 NOT FOUND as error when downloading files
--auto-fallback
Allow automatically choose fallback parser when ParseError occurred
--header <HEADER>
Custom header for HTTP(S) requests in format "Headerkey: headervalue". Supports multiple
-h, --help
Print help
-V, --version
Expand All @@ -104,6 +106,7 @@ Options:
--exclude <EXCLUDE> Excluded relative path regex. Supports multiple
--include <INCLUDE> Included relative path regex (even if excluded). Supports multiple
--upstream-base <UPSTREAM_BASE> The upstream base starting with "/" [default: /]
--header <HEADER> Custom header for HTTP(S) requests in format "Headerkey: headervalue". Supports multiple
-h, --help Print help
-V, --version Print version
```
Expand Down
18 changes: 18 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tracing_subscriber::EnvFilter;
use url::Url;

use shadow_rs::shadow;
use utils::{headers_to_headermap, Header};
shadow!(build);

mod cli;
Expand Down Expand Up @@ -61,6 +62,7 @@ enum Commands {

trait SharedArgs {
fn user_agent(&self) -> &str;
fn headers(&self) -> reqwest::header::HeaderMap;
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -157,12 +159,20 @@ pub struct SyncArgs {
/// Allow automatically choose fallback parser when ParseError occurred.
#[clap(long)]
auto_fallback: bool,

/// Custom header for HTTP(S) requests in format "Headerkey: headervalue". Supports multiple.
#[clap(long, value_parser)]
header: Vec<Header>,
}

impl SharedArgs for &SyncArgs {
fn user_agent(&self) -> &str {
&self.user_agent
}

fn headers(&self) -> reqwest::header::HeaderMap {
headers_to_headermap(&self.header)
}
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -190,12 +200,20 @@ pub struct ListArgs {
/// The upstream base starting with "/".
#[clap(long, default_value = "/")]
upstream_base: String,

/// Custom header for HTTP(S) requests in format "Headerkey: headervalue". Supports multiple.
#[clap(long, value_parser)]
header: Vec<Header>,
}

impl SharedArgs for &ListArgs {
fn user_agent(&self) -> &str {
&self.user_agent
}

fn headers(&self) -> reqwest::header::HeaderMap {
headers_to_headermap(&self.header)
}
}

pub struct AsyncContext {
Expand Down
52 changes: 52 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,57 @@ fn proxy_precheck() {
}
}

// Helper structs for custom header support
#[derive(Debug, Clone)]
pub struct Header {
pub name: reqwest::header::HeaderName,
pub value: reqwest::header::HeaderValue,
}

pub fn headers_to_headermap(value: &[Header]) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
for header in value.iter() {
headers.insert(header.name.clone(), header.value.clone());
}
headers
}

#[derive(Debug)]
pub struct HeaderParseError;

impl std::fmt::Display for HeaderParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Failed to parse header")
}
}

impl std::error::Error for HeaderParseError {}

impl std::str::FromStr for Header {
type Err = HeaderParseError;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let parts: Vec<&str> = s.splitn(2, ':').collect();

let name = parts[0].trim();
let value = parts[1].trim();

if parts.len() != 2 {
return Err(HeaderParseError);
}

let header_name =
reqwest::header::HeaderName::from_str(name).map_err(|_| HeaderParseError)?;
let header_value =
reqwest::header::HeaderValue::from_str(value).map_err(|_| HeaderParseError)?;

Ok(Header {
name: header_name,
value: header_value,
})
}
}

pub fn build_client(
args: impl SharedArgs,
redirect: bool,
Expand All @@ -82,6 +133,7 @@ pub fn build_client(
let mut builder = reqwest::Client::builder()
.user_agent(args.user_agent())
.local_address(bind_address.map(|x| x.parse::<std::net::IpAddr>().unwrap()))
.default_headers(args.headers())
// hard code 1min connect/read timeout currently
.connect_timeout(minute)
.read_timeout(minute)
Expand Down

0 comments on commit f1912fc

Please sign in to comment.