diff --git a/Cargo.lock b/Cargo.lock index 3c8f8a2..a8016cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,30 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "actix" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5" -dependencies = [ - "actix-rt", - "actix_derive", - "bitflags 1.3.2", - "bytes", - "crossbeam-channel", - "futures-core", - "futures-sink", - "futures-task", - "futures-util", - "log", - "once_cell", - "parking_lot 0.12.0", - "pin-project-lite", - "smallvec", - "tokio", - "tokio-util", -] - [[package]] name = "actix-codec" version = "0.5.0" @@ -242,17 +218,6 @@ dependencies = [ "syn 1.0.103", ] -[[package]] -name = "actix_derive" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.103", -] - [[package]] name = "adler" version = "1.0.2" @@ -528,16 +493,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53" -dependencies = [ - "cfg-if", - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.8" @@ -769,6 +724,12 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.21" @@ -808,6 +769,24 @@ dependencies = [ "wasi 0.10.2+wasi-snapshot-preview1", ] +[[package]] +name = "governor" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot 0.12.0", + "quanta", + "rand", + "smallvec", +] + [[package]] name = "h2" version = "0.3.13" @@ -1003,6 +982,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "mach2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +dependencies = [ + "libc", +] + [[package]] name = "matches" version = "0.1.9" @@ -1053,6 +1041,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "ntapi" version = "0.3.7" @@ -1233,6 +1233,22 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.29" @@ -1283,6 +1299,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.2.13" @@ -1514,7 +1539,6 @@ dependencies = [ name = "testaustime-rs" version = "0.3.0" dependencies = [ - "actix", "actix-cors", "actix-web", "argon2", @@ -1526,6 +1550,7 @@ dependencies = [ "env_logger", "futures", "futures-util", + "governor", "http", "itertools", "log", diff --git a/Cargo.toml b/Cargo.toml index 5101b0b..a21e66c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ lto = true [dependencies] actix-web = { version = "4.2.1", features = ["macros", "rustls"] } -actix = "0.13" awc = { version = "3.0.0", features = ["rustls"], optional = true } actix-cors = "0.6" http = "0.2" @@ -45,3 +44,4 @@ dotenv = "0.15" url = "2.2" itertools = "0.10.3" +governor = "0.6.0" diff --git a/src/main.rs b/src/main.rs index c4e1dd0..b5ef721 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,8 @@ mod requests; mod schema; mod utils; -use actix::Actor; +use std::{num::NonZeroU32, sync::Arc}; + use actix_cors::Cors; use actix_web::{ dev::{ServiceRequest, ServiceResponse}, @@ -29,7 +30,8 @@ use diesel::{ r2d2::{ConnectionManager, Pool}, PgConnection, }; -use ratelimiter::{RateLimiter, RateLimiterStorage}; +use governor::{Quota, RateLimiter}; +use ratelimiter::TestaustimeRateLimiter; use serde_derive::Deserialize; use tracing::Span; use tracing_actix_web::{root_span, RootSpanBuilder, TracingLogger}; @@ -99,7 +101,12 @@ async fn main() -> std::io::Result<()> { storage: DashMap::new(), }); - let ratelimiter = RateLimiterStorage::new(config.max_requests_per_min, 60).start(); + let ratelimiter = Arc::new( + RateLimiter::keyed(Quota::per_minute( + NonZeroU32::new(config.max_requests_per_min as u32).unwrap(), + )) + .with_middleware(), + ); let heartbeat_store = Data::new(api::activity::HeartBeatMemoryStore::new()); let leaderboard_cache = Data::new(api::leaderboards::LeaderboardCache::new()); @@ -134,12 +141,10 @@ async fn main() -> std::io::Result<()> { let scope = web::scope("") .wrap(tracing) .wrap(AuthMiddleware) - .wrap(RateLimiter { - storage: ratelimiter.clone(), + .wrap(TestaustimeRateLimiter { + limiter: Arc::clone(&ratelimiter), use_peer_addr: config.ratelimit_by_peer_ip, - maxrpm: config.max_requests_per_min, bypass_token: config.bypass_token.clone(), - reset_interval: 60, }) .service({ web::scope("/activity") diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs index 0ce0a82..2ffe1f5 100644 --- a/src/ratelimiter.rs +++ b/src/ratelimiter.rs @@ -1,133 +1,27 @@ -use std::{ - collections::HashMap, - rc::Rc, - time::{Duration, Instant}, -}; +use std::{net::IpAddr, rc::Rc, sync::Arc}; -use actix::prelude::*; use actix_web::{ body::EitherBody, dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, - http::header::{HeaderName, HeaderValue}, Error, HttpResponse, }; -use futures_util::{future::LocalBoxFuture, stream::once}; - -pub struct RateLimitInfo { - pub request_count: usize, - pub last_reset: Instant, -} - -pub struct RateLimiterStorage { - pub clients: HashMap, - pub maxrpm: usize, - pub reset_interval: usize, - event_count: usize, -} - -impl Actor for RateLimiterStorage { - type Context = Context; -} - -impl RateLimiterStorage { - pub fn new(maxrpm: usize, reset_interval: usize) -> Self { - RateLimiterStorage { - clients: HashMap::new(), - maxrpm, - reset_interval, - event_count: 0, - } - } -} - -struct ConfigRequest; - -impl Message for ConfigRequest { - type Result = Result<(usize, usize), std::io::Error>; -} - -impl Handler for RateLimiterStorage { - type Result = Result<(usize, usize), std::io::Error>; - - fn handle(&mut self, _: ConfigRequest, _: &mut Context) -> Self::Result { - Ok((self.maxrpm.to_owned(), self.reset_interval.to_owned())) - } -} - -#[derive(Message)] -#[rtype(result = "()")] -struct ClearRequest; - -impl Handler for RateLimiterStorage { - type Result = (); - - fn handle(&mut self, _: ClearRequest, _: &mut Context) { - let cur_time = Instant::now(); - self.clients - .retain(|_, i| cur_time.duration_since(i.last_reset) < Duration::from_secs(1800)); - } -} - -struct IpRequest { - pub ip: String, -} - -impl Message for IpRequest { - type Result = Result<(Option, Duration), std::io::Error>; -} - -impl Handler for RateLimiterStorage { - type Result = Result<(Option, Duration), std::io::Error>; +use futures_util::future::LocalBoxFuture; +use governor::{ + clock::DefaultClock, middleware::StateInformationMiddleware, + state::keyed::DefaultKeyedStateStore, RateLimiter, +}; +use http::{header::HeaderName, HeaderValue}; - fn handle(&mut self, req: IpRequest, ctx: &mut Context) -> Self::Result { - if self.event_count > 1000 { - ctx.add_message_stream(once(async { ClearRequest })); - self.event_count = 0; - } else { - self.event_count += 1; - }; - if let Some(rlinfo) = self.clients.get_mut(&req.ip) { - let time = Instant::now(); - let duration = (rlinfo.last_reset).duration_since(time); - if duration == Duration::from_secs(0) { - rlinfo.request_count = 1; - rlinfo.last_reset = time + Duration::from_secs(self.reset_interval as u64); - Ok(( - Some(self.maxrpm - rlinfo.request_count), - Duration::from_secs(self.reset_interval as u64), - )) - } else if rlinfo.request_count >= self.maxrpm { - Ok((None, duration)) - } else { - rlinfo.request_count += 1; - Ok((Some(self.maxrpm - rlinfo.request_count), duration)) - } - } else { - self.clients.insert( - req.ip, - RateLimitInfo { - request_count: 1, - last_reset: std::time::Instant::now() - + Duration::from_secs(self.reset_interval as u64), - }, - ); - Ok(( - Some(self.maxrpm - 1), - Duration::from_secs(self.reset_interval as u64), - )) - } - } -} +type SharedRateLimiter = + Arc, DefaultClock, M>>; -pub struct RateLimiter { - pub storage: Addr, +pub struct TestaustimeRateLimiter { + pub limiter: SharedRateLimiter, pub use_peer_addr: bool, - pub maxrpm: usize, pub bypass_token: String, - pub reset_interval: usize, } -impl Transform for RateLimiter +impl Transform for TestaustimeRateLimiter where S: Service, Error = Error> + 'static, S::Future: 'static, @@ -136,33 +30,29 @@ where type Response = ServiceResponse>; type Error = Error; type InitError = (); - type Transform = RateLimiterTransform; + type Transform = TestaustimeRateLimiterTransform; type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { - let ratelimiter = RateLimiterTransform { + let transform = Ok(Self::Transform { service: Rc::new(service), - ratelimiter: self.storage.clone(), + limiter: Arc::clone(&self.limiter), use_peer_addr: self.use_peer_addr, bypass_token: self.bypass_token.clone(), - maxrpm: self.maxrpm, - reset_interval: self.reset_interval, - }; + }); - Box::pin(async { Ok(ratelimiter) }) + Box::pin(async move { transform }) } } -pub struct RateLimiterTransform { - pub service: Rc, - pub ratelimiter: Addr, - pub use_peer_addr: bool, - pub maxrpm: usize, - pub bypass_token: String, - pub reset_interval: usize, +pub struct TestaustimeRateLimiterTransform { + service: Rc, + limiter: SharedRateLimiter, + use_peer_addr: bool, + bypass_token: String, } -impl Service for RateLimiterTransform +impl Service for TestaustimeRateLimiterTransform where S: Service, Error = Error> + 'static, S::Future: 'static, @@ -182,7 +72,7 @@ where .get("bypass-token") .is_some_and(|token| token.to_str().is_ok_and(|token| self.bypass_token == token)); - if bypass { + let addr = if bypass { req.headers() .get("client-ip") .and_then(|ip| ip.to_str().ok()) @@ -190,43 +80,60 @@ where conn_info.peer_addr() } else { conn_info.realip_remote_addr() - } + }; + + addr.and_then(|addr| addr.parse::().ok()) } { - let res = self.ratelimiter.send(IpRequest { ip: ip.to_owned() }); - let service = Rc::clone(&self.service); - let maxrpm = self.maxrpm; + match self.limiter.check_key(&ip) { + Ok(state) => { + let res = self.service.call(req); - Box::pin(async move { - let (remaining, reset) = res - .await - .map_err(actix_web::error::ErrorInternalServerError)??; - if let Some(remaining) = remaining { - let mut resp = service.call(req).await?; - let headers = resp.headers_mut(); - headers.insert( - HeaderName::from_static("ratelimit-limit"), - HeaderValue::from_str(&maxrpm.to_string()).unwrap(), - ); - headers.insert( - HeaderName::from_static("ratelimit-remaining"), - HeaderValue::from_str(&remaining.to_string()).unwrap(), - ); - headers.insert( - HeaderName::from_static("ratelimit-reset"), - HeaderValue::from_str(&reset.as_secs().to_string()).unwrap(), - ); - Ok(resp.map_into_left_body()) - } else { + Box::pin(async move { + let mut res = res.await?; + + let headers = res.headers_mut(); + + let quota = state.quota(); + + headers.insert( + HeaderName::from_static("ratelimit-limit"), + HeaderValue::from_str("a.burst_size().to_string())?, + ); + + headers.insert( + HeaderName::from_static("ratelimit-remaining"), + HeaderValue::from_str(&state.remaining_burst_capacity().to_string())?, + ); + + headers.insert( + HeaderName::from_static("ratelimit-reset"), + HeaderValue::from_str( + "a.replenish_interval().as_secs().to_string(), + )?, + ); + + Ok(res.map_into_left_body()) + }) + } + Err(denied) => Box::pin(async move { let response = HttpResponse::TooManyRequests() - .insert_header(("ratelimit-limit", maxrpm.to_string())) - .insert_header(("ratelimit-remaining", 0usize.to_string())) - .insert_header(("ratelimit-reset", reset.as_secs().to_string())) + .insert_header(("ratelimit-limit", denied.quota().burst_size().to_string())) + .insert_header(("ratelimit-remaining", "0")) + .insert_header(( + "ratelimit-reset", + denied.quota().replenish_interval().as_secs().to_string(), + )) .finish(); + Ok(req.into_response(response.map_into_right_body())) - } - }) + }), + } } else { - Box::pin(async move { Err(actix_web::error::ErrorInternalServerError("wtf")) }) + Box::pin(async move { + Err(actix_web::error::ErrorInternalServerError( + "Failed to get request ip (?)", + )) + }) } } }