From c6c822ad3427484bbdf2cff0b018b1f916220196 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Fri, 6 Oct 2023 22:52:08 +0300 Subject: [PATCH] fix: use `axum-client-ip` to get the real client IP --- Cargo.lock | 3 ++- Cargo.toml | 1 + src/handlers/push_message.rs | 10 +++++----- src/handlers/register_client.rs | 22 +++++++++++++--------- src/handlers/single_tenant_wrappers.rs | 12 ++++++------ src/lib.rs | 4 +++- 6 files changed, 30 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a1f65403..c2616ff9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1187,7 +1187,7 @@ checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" [[package]] name = "echo-server" -version = "0.34.7" +version = "0.35.0" dependencies = [ "a2", "async-recursion", @@ -1196,6 +1196,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "axum", + "axum-client-ip", "base64 0.21.4", "build-info", "build-info-build", diff --git a/Cargo.toml b/Cargo.toml index 8eb9965a..a6d6ad14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ wc = { git = "https://github.com/WalletConnect/utils-rs.git", tag = "v0.5.1", fe tokio = { version = "1", features = ["full"] } axum = { version = "0.6", features = ["json", "multipart", "tokio"] } +axum-client-ip = "0.4" tower = "0.4" tower-http = { version = "0.4", features = ["trace", "cors", "request-id", "propagate-header", "catch-panic"] } hyper = "0.14" diff --git a/src/handlers/push_message.rs b/src/handlers/push_message.rs index 3063c9ff..b411463a 100644 --- a/src/handlers/push_message.rs +++ b/src/handlers/push_message.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "analytics")] +use axum_client_ip::SecureClientIp; use { crate::{ analytics::message_info::MessageInfo, @@ -23,8 +25,6 @@ use { serde::{Deserialize, Serialize}, std::sync::Arc, }; -#[cfg(feature = "analytics")] -use {axum::extract::ConnectInfo, std::net::SocketAddr}; #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] pub struct MessagePayload { @@ -46,7 +46,7 @@ pub struct PushMessageBody { } pub async fn handler( - #[cfg(feature = "analytics")] ConnectInfo(addr): ConnectInfo, + #[cfg(feature = "analytics")] SecureClientIp(client_ip): SecureClientIp, Path((tenant_id, id)): Path<(String, String)>, StateExtractor(state): StateExtractor>, headers: HeaderMap, @@ -104,7 +104,7 @@ pub async fn handler( tokio::spawn(async move { if let Some(analytics) = &state.analytics { let (country, continent, region) = analytics - .lookup_geo_data(addr.ip()) + .lookup_geo_data(client_ip) .map_or((None, None, None), |geo| { (geo.country, geo.continent, geo.region) }); @@ -113,7 +113,7 @@ pub async fn handler( %request_id, %tenant_id, client_id = %id, - ip = %addr.ip(), + ip = %client_ip, "loaded geo data" ); diff --git a/src/handlers/register_client.rs b/src/handlers/register_client.rs index fe4e3fd8..f9688908 100644 --- a/src/handlers/register_client.rs +++ b/src/handlers/register_client.rs @@ -1,5 +1,5 @@ #[cfg(feature = "analytics")] -use {crate::analytics::client_info::ClientInfo, axum::extract::ConnectInfo, std::net::SocketAddr}; +use {crate::analytics::client_info::ClientInfo, axum_client_ip::SecureClientIp}; use { crate::{ error::{ @@ -31,7 +31,7 @@ pub struct RegisterBody { } pub async fn handler( - #[cfg(feature = "analytics")] ConnectInfo(addr): ConnectInfo, + #[cfg(feature = "analytics")] SecureClientIp(client_ip): SecureClientIp, Path(tenant_id): Path, StateExtractor(state): StateExtractor>, headers: HeaderMap, @@ -89,11 +89,15 @@ pub async fn handler( state .client_store - .create_client(&tenant_id, &client_id, Client { - tenant_id: tenant_id.clone(), - push_type, - token: body.token, - }) + .create_client( + &tenant_id, + &client_id, + Client { + tenant_id: tenant_id.clone(), + push_type, + token: body.token, + }, + ) .await?; info!( @@ -108,7 +112,7 @@ pub async fn handler( tokio::spawn(async move { if let Some(analytics) = &state.analytics { let (country, continent, region) = analytics - .lookup_geo_data(addr.ip()) + .lookup_geo_data(client_ip) .map_or((None, None, None), |geo| { (geo.country, geo.continent, geo.region) }); @@ -117,7 +121,7 @@ pub async fn handler( %request_id, %tenant_id, %client_id, - ip = %addr.ip(), + ip = %client_ip, "loaded geo data" ); diff --git a/src/handlers/single_tenant_wrappers.rs b/src/handlers/single_tenant_wrappers.rs index 8d49388f..d65a09d3 100644 --- a/src/handlers/single_tenant_wrappers.rs +++ b/src/handlers/single_tenant_wrappers.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "analytics")] +use axum_client_ip::SecureClientIp; use { crate::{ error::Result, @@ -13,8 +15,6 @@ use { hyper::HeaderMap, std::sync::Arc, }; -#[cfg(feature = "analytics")] -use {axum::extract::ConnectInfo, std::net::SocketAddr}; #[cfg(feature = "multitenant")] use crate::error::Error::MissingTenantId; @@ -37,7 +37,7 @@ pub async fn delete_handler( } pub async fn push_handler( - #[cfg(feature = "analytics")] addr: ConnectInfo, + #[cfg(feature = "analytics")] SecureClientIp(client_ip): SecureClientIp, Path(id): Path, state: StateExtractor>, headers: HeaderMap, @@ -48,7 +48,7 @@ pub async fn push_handler( #[cfg(all(not(feature = "multitenant"), feature = "analytics"))] return crate::handlers::push_message::handler( - addr, + SecureClientIp(client_ip), Path((DEFAULT_TENANT_ID.to_string(), id)), state, headers, @@ -67,7 +67,7 @@ pub async fn push_handler( } pub async fn register_handler( - #[cfg(feature = "analytics")] addr: ConnectInfo, + #[cfg(feature = "analytics")] SecureClientIp(client_ip): SecureClientIp, state: StateExtractor>, headers: HeaderMap, body: Json, @@ -77,7 +77,7 @@ pub async fn register_handler( #[cfg(all(not(feature = "multitenant"), feature = "analytics"))] return crate::handlers::register_client::handler( - addr, + SecureClientIp(client_ip), Path(DEFAULT_TENANT_ID.to_string()), state, headers, diff --git a/src/lib.rs b/src/lib.rs index fad5b0c2..e96deb06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use { routing::{delete, get, post}, Router, }, + axum_client_ip::SecureClientIpSource, config::Config, hyper::http::Method, opentelemetry::{sdk::Resource, KeyValue}, @@ -210,7 +211,8 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) -> hyper::http::header::CONTENT_TYPE, hyper::http::header::AUTHORIZATION, ]), - ); + ) + .layer(SecureClientIpSource::ConnectInfo.into_extension()); #[cfg(feature = "multitenant")] let app = {