diff --git a/java/org/apache/catalina/filters/RateLimitFilter.java b/java/org/apache/catalina/filters/RateLimitFilter.java index f96dc805dda8..80315948bb33 100644 --- a/java/org/apache/catalina/filters/RateLimitFilter.java +++ b/java/org/apache/catalina/filters/RateLimitFilter.java @@ -17,7 +17,6 @@ package org.apache.catalina.filters; import java.io.IOException; -import java.util.concurrent.ScheduledExecutorService; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; @@ -30,7 +29,6 @@ import org.apache.juli.logging.Log; import org.apache.juli.logging.LogFactory; import org.apache.tomcat.util.res.StringManager; -import org.apache.tomcat.util.threads.ScheduledThreadPoolExecutor; /** *

@@ -202,21 +200,16 @@ protected boolean isConfigProblemFatal() { public void init(FilterConfig filterConfig) throws ServletException { super.init(filterConfig); - ScheduledExecutorService utilityExecutor = (ScheduledExecutorService) filterConfig.getServletContext() - .getAttribute(ScheduledThreadPoolExecutor.class.getName()); - if (utilityExecutor == null) { - if (newExecutorService == null) { - newExecutorService = new java.util.concurrent.ScheduledThreadPoolExecutor(1); - } - utilityExecutor = newExecutorService; - } - try { rateLimiter = (RateLimiter) Class.forName(rateLimitClassName).getConstructor().newInstance(); } catch (ReflectiveOperationException e) { throw new ServletException(e); } + rateLimiter.setDuration(bucketDuration); + rateLimiter.setRequests(bucketRequests); + rateLimiter.setFilterConfig(filterConfig); + if (policyName != null) { String trimmedName = policyName.trim(); if (!trimmedName.isEmpty()) { @@ -224,8 +217,6 @@ public void init(FilterConfig filterConfig) throws ServletException { } } - rateLimiter.initialize(utilityExecutor, bucketDuration, bucketRequests); - filterName = filterConfig.getFilterName(); log.info(sm.getString("rateLimitFilter.initialized", filterName, Integer.valueOf(bucketRequests), @@ -262,18 +253,9 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha chain.doFilter(request, response); } - private ScheduledExecutorService newExecutorService = null; - @Override public void destroy() { rateLimiter.destroy(); - if (newExecutorService != null) { - try { - newExecutorService.shutdown(); - } catch (SecurityException e) { - // ignore - } - } super.destroy(); } diff --git a/java/org/apache/catalina/util/FastRateLimiter.java b/java/org/apache/catalina/util/FastRateLimiter.java index f20d69c2e906..d4feb93a4a68 100644 --- a/java/org/apache/catalina/util/FastRateLimiter.java +++ b/java/org/apache/catalina/util/FastRateLimiter.java @@ -33,4 +33,9 @@ protected String getDefaultPolicyName() { protected TimeBucketCounterBase newCounterInstance(ScheduledExecutorService executorService, int duration) { return new TimeBucketCounter(executorService, duration); } + + @Override + public TimeBucketCounter getBucketCounter() { + return (TimeBucketCounter)bucketCounter; + } } diff --git a/java/org/apache/catalina/util/RateLimiter.java b/java/org/apache/catalina/util/RateLimiter.java index 297dea89e238..8ca41937a85b 100644 --- a/java/org/apache/catalina/util/RateLimiter.java +++ b/java/org/apache/catalina/util/RateLimiter.java @@ -17,7 +17,7 @@ package org.apache.catalina.util; -import java.util.concurrent.ScheduledExecutorService; +import jakarta.servlet.FilterConfig; public interface RateLimiter { @@ -26,11 +26,25 @@ public interface RateLimiter { */ int getDuration(); + /** + * Sets the configured duration value in seconds. + * + * @param duration The duration of the time window in seconds + */ + void setDuration(int duration); + /** * @return the maximum number of requests allowed per time window */ int getRequests(); + /** + * Sets the configured number of requests allowed per time window. + * + * @param requests The number of requests per time window + */ + void setRequests(int requests); + /** * Increments the number of requests by the given identifier in the current time window. * @@ -46,13 +60,11 @@ public interface RateLimiter { void destroy(); /** - * Initialize with parameters, start {@link TimeBucketCounterBase}. + * Pass the FilterConfig to configure the filter. * - * @param executorService the executor - * @param duration the duration of the time window in seconds - * @param requests the configured number of requests allowed per time window + * @param filterConfig The FilterConfig used to configure the associated filter */ - void initialize(ScheduledExecutorService executorService, int duration, int requests); + void setFilterConfig(FilterConfig filterConfig); /** * @return name of RateLimit policy @@ -80,7 +92,7 @@ default String getPolicy() { /** * Provide the quota header for this rate limit for a given request count within the current time window. * - * @param requestCount The request count within the current time window + * @param requestCount The request count within the current time window * * @return the quota header for the given value of request count * diff --git a/java/org/apache/catalina/util/RateLimiterBase.java b/java/org/apache/catalina/util/RateLimiterBase.java index ecedb3731b5c..589d80b1a84c 100644 --- a/java/org/apache/catalina/util/RateLimiterBase.java +++ b/java/org/apache/catalina/util/RateLimiterBase.java @@ -18,8 +18,11 @@ import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; +import jakarta.servlet.FilterConfig; + /** * Basic implementation of {@link RateLimiter}, provides runtime data maintenance mechanism monitor. */ @@ -29,8 +32,10 @@ public abstract class RateLimiterBase implements RateLimiter { TimeBucketCounterBase bucketCounter; + int requests; int actualRequests; + int duration; int actualDuration; // Initial policy name can be rewritten by setPolicyName() @@ -63,11 +68,21 @@ public int getDuration() { return actualDuration; } + @Override + public void setDuration(int duration) { + this.duration=duration; + } + @Override public int getRequests() { return actualRequests; } + @Override + public void setRequests(int requests) { + this.requests=requests; + } + @Override public int increment(String identifier) { return bucketCounter.increment(identifier); @@ -76,6 +91,13 @@ public int increment(String identifier) { @Override public void destroy() { bucketCounter.destroy(); + if (newExecutorService != null) { + try { + newExecutorService.shutdown(); + } catch (SecurityException e) { + // ignore + } + } } /** @@ -90,17 +112,20 @@ public void destroy() { protected abstract TimeBucketCounterBase newCounterInstance(ScheduledExecutorService utilityExecutor, int duration); @Override - public void initialize(ScheduledExecutorService utilityExecutor, int duration, int requests) { - if (bucketCounter != null) { - bucketCounter.destroy(); - } + public void setFilterConfig(FilterConfig filterConfig) { + + ScheduledExecutorService executorService = (ScheduledExecutorService) filterConfig.getServletContext() + .getAttribute(ScheduledThreadPoolExecutor.class.getName()); - bucketCounter = newCounterInstance(utilityExecutor, duration); + if (executorService == null) { + newExecutorService = new java.util.concurrent.ScheduledThreadPoolExecutor(1); + executorService = newExecutorService; + } + bucketCounter = newCounterInstance(executorService, duration); actualDuration = bucketCounter.getBucketDuration(); actualRequests = (int) Math.round(bucketCounter.getRatio() * requests); } - /** * Returns the internal instance of {@link TimeBucketCounterBase} * @@ -109,4 +134,10 @@ public void initialize(ScheduledExecutorService utilityExecutor, int duration, i public TimeBucketCounterBase getBucketCounter() { return bucketCounter; } + + /** + * The self-owned utility executor, will be instantiated only when ScheduledThreadPoolExecutor is absent during + * filter configure phase. + */ + private ScheduledThreadPoolExecutor newExecutorService = null; }