diff --git a/Containerfile b/Containerfile index f54c4e8..124e2ad 100644 --- a/Containerfile +++ b/Containerfile @@ -10,7 +10,11 @@ LABEL maintainer="Jetsung Chan" # RUN apt-get update && apt-get install -y locales libssl3 libssl-dev && rm -rf /var/lib/apt/lists/* \ # && localedef -i zh_CN -c -f UTF-8 -A /usr/share/locale/locale.alias zh_CN.UTF-8 # ENV LANG zh_CN.utf8 -RUN apt update && apt install -y openssl libssl-dev && rm -rf /var/lib/apt/lists/* +RUN apt update && \ + apt install -y deborphan openssl libcurl4 libssl-dev && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/cache/apt/archives/* && \ + deborphan | xargs apt -y remove --purge WORKDIR /app diff --git a/src/main.rs b/src/main.rs index 056a54d..9a7bfd8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,144 +1,153 @@ use axum::{ - body::Body, extract::Path, http::{header, StatusCode}, response::{Html, IntoResponse, Response}, Router + body::Body, + extract::Path, + http::{header, StatusCode}, + response::{Html, IntoResponse, Response}, + routing::get, + Router, }; -use axum::routing::get; use reqwest::Client; -use std::{env, net::SocketAddr}; +use std::{env, net::SocketAddr, sync::Arc}; +use tokio::sync::OnceCell; use url::Url; -static FILE_EXT: &str = ""; // 文件扩展名(逗号分隔) +// 静态文件扩展名,用于判断是否为下载链接 +static FILE_EXT: &str = ""; // 逗号分隔的扩展名 + +// 共享的 HTTP 客户端 +static CLIENT: OnceCell> = OnceCell::const_new(); + +// 初始化 HTTP 客户端 +async fn get_client() -> Arc { + CLIENT + .get_or_init(|| async { Arc::new(Client::new()) }) + .await + .clone() +} // 判断是否为下载链接 fn is_download_url(url: &str) -> bool { if let Ok(parsed_url) = Url::parse(url) { - if let Some(path) = parsed_url.path_segments() { - let last_segment = path.last().unwrap_or_default(); - if let Some(extension) = last_segment.split('.').last() { - return FILE_EXT - .split(',') - .any(|ext| ext.eq_ignore_ascii_case(extension)); + if let Some(segments) = parsed_url.path_segments() { + if let Some(last_segment) = segments.last() { + return last_segment + .split('.') + .last() + .map(|extension| FILE_EXT.split(',').any(|ext| ext.eq_ignore_ascii_case(extension))) + .unwrap_or(false); } } } false } -// 代理路由: 处理所有 url 请求 -async fn proxy(Path(uri): Path) -> impl IntoResponse { - // 若 URL 不含 'https://', 'http://',则添加 https:// - let target_url = if !uri.starts_with("http://") && !uri.starts_with("https://") { - if FILE_EXT.is_empty() || !is_download_url(&uri) { - let info = "下载链接必须含 \"https://\" 或 \"http://\""; +// 加载错误模板并返回响应 +fn error_response(info: &str, status: StatusCode) -> impl IntoResponse { + let html_content = std::fs::read_to_string("templates/error.html") + .unwrap_or_else(|_| info.to_string()); + let response_body = html_content.replace("{{ info }}", info); + (status, Html(response_body)) +} - let html_content = std::fs::read_to_string("templates/error.html") - .unwrap_or_else(|_| info.to_string()); - - let html_with_info = html_content.replace("{{ info }}", &info); - return Html(html_with_info).into_response(); +// 代理请求处理 +async fn proxy(Path(uri): Path) -> impl IntoResponse { + let target_url = match normalize_url(&uri) { + Some(url) => url, + None => { + return error_response("下载链接必须含 \"https://\" 或 \"http://\"", StatusCode::BAD_REQUEST).into_response(); } - format!("https://{}", uri) - } else { - uri }; - let client = Client::new(); + let client = get_client().await; match client.get(&target_url).send().await { - Ok(resp) => { - let status = resp.status(); - - // 使用流式返回响应体,而不是加载到内存 - let stream = resp.bytes_stream(); - let body = Body::from_stream(stream); - - let mut response = Response::new(body); - - // 设置额外的 CORS 头 - let response_headers = response.headers_mut(); - response_headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap()); - response_headers.insert( - header::ACCESS_CONTROL_ALLOW_METHODS, - "GET".parse().unwrap(), - ); - response_headers.insert( - header::ACCESS_CONTROL_ALLOW_HEADERS, - "Content-Type, Authorization".parse().unwrap(), - ); - - *response.status_mut() = status; - response - } - Err(_) => { - // 返回错误响应 - let error_body = Body::from("无法访问目标地址,请检查链接是否正确"); - let mut response = Response::new(error_body); - *response.status_mut() = StatusCode::BAD_GATEWAY; - response - } + Ok(resp) => stream_response(resp).await, + Err(_) => error_response("无法访问目标地址,请检查链接是否正确", StatusCode::BAD_REQUEST).into_response(), + } +} + +// 规范化 URL(若无协议,添加 https://) +fn normalize_url(uri: &str) -> Option { + if uri.starts_with("http://") || uri.starts_with("https://") { + Some(uri.to_string()) + } else if !FILE_EXT.is_empty() && is_download_url(uri) { + Some(format!("https://{}", uri)) + } else { + None } } -// favicon.ico 路由 +// 流式传输响应体 +async fn stream_response(resp: reqwest::Response) -> Response { + let status = resp.status(); + let stream = resp.bytes_stream(); + let body = Body::from_stream(stream); + + let mut response = Response::new(body); + let headers = response.headers_mut(); + + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap()); + headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, "GET".parse().unwrap()); + headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, Authorization".parse().unwrap()); + + *response.status_mut() = status; + response +} + + + +// 处理 favicon.ico 请求 async fn favicon_ico() -> impl IntoResponse { - match tokio::fs::read("static/favicon.ico").await { + serve_static_file("static/favicon.ico", "image/x-icon").await +} + +// 处理 robots.txt 请求 +async fn robots_txt() -> impl IntoResponse { + (StatusCode::OK, [(header::CONTENT_TYPE, "text/plain")], "User-agent:*\nDisallow:/") +} + +// 提供静态文件async fn serve_static_file(path: &str, content_type: &str) -> Response { +async fn serve_static_file(path: &str, content_type: &str) -> Response { + match tokio::fs::read(path).await { Ok(content) => ( - axum::http::StatusCode::OK, - [(axum::http::header::CONTENT_TYPE, "image/x-icon")], + StatusCode::OK, + [(header::CONTENT_TYPE, content_type.to_string())], content, - ), + ) + .into_response(), Err(_) => ( - axum::http::StatusCode::NOT_FOUND, - [(axum::http::header::CONTENT_TYPE, "text/plain")], - "404 Not Found".into(), - ), + StatusCode::NOT_FOUND, + [(header::CONTENT_TYPE, "text/plain")], + axum::body::Body::from("404 Not Found"), + ) + .into_response(), } } + -// robots.txt 路由 -async fn robots_txt() -> impl IntoResponse { - let content = "User-agent:*\nDisallow:/"; - ( - axum::http::StatusCode::OK, - [(axum::http::header::CONTENT_TYPE, "text/plain")], - content, - ) -} - -// 处理根路径的请求,返回 index.html 内容 +// 主页处理 async fn index_handler() -> Html { - let title = env::var("TITLE").unwrap_or("文件加速下载".to_string()); - // 假设 `index.html` 存放在项目根目录的 "templates" 文件夹下 + let title = env::var("TITLE").unwrap_or_else(|_| "文件加速下载".to_string()); let html_content = std::fs::read_to_string("templates/index.html") .unwrap_or_else(|_| "Error: Could not load index.html".to_string()); - - // 将 "title" 插入到模板中(简单替换逻辑) - let html_with_title = html_content.replace("{{ title }}", &title); - - Html(html_with_title) + Html(html_content.replace("{{ title }}", &title)) } #[tokio::main] async fn main() { - // 读取环境变量 (可选) - let host = env::var("HOST").unwrap_or("0.0.0.0".to_string()); - let port: u16 = env::var("PORT") - .unwrap_or("8000".to_string()) - .parse() - .unwrap_or(8000); - - // 创建路由 + let host = env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); + let port: u16 = env::var("PORT").ok().and_then(|p| p.parse().ok()).unwrap_or(8000); + let app = Router::new() .route("/favicon.ico", get(favicon_ico)) .route("/robots.txt", get(robots_txt)) .route("/*uri", get(proxy)) .route("/", get(index_handler)); - // 绑定并启动服务 let addr = SocketAddr::from(([0, 0, 0, 0], port)); println!("Listening on {}:{}", host, port); - // run our app with hyper, listening globally on port addr let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); -} \ No newline at end of file +}