Skip to content

Commit

Permalink
fix: need libcurl
Browse files Browse the repository at this point in the history
  • Loading branch information
jetsung committed Dec 11, 2024
1 parent baefe61 commit f5e91fa
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 95 deletions.
6 changes: 5 additions & 1 deletion Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ LABEL maintainer="Jetsung Chan<[email protected]>"
# RUN apt-get update && apt-get install -y locales libssl3 libssl-dev && rm -rf /var/lib/apt/lists/* \
# && localedef -i zh_CN -c -f UTF-8 -A /usr/share/locale/locale.alias zh_CN.UTF-8
# ENV LANG zh_CN.utf8
RUN apt update && apt install -y openssl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN apt update && \
apt install -y deborphan openssl libcurl4 libssl-dev && \
rm -rf /var/lib/apt/lists/* && \
rm -rf /var/cache/apt/archives/* && \
deborphan | xargs apt -y remove --purge

WORKDIR /app

Expand Down
197 changes: 103 additions & 94 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,144 +1,153 @@
use axum::{
body::Body, extract::Path, http::{header, StatusCode}, response::{Html, IntoResponse, Response}, Router
body::Body,
extract::Path,
http::{header, StatusCode},
response::{Html, IntoResponse, Response},
routing::get,
Router,
};
use axum::routing::get;
use reqwest::Client;
use std::{env, net::SocketAddr};
use std::{env, net::SocketAddr, sync::Arc};
use tokio::sync::OnceCell;
use url::Url;

static FILE_EXT: &str = ""; // 文件扩展名(逗号分隔)
// 静态文件扩展名,用于判断是否为下载链接
static FILE_EXT: &str = ""; // 逗号分隔的扩展名

// 共享的 HTTP 客户端
static CLIENT: OnceCell<Arc<Client>> = OnceCell::const_new();

// 初始化 HTTP 客户端
async fn get_client() -> Arc<Client> {
CLIENT
.get_or_init(|| async { Arc::new(Client::new()) })
.await
.clone()
}

// 判断是否为下载链接
fn is_download_url(url: &str) -> bool {
if let Ok(parsed_url) = Url::parse(url) {
if let Some(path) = parsed_url.path_segments() {
let last_segment = path.last().unwrap_or_default();
if let Some(extension) = last_segment.split('.').last() {
return FILE_EXT
.split(',')
.any(|ext| ext.eq_ignore_ascii_case(extension));
if let Some(segments) = parsed_url.path_segments() {
if let Some(last_segment) = segments.last() {
return last_segment
.split('.')
.last()
.map(|extension| FILE_EXT.split(',').any(|ext| ext.eq_ignore_ascii_case(extension)))
.unwrap_or(false);
}
}
}
false
}

// 代理路由: 处理所有 url 请求
async fn proxy(Path(uri): Path<String>) -> impl IntoResponse {
// 若 URL 不含 'https://', 'http://',则添加 https://
let target_url = if !uri.starts_with("http://") && !uri.starts_with("https://") {
if FILE_EXT.is_empty() || !is_download_url(&uri) {

let info = "下载链接必须含 \"https://\"\"http://\"";
// 加载错误模板并返回响应
fn error_response(info: &str, status: StatusCode) -> impl IntoResponse {
let html_content = std::fs::read_to_string("templates/error.html")
.unwrap_or_else(|_| info.to_string());
let response_body = html_content.replace("{{ info }}", info);
(status, Html(response_body))
}

let html_content = std::fs::read_to_string("templates/error.html")
.unwrap_or_else(|_| info.to_string());

let html_with_info = html_content.replace("{{ info }}", &info);
return Html(html_with_info).into_response();
// 代理请求处理
async fn proxy(Path(uri): Path<String>) -> impl IntoResponse {
let target_url = match normalize_url(&uri) {
Some(url) => url,
None => {
return error_response("下载链接必须含 \"https://\"\"http://\"", StatusCode::BAD_REQUEST).into_response();
}
format!("https://{}", uri)
} else {
uri
};

let client = Client::new();
let client = get_client().await;

match client.get(&target_url).send().await {
Ok(resp) => {
let status = resp.status();

// 使用流式返回响应体,而不是加载到内存
let stream = resp.bytes_stream();
let body = Body::from_stream(stream);

let mut response = Response::new(body);

// 设置额外的 CORS 头
let response_headers = response.headers_mut();
response_headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
response_headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
"GET".parse().unwrap(),
);
response_headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
"Content-Type, Authorization".parse().unwrap(),
);

*response.status_mut() = status;
response
}
Err(_) => {
// 返回错误响应
let error_body = Body::from("无法访问目标地址,请检查链接是否正确");
let mut response = Response::new(error_body);
*response.status_mut() = StatusCode::BAD_GATEWAY;
response
}
Ok(resp) => stream_response(resp).await,
Err(_) => error_response("无法访问目标地址,请检查链接是否正确", StatusCode::BAD_REQUEST).into_response(),
}
}

// 规范化 URL(若无协议,添加 https://)
fn normalize_url(uri: &str) -> Option<String> {
if uri.starts_with("http://") || uri.starts_with("https://") {
Some(uri.to_string())
} else if !FILE_EXT.is_empty() && is_download_url(uri) {
Some(format!("https://{}", uri))
} else {
None
}
}

// favicon.ico 路由
// 流式传输响应体
async fn stream_response(resp: reqwest::Response) -> Response {
let status = resp.status();
let stream = resp.bytes_stream();
let body = Body::from_stream(stream);

let mut response = Response::new(body);
let headers = response.headers_mut();

headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, "GET".parse().unwrap());
headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type, Authorization".parse().unwrap());

*response.status_mut() = status;
response
}



// 处理 favicon.ico 请求
async fn favicon_ico() -> impl IntoResponse {
match tokio::fs::read("static/favicon.ico").await {
serve_static_file("static/favicon.ico", "image/x-icon").await
}

// 处理 robots.txt 请求
async fn robots_txt() -> impl IntoResponse {
(StatusCode::OK, [(header::CONTENT_TYPE, "text/plain")], "User-agent:*\nDisallow:/")
}

// 提供静态文件async fn serve_static_file(path: &str, content_type: &str) -> Response {
async fn serve_static_file(path: &str, content_type: &str) -> Response {
match tokio::fs::read(path).await {
Ok(content) => (
axum::http::StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "image/x-icon")],
StatusCode::OK,
[(header::CONTENT_TYPE, content_type.to_string())],
content,
),
)
.into_response(),
Err(_) => (
axum::http::StatusCode::NOT_FOUND,
[(axum::http::header::CONTENT_TYPE, "text/plain")],
"404 Not Found".into(),
),
StatusCode::NOT_FOUND,
[(header::CONTENT_TYPE, "text/plain")],
axum::body::Body::from("404 Not Found"),
)
.into_response(),
}
}


// robots.txt 路由
async fn robots_txt() -> impl IntoResponse {
let content = "User-agent:*\nDisallow:/";
(
axum::http::StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "text/plain")],
content,
)
}

// 处理根路径的请求,返回 index.html 内容
// 主页处理
async fn index_handler() -> Html<String> {
let title = env::var("TITLE").unwrap_or("文件加速下载".to_string());
// 假设 `index.html` 存放在项目根目录的 "templates" 文件夹下
let title = env::var("TITLE").unwrap_or_else(|_| "文件加速下载".to_string());
let html_content = std::fs::read_to_string("templates/index.html")
.unwrap_or_else(|_| "Error: Could not load index.html".to_string());

// 将 "title" 插入到模板中(简单替换逻辑)
let html_with_title = html_content.replace("{{ title }}", &title);

Html(html_with_title)
Html(html_content.replace("{{ title }}", &title))
}

#[tokio::main]
async fn main() {
// 读取环境变量 (可选)
let host = env::var("HOST").unwrap_or("0.0.0.0".to_string());
let port: u16 = env::var("PORT")
.unwrap_or("8000".to_string())
.parse()
.unwrap_or(8000);

// 创建路由
let host = env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let port: u16 = env::var("PORT").ok().and_then(|p| p.parse().ok()).unwrap_or(8000);

let app = Router::new()
.route("/favicon.ico", get(favicon_ico))
.route("/robots.txt", get(robots_txt))
.route("/*uri", get(proxy))
.route("/", get(index_handler));

// 绑定并启动服务
let addr = SocketAddr::from(([0, 0, 0, 0], port));
println!("Listening on {}:{}", host, port);

// run our app with hyper, listening globally on port addr
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
}

0 comments on commit f5e91fa

Please sign in to comment.