Skip to content

Commit

Permalink
Implement our own rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew4699 committed Sep 16, 2024
1 parent 7512182 commit d8093fc
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

/** An implementation of our Clock interface using opentelemetry's Clock implementation */
public class ClockImpl implements Clock {
public ClockImpl() {
}
public ClockImpl() {}

@Override
public long nanoTime() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ public class DefaultRateLimiterFactory implements RateLimiterFactory {

@Override
public Future<RateLimiter> createRateLimiter(String key) {
return CompletableFuture.supplyAsync(
() ->
new OpenTelemetryRateLimiter(
requestsPerSecond, requestsPerSecond * windowSeconds, new ClockImpl()));
return CompletableFuture.completedFuture(
new TokenBucketRateLimiter(
requestsPerSecond, requestsPerSecond * windowSeconds, new ClockImpl()));
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class RateLimiterFilter implements Filter {
private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterFilter.class);
private static final RateLimiter NO_OP_LIMITER = new NoOpRateLimiter();
private static final RateLimiter ALWAYS_REJECT_LIMITER =
new OpenTelemetryRateLimiter(0, 0, new ClockImpl());
new TokenBucketRateLimiter(0, 0, new ClockImpl());
private static final Clock CLOCK = new ClockImpl();

private final RateLimiterConfig config;
Expand All @@ -69,9 +69,9 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha

private RateLimiter maybeBlockToGetRateLimiter(String realm) {
try {
return perRealmLimiters.computeIfAbsent(realm, (key) -> config
.getRateLimiterFactory()
.createRateLimiter(key)).get(config.getConstructionTimeoutMillis(), TimeUnit.MILLISECONDS);
return perRealmLimiters
.computeIfAbsent(realm, (key) -> config.getRateLimiterFactory().createRateLimiter(key))
.get(config.getConstructionTimeoutMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
return getDefaultRateLimiterOnConstructionFailed(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.polaris.service.ratelimiter;

/** Token bucket implementation of a Polaris RateLimiter. */
public class TokenBucketRateLimiter implements RateLimiter {
private final double tokensPerNano;
private final double maxTokens;
private final Clock clock;

private double tokens;
private long lastAcquireNanos;

public TokenBucketRateLimiter(double tokensPerSecond, double maxTokens, Clock clock) {
this.tokensPerNano = tokensPerSecond / 1e9;
this.maxTokens = maxTokens;
this.clock = clock;

tokens = maxTokens;
lastAcquireNanos = clock.nanoTime();
}

@Override
public synchronized boolean tryAcquire() {
// Grant tokens for the time that has passed since our last tryAcquire()
long t = clock.nanoTime();
long nanosPassed = t - lastAcquireNanos;
lastAcquireNanos = t;
tokens = Math.min(maxTokens, tokens + (nanosPassed * tokensPerNano));

// Take a token if they have one available
if (tokens >= 1) {
tokens--;
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ public void setMillis(long millis) {
nanos = millis * 1_000_000;
}

public void setSeconds(long seconds) {
setMillis(seconds * 1000);
}

@Override
public long nanoTime() {
return nanos;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public Future<RateLimiter> createRateLimiter(String key) {
}
return CompletableFuture.supplyAsync(
() ->
new OpenTelemetryRateLimiter(
new TokenBucketRateLimiter(
requestsPerSecond, requestsPerSecond * windowSeconds, CLOCK));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PolarisConnectionExtension.class,
SnowmanCredentialsExtension.class
})
public class RateLimiterTest {
public class RateLimiterFilterTest {
private static final DropwizardAppExtension<PolarisApplicationConfig> EXT =
new DropwizardAppExtension<>(
PolarisApplication.class,
Expand All @@ -56,7 +56,7 @@ public class RateLimiterTest {
@BeforeAll
public static void setup(PolarisConnectionExtension.PolarisToken userToken) {
realm = PolarisConnectionExtension.getTestRealm(PolarisApplicationIntegrationTest.class);
RateLimiterTest.userToken = userToken.token();
RateLimiterFilterTest.userToken = userToken.token();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.polaris.service.ratelimiter;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/** Main unit test class for TokenBucketRateLimiter */
public class TokenBucketRateLimiterTest {
@Test
void testBasic() {
MockClock clock = new MockClock();
clock.setSeconds(5);

RateLimitResultAsserter asserter =
new RateLimitResultAsserter(new TokenBucketRateLimiter(10, 100, clock));

asserter.canAcquire(100);
asserter.cantAcquire();

clock.setSeconds(6);
asserter.canAcquire(10);
asserter.cantAcquire();

clock.setSeconds(16);
asserter.canAcquire(100);
asserter.cantAcquire();
}

/**
* Starts several threads that try to query the rate limiter at the same time, ensuring that we
* only allow "maxTokens" requests
*/
@Test
void testConcurrent() throws InterruptedException {
int maxTokens = 100;
int numTasks = 50000;
int tokensPerSecond = 10; // Can be anything above 0
int sleepPerNThreads = 100; // Making this too low will result in the test taking a long time
int maxSleepMillis = 5;

TokenBucketRateLimiter rl =
new TokenBucketRateLimiter(tokensPerSecond, maxTokens, new MockClock());
AtomicInteger numAcquired = new AtomicInteger();
CountDownLatch startLatch = new CountDownLatch(numTasks);
CountDownLatch endLatch = new CountDownLatch(numTasks);

try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
for (int i = 0; i < numTasks; i++) {
int i_ = i;
executor.submit(
() -> {
try {
// Enforce that tasks pause until all tasks are submitted
startLatch.countDown();
startLatch.await();

// Make some threads sleep
if (i_ % sleepPerNThreads == 0) {
Thread.sleep((int) (Math.random() * (maxSleepMillis + 1)));
}

if (rl.tryAcquire()) {
numAcquired.incrementAndGet();
}
endLatch.countDown();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}
}

endLatch.await();
Assertions.assertEquals(maxTokens, numAcquired.get());
}

static class RateLimitResultAsserter {
private final RateLimiter rateLimiter;

RateLimitResultAsserter(RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}

private void canAcquire(int times) {
for (int i = 0; i < times; i++) {
Assertions.assertTrue(rateLimiter.tryAcquire());
}
}

private void cantAcquire() {
for (int i = 0; i < 5; i++) {
Assertions.assertFalse(rateLimiter.tryAcquire());
}
}
}
}

0 comments on commit d8093fc

Please sign in to comment.