Skip to content

Commit

Permalink
Rate limiting by path
Browse files Browse the repository at this point in the history
  • Loading branch information
filiptronicek committed Jan 7, 2024
1 parent 844fa62 commit 175456c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
35 changes: 29 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use rocket::http::Status;
use rocket::response::status::Custom;
use utils::id::gen_id;
use utils::log::setup_logger;
use utils::rate_limit::RateLimitConfig;

use std::result::Result;
use std::result::Result::Ok;
use std::string::String;
use std::time::Duration;

use rocket::form::Form;
use rocket::serde::json::Json;
Expand Down Expand Up @@ -41,7 +43,7 @@ fn status(_rate_limiter: RateLimiter) -> Result<Json<APIResponse>, Custom<Json<A
result: "A problem with the database has occurred".to_string(),
};
return Err(Custom(Status::InternalServerError, Json(response)));
},
}
};

let code = "test".to_string();
Expand Down Expand Up @@ -145,7 +147,7 @@ fn set_clip(
result: "A problem with the database has occurred".to_string(),
};
return Err(Custom(Status::InternalServerError, Json(response)));
},
}
};

// Check for existence of the URL in the database
Expand Down Expand Up @@ -211,7 +213,7 @@ fn get_clip(
result: "A problem with the database has occurred".to_string(),
};
return Err(Custom(Status::InternalServerError, Json(response)));
},
}
};

let result = db::get_clip(&mut db_connection, code);
Expand All @@ -221,7 +223,7 @@ fn get_clip(
status: APIStatus::Success,
result: clip.url,
};
Ok(Custom(Status::Created, Json(response)))
Ok(Custom(Status::Ok, Json(response)))
}
Ok(None) => {
let response = APIResponse {
Expand Down Expand Up @@ -281,7 +283,7 @@ fn version(_rate_limiter: RateLimiter) -> Json<Version> {
}

#[launch]
fn rocket() -> _ {
async fn rocket() -> _ {
match setup_logger() {
Ok(path) => {
println!("Logger setup at {}", path);
Expand All @@ -290,6 +292,27 @@ fn rocket() -> _ {
println!("Error whilst setting up logger: {}", e);
}
};

let rate_limiter = RateLimiter::new();
rate_limiter
.add_config(
"/api/get",
RateLimitConfig::new(Duration::from_secs(30), 100),
)
.await;
rate_limiter
.add_config(
"/api/set",
RateLimitConfig::new(Duration::from_secs(60), 20),
)
.await;
rate_limiter
.add_config(
"/api/status",
RateLimitConfig::new(Duration::from_secs(30), 20),
)
.await;

rocket::build()
.mount(
"/api",
Expand All @@ -303,5 +326,5 @@ fn rocket() -> _ {
],
)
.register("/", catchers![too_many_requests, not_found])
.manage(RateLimiter::new())
.manage(rate_limiter)
}
38 changes: 37 additions & 1 deletion src/utils/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use rocket::request::{self, FromRequest, Outcome};

use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};

extern crate serde;
Expand All @@ -10,10 +11,26 @@ use std::time::{Duration, Instant};

use async_lock::RwLock;

#[derive(Clone)]
pub struct RateLimitConfig {
interval: Duration,
max_requests: u32,
}

impl RateLimitConfig {
pub fn new(interval: Duration, max_requests: u32) -> Self {
RateLimitConfig {
interval,
max_requests,
}
}
}

#[derive(Clone)]
pub struct RateLimiter {
requests: Arc<AtomicU32>,
reset_time: Arc<RwLock<Instant>>,
config: Arc<RwLock<HashMap<String, RateLimitConfig>>>,
}

impl RateLimiter {
Expand All @@ -22,9 +39,15 @@ impl RateLimiter {
RateLimiter {
requests: Arc::new(AtomicU32::new(0)),
reset_time: Arc::new(RwLock::new(Instant::now())),
config: Arc::new(RwLock::new(HashMap::new())),
}
}

pub async fn add_config(&self, path: &str, config: RateLimitConfig) {
let mut configs = self.config.write().await;
configs.insert(path.to_string(), config);
}

async fn should_limit(&self, interval: Duration, max_requests: u32) -> bool {
let mut reset_time = self.reset_time.write().await;
let requests = self.requests.load(Ordering::Relaxed);
Expand Down Expand Up @@ -54,7 +77,20 @@ impl<'r> FromRequest<'r> for RateLimiter {
.state::<RateLimiter>()
.expect("RateLimiter registered as state");

if rate_limiter.should_limit(Duration::from_secs(10), 15).await {
let uri = request.uri();
let path = uri.path().to_string();

let config = {
let configs = rate_limiter.config.read().await;
configs.get(&path).cloned().unwrap_or_else(
|| RateLimitConfig::new(Duration::from_secs(60), 20), // By default, allow 20 requests per minute
)
};

if rate_limiter
.should_limit(config.interval, config.max_requests)
.await
{
Outcome::Error((rocket::http::Status::TooManyRequests, ()))
} else {
Outcome::Success(rate_limiter.clone())
Expand Down

0 comments on commit 175456c

Please sign in to comment.