From d8093fc717f0a8a1e2fc0dc1615a8c00026b2ee6 Mon Sep 17 00:00:00 2001 From: andrew4699 Date: Mon, 16 Sep 2024 15:10:20 -0700 Subject: [PATCH] Implement our own rate limiter --- .../service/ratelimiter/ClockImpl.java | 3 +- .../DefaultRateLimiterFactory.java | 7 +- .../ratelimiter/OpenTelemetryRateLimiter.java | 38 ------ .../ratelimiter/RateLimiterFilter.java | 8 +- .../ratelimiter/TokenBucketRateLimiter.java | 54 ++++++++ .../service/ratelimiter/MockClock.java | 4 + .../ratelimiter/MockRateLimiterFactory.java | 2 +- ...erTest.java => RateLimiterFilterTest.java} | 4 +- .../TokenBucketRateLimiterTest.java | 117 ++++++++++++++++++ 9 files changed, 186 insertions(+), 51 deletions(-) delete mode 100644 polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/OpenTelemetryRateLimiter.java create mode 100644 polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiter.java rename polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/{RateLimiterTest.java => RateLimiterFilterTest.java} (97%) create mode 100644 polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiterTest.java diff --git a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/ClockImpl.java b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/ClockImpl.java index b7585540f..255538cb6 100644 --- a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/ClockImpl.java +++ b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/ClockImpl.java @@ -20,8 +20,7 @@ /** An implementation of our Clock interface using opentelemetry's Clock implementation */ public class ClockImpl implements Clock { - public ClockImpl() { - } + public ClockImpl() {} @Override public long nanoTime() { diff --git a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/DefaultRateLimiterFactory.java b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/DefaultRateLimiterFactory.java index ffc1815bf..54fc431f9 100644 --- a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/DefaultRateLimiterFactory.java +++ b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/DefaultRateLimiterFactory.java @@ -37,9 +37,8 @@ public class DefaultRateLimiterFactory implements RateLimiterFactory { @Override public Future createRateLimiter(String key) { - return CompletableFuture.supplyAsync( - () -> - new OpenTelemetryRateLimiter( - requestsPerSecond, requestsPerSecond * windowSeconds, new ClockImpl())); + return CompletableFuture.completedFuture( + new TokenBucketRateLimiter( + requestsPerSecond, requestsPerSecond * windowSeconds, new ClockImpl())); } } diff --git a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/OpenTelemetryRateLimiter.java b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/OpenTelemetryRateLimiter.java deleted file mode 100644 index 88418a44e..000000000 --- a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/OpenTelemetryRateLimiter.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.polaris.service.ratelimiter; - -/** - * Wrapper around the opentelemetry RateLimiter that implements the Polaris RateLimiter interface - * The opentelemetry limiter uses a credits/balance system. We treat 1 request as 1 credit. - */ -public class OpenTelemetryRateLimiter implements RateLimiter { - private final io.opentelemetry.sdk.internal.RateLimiter rateLimiter; - - public OpenTelemetryRateLimiter(double creditsPerSecond, double maxBalance, Clock clock) { - rateLimiter = - new io.opentelemetry.sdk.internal.RateLimiter( - creditsPerSecond, maxBalance, new OpenTelemetryClock(clock)); - } - - @Override - public boolean tryAcquire() { - return rateLimiter.trySpend(1); - } -} diff --git a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/RateLimiterFilter.java b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/RateLimiterFilter.java index 241d5a12e..1e1425baa 100644 --- a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/RateLimiterFilter.java +++ b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/RateLimiterFilter.java @@ -44,7 +44,7 @@ public class RateLimiterFilter implements Filter { private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterFilter.class); private static final RateLimiter NO_OP_LIMITER = new NoOpRateLimiter(); private static final RateLimiter ALWAYS_REJECT_LIMITER = - new OpenTelemetryRateLimiter(0, 0, new ClockImpl()); + new TokenBucketRateLimiter(0, 0, new ClockImpl()); private static final Clock CLOCK = new ClockImpl(); private final RateLimiterConfig config; @@ -69,9 +69,9 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha private RateLimiter maybeBlockToGetRateLimiter(String realm) { try { - return perRealmLimiters.computeIfAbsent(realm, (key) -> config - .getRateLimiterFactory() - .createRateLimiter(key)).get(config.getConstructionTimeoutMillis(), TimeUnit.MILLISECONDS); + return perRealmLimiters + .computeIfAbsent(realm, (key) -> config.getRateLimiterFactory().createRateLimiter(key)) + .get(config.getConstructionTimeoutMillis(), TimeUnit.MILLISECONDS); } catch (InterruptedException | ExecutionException | TimeoutException e) { return getDefaultRateLimiterOnConstructionFailed(e); } diff --git a/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiter.java b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiter.java new file mode 100644 index 000000000..587526738 --- /dev/null +++ b/polaris-service/src/main/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiter.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.service.ratelimiter; + +/** Token bucket implementation of a Polaris RateLimiter. */ +public class TokenBucketRateLimiter implements RateLimiter { + private final double tokensPerNano; + private final double maxTokens; + private final Clock clock; + + private double tokens; + private long lastAcquireNanos; + + public TokenBucketRateLimiter(double tokensPerSecond, double maxTokens, Clock clock) { + this.tokensPerNano = tokensPerSecond / 1e9; + this.maxTokens = maxTokens; + this.clock = clock; + + tokens = maxTokens; + lastAcquireNanos = clock.nanoTime(); + } + + @Override + public synchronized boolean tryAcquire() { + // Grant tokens for the time that has passed since our last tryAcquire() + long t = clock.nanoTime(); + long nanosPassed = t - lastAcquireNanos; + lastAcquireNanos = t; + tokens = Math.min(maxTokens, tokens + (nanosPassed * tokensPerNano)); + + // Take a token if they have one available + if (tokens >= 1) { + tokens--; + return true; + } + return false; + } +} diff --git a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockClock.java b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockClock.java index c4e750a0b..ccce0a406 100644 --- a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockClock.java +++ b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockClock.java @@ -28,6 +28,10 @@ public void setMillis(long millis) { nanos = millis * 1_000_000; } + public void setSeconds(long seconds) { + setMillis(seconds * 1000); + } + @Override public long nanoTime() { return nanos; diff --git a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockRateLimiterFactory.java b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockRateLimiterFactory.java index 8ecf409ba..635fab5da 100644 --- a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockRateLimiterFactory.java +++ b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/MockRateLimiterFactory.java @@ -48,7 +48,7 @@ public Future createRateLimiter(String key) { } return CompletableFuture.supplyAsync( () -> - new OpenTelemetryRateLimiter( + new TokenBucketRateLimiter( requestsPerSecond, requestsPerSecond * windowSeconds, CLOCK)); } } diff --git a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterTest.java b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterFilterTest.java similarity index 97% rename from polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterTest.java rename to polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterFilterTest.java index 3f4d160cf..2ac4fdea5 100644 --- a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterTest.java +++ b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/RateLimiterFilterTest.java @@ -39,7 +39,7 @@ PolarisConnectionExtension.class, SnowmanCredentialsExtension.class }) -public class RateLimiterTest { +public class RateLimiterFilterTest { private static final DropwizardAppExtension EXT = new DropwizardAppExtension<>( PolarisApplication.class, @@ -56,7 +56,7 @@ public class RateLimiterTest { @BeforeAll public static void setup(PolarisConnectionExtension.PolarisToken userToken) { realm = PolarisConnectionExtension.getTestRealm(PolarisApplicationIntegrationTest.class); - RateLimiterTest.userToken = userToken.token(); + RateLimiterFilterTest.userToken = userToken.token(); } @Test diff --git a/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiterTest.java b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiterTest.java new file mode 100644 index 000000000..2c8de0031 --- /dev/null +++ b/polaris-service/src/test/java/org/apache/polaris/service/ratelimiter/TokenBucketRateLimiterTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.service.ratelimiter; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** Main unit test class for TokenBucketRateLimiter */ +public class TokenBucketRateLimiterTest { + @Test + void testBasic() { + MockClock clock = new MockClock(); + clock.setSeconds(5); + + RateLimitResultAsserter asserter = + new RateLimitResultAsserter(new TokenBucketRateLimiter(10, 100, clock)); + + asserter.canAcquire(100); + asserter.cantAcquire(); + + clock.setSeconds(6); + asserter.canAcquire(10); + asserter.cantAcquire(); + + clock.setSeconds(16); + asserter.canAcquire(100); + asserter.cantAcquire(); + } + + /** + * Starts several threads that try to query the rate limiter at the same time, ensuring that we + * only allow "maxTokens" requests + */ + @Test + void testConcurrent() throws InterruptedException { + int maxTokens = 100; + int numTasks = 50000; + int tokensPerSecond = 10; // Can be anything above 0 + int sleepPerNThreads = 100; // Making this too low will result in the test taking a long time + int maxSleepMillis = 5; + + TokenBucketRateLimiter rl = + new TokenBucketRateLimiter(tokensPerSecond, maxTokens, new MockClock()); + AtomicInteger numAcquired = new AtomicInteger(); + CountDownLatch startLatch = new CountDownLatch(numTasks); + CountDownLatch endLatch = new CountDownLatch(numTasks); + + try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < numTasks; i++) { + int i_ = i; + executor.submit( + () -> { + try { + // Enforce that tasks pause until all tasks are submitted + startLatch.countDown(); + startLatch.await(); + + // Make some threads sleep + if (i_ % sleepPerNThreads == 0) { + Thread.sleep((int) (Math.random() * (maxSleepMillis + 1))); + } + + if (rl.tryAcquire()) { + numAcquired.incrementAndGet(); + } + endLatch.countDown(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + } + + endLatch.await(); + Assertions.assertEquals(maxTokens, numAcquired.get()); + } + + static class RateLimitResultAsserter { + private final RateLimiter rateLimiter; + + RateLimitResultAsserter(RateLimiter rateLimiter) { + this.rateLimiter = rateLimiter; + } + + private void canAcquire(int times) { + for (int i = 0; i < times; i++) { + Assertions.assertTrue(rateLimiter.tryAcquire()); + } + } + + private void cantAcquire() { + for (int i = 0; i < 5; i++) { + Assertions.assertFalse(rateLimiter.tryAcquire()); + } + } + } +}