From e3122bac807dd3e0b676b5754f7b3900eb684bc0 Mon Sep 17 00:00:00 2001
From: Max Kalashnikoff <geekmaks@gmail.com>
Date: Thu, 16 May 2024 15:50:12 +0200
Subject: [PATCH] fix: applying rate limiting middleware to all except push
 endpoints

---
 src/lib.rs | 53 ++++++++++++++++++++++++++++-------------------------
 1 file changed, 28 insertions(+), 25 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index cf1f7769..88512ec7 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -235,62 +235,65 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Config) ->
                 ),
             );
 
-        let app = Router::new()
-            .route("/health", get(handlers::health::handler))
-            .nest("/tenants", tenancy_routes)
+        Router::new()
+            .route("/health", get(handlers::health::handler).layer(
+                axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+            ))
+            .nest("/tenants", tenancy_routes.layer(
+                axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+            ))
             .route(
                 "/:tenant_id/clients",
-                post(handlers::register_client::handler),
+                post(handlers::register_client::handler).layer(
+                    axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+                ),
             )
             .route(
                 "/:tenant_id/clients/:id",
-                delete(handlers::delete_client::handler),
+                delete(handlers::delete_client::handler).layer(
+                    axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+                ),
             )
+            // Rate limiting middleware is not applying to push_handler because it is used by the relay
             .route(
                 "/:tenant_id/clients/:id",
                 post(handlers::push_message::handler),
             )
-            .layer(global_middleware);
-
-        let app = if let Some(geoblock) = state_arc.geoblock.clone() {
-            app.layer(geoblock)
-        } else {
-            app
-        };
-        let app = app.route_layer(axum::middleware::from_fn_with_state(
-            state_arc.clone(),
-            rate_limit_middleware,
-        ));
-        app.with_state(state_arc.clone())
+            .layer(global_middleware)
     };
 
     #[cfg(not(feature = "multitenant"))]
     let app = Router::new()
-        .route("/health", get(handlers::health::handler))
+        .route("/health", get(handlers::health::handler).layer(
+            axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+        ))
         .route(
             "/clients",
-            post(handlers::single_tenant_wrappers::register_handler),
+            post(handlers::single_tenant_wrappers::register_handler).layer(
+                axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+            ),
         )
         .route(
             "/clients/:id",
-            delete(handlers::single_tenant_wrappers::delete_handler),
+            delete(handlers::single_tenant_wrappers::delete_handler).layer(
+                axum::middleware::from_fn_with_state(state_arc.clone(), rate_limit_middleware),
+            ),
         )
+        // Rate limiting middleware is not applying to push_handler because it is used by the relay
         .route(
             "/clients/:id",
             post(handlers::single_tenant_wrappers::push_handler),
         )
         .layer(global_middleware);
+
+    // If geoblock is enabled, add the geoblock middleware to the app
     let app = if let Some(geoblock) = state_arc.geoblock.clone() {
         app.layer(geoblock)
     } else {
         app
     };
-    let app = app.route_layer(axum::middleware::from_fn_with_state(
-        state_arc.clone(),
-        rate_limit_middleware,
-    ));
-    let app = app.with_state(state_arc.clone());
 
+    let app = app.with_state(state_arc.clone());
     let private_app = Router::new()
         .route("/metrics", get(handlers::metrics::handler))
         .with_state(state_arc);