From 6bdd3d0435157c3f64ca78253458eb5b580bec45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Femen=C3=ADa?= <131800808+pablf@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:54:28 +0100 Subject: [PATCH] Add rate limiter primitives (#235) Co-authored-by: pablf Co-authored-by: adamw --- core/src/main/scala/ox/fork.scala | 7 +- .../scala/ox/resilience/RateLimiter.scala | 78 +++ .../ox/resilience/RateLimiterAlgorithm.scala | 142 ++++++ .../resilience/RateLimiterInterfaceTest.scala | 126 +++++ .../scala/ox/resilience/RateLimiterTest.scala | 470 ++++++++++++++++++ doc/index.md | 1 + doc/utils/rate-limiter.md | 57 +++ 7 files changed, 880 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/ox/resilience/RateLimiter.scala create mode 100644 core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala create mode 100644 core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala create mode 100644 core/src/test/scala/ox/resilience/RateLimiterTest.scala create mode 100644 doc/utils/rate-limiter.md diff --git a/core/src/main/scala/ox/fork.scala b/core/src/main/scala/ox/fork.scala index cd1681f0..6fc1337f 100644 --- a/core/src/main/scala/ox/fork.scala +++ b/core/src/main/scala/ox/fork.scala @@ -184,11 +184,16 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] = end new end forkCancellable -/** Same as [[fork]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side-effects, +/** Same as [[fork]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side effects, * it's not possible to join it. */ inline def forkDiscard[T](inline f: T)(using Ox): Unit = fork(f).discard +/** Same as [[forkUser]], but discards the resulting [[Fork]], to avoid compiler warnings. That is, the fork is run only for its side + * effects, it's not possible to join it. + */ +inline def forkUserDiscard[T](inline f: T)(using Ox): Unit = forkUser(f).discard + private trait ForkUsingResult[T](result: CompletableFuture[T]) extends Fork[T]: override def join(): T = unwrapExecutionException(result.get()) override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = diff --git a/core/src/main/scala/ox/resilience/RateLimiter.scala b/core/src/main/scala/ox/resilience/RateLimiter.scala new file mode 100644 index 00000000..cdae50a1 --- /dev/null +++ b/core/src/main/scala/ox/resilience/RateLimiter.scala @@ -0,0 +1,78 @@ +package ox.resilience + +import scala.concurrent.duration.FiniteDuration +import ox.* + +import scala.annotation.tailrec + +/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */ +class RateLimiter private (algorithm: RateLimiterAlgorithm): + /** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */ + def runBlocking[T](operation: => T): T = + algorithm.acquire() + operation + + /** Runs or drops the operation, if the rate limit is reached. + * + * @return + * `Some` if the operation has been allowed to run, `None` if the operation has been dropped. + */ + def runOrDrop[T](operation: => T): Option[T] = + if algorithm.tryAcquire() then Some(operation) + else None + +end RateLimiter + +object RateLimiter: + def apply(algorithm: RateLimiterAlgorithm)(using Ox): RateLimiter = + @tailrec + def update(): Unit = + val waitTime = algorithm.getNextUpdate + val millis = waitTime / 1000000 + val nanos = waitTime % 1000000 + Thread.sleep(millis, nanos.toInt) + algorithm.update() + update() + end update + + forkDiscard(update()) + new RateLimiter(algorithm) + end apply + + /** Creates a rate limiter using a fixed window algorithm. + * + * Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. + * + * @param maxOperations + * Maximum number of operations that are allowed to **start** within a time [[window]]. + * @param window + * Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next + * time window. + */ + def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter = + apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window)) + + /** Creates a rate limiter using a sliding window algorithm. + * + * Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. + * + * @param maxOperations + * Maximum number of operations that are allowed to **start** within any [[window]] of time. + * @param window + * Length of the window. + */ + def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter = + apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window)) + + /** Rate limiter with token/leaky bucket algorithm. + * + * Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter. + * + * @param maxTokens + * Max capacity of tokens in the algorithm, limiting the operations that are allowed to **start** concurrently. + * @param refillInterval + * Interval of time between adding a single token to the bucket. + */ + def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter = + apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval)) +end RateLimiter diff --git a/core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala b/core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala new file mode 100644 index 00000000..bb667cf5 --- /dev/null +++ b/core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala @@ -0,0 +1,142 @@ +package ox.resilience + +import scala.concurrent.duration.FiniteDuration +import scala.collection.immutable.Queue +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.Semaphore +import scala.annotation.tailrec + +/** Determines the algorithm to use for the rate limiter */ +trait RateLimiterAlgorithm: + + /** Acquires a permit to execute the operation. This method should block until a permit is available. */ + final def acquire(): Unit = + acquire(1) + + /** Acquires permits to execute the operation. This method should block until a permit is available. */ + def acquire(permits: Int): Unit + + /** Tries to acquire a permit to execute the operation. This method should not block. */ + final def tryAcquire(): Boolean = + tryAcquire(1) + + /** Tries to acquire permits to execute the operation. This method should not block. */ + def tryAcquire(permits: Int): Boolean + + /** Updates the internal state of the rate limiter to check whether new operations can be accepted. */ + def update(): Unit + + /** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */ + def getNextUpdate: Long + +end RateLimiterAlgorithm + +object RateLimiterAlgorithm: + /** Fixed window algorithm: allows starting at most `rate` operations in consecutively segments of duration `per`. */ + case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: + private val lastUpdate = new AtomicLong(System.nanoTime()) + private val semaphore = new Semaphore(rate) + + def acquire(permits: Int): Unit = + semaphore.acquire(permits) + + def tryAcquire(permits: Int): Boolean = + semaphore.tryAcquire(permits) + + def getNextUpdate: Long = + val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime() + if waitTime > 0 then waitTime else 0L + + def update(): Unit = + val now = System.nanoTime() + lastUpdate.set(now) + semaphore.release(rate - semaphore.availablePermits()) + end update + + end FixedWindow + + /** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */ + case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: + // stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully + private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]()) + private val semaphore = new Semaphore(rate) + + def acquire(permits: Int): Unit = + semaphore.acquire(permits) + addTimestampToLog(permits) + + def tryAcquire(permits: Int): Boolean = + if semaphore.tryAcquire(permits) then + addTimestampToLog(permits) + true + else false + + private def addTimestampToLog(permits: Int): Unit = + val now = System.nanoTime() + log.updateAndGet { q => + q.enqueue((now, permits)) + } + () + + def getNextUpdate: Long = + log.get().headOption match + case None => + // no logs so no need to update until `per` has passed + per.toNanos + case Some(record) => + // oldest log provides the new updating point + val waitTime = record._1 + per.toNanos - System.nanoTime() + if waitTime > 0 then waitTime else 0L + end getNextUpdate + + def update(): Unit = + val now = System.nanoTime() + // retrieving current queue to append it later if some elements were added concurrently + val q = log.getAndUpdate(_ => Queue[(Long, Int)]()) + // remove records older than window size + val qUpdated = removeRecords(q, now) + // merge old records with the ones concurrently added + val _ = log.updateAndGet(qNew => + qNew.foldLeft(qUpdated) { case (queue, record) => + queue.enqueue(record) + } + ) + end update + + @tailrec + private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] = + q.dequeueOption match + case None => q + case Some((head, tail)) => + if head._1 + per.toNanos < now then + val (_, permits) = head + semaphore.release(permits) + removeRecords(tail, now) + else q + + end SlidingWindow + + /** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */ + case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm: + private val refillInterval = per.toNanos + private val lastRefillTime = new AtomicLong(System.nanoTime()) + private val semaphore = new Semaphore(1) + + def acquire(permits: Int): Unit = + semaphore.acquire(permits) + + def tryAcquire(permits: Int): Boolean = + semaphore.tryAcquire(permits) + + def getNextUpdate: Long = + val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime() + if waitTime > 0 then waitTime else 0L + + def update(): Unit = + val now = System.nanoTime() + lastRefillTime.set(now) + if semaphore.availablePermits() < rate then semaphore.release() + + end LeakyBucket +end RateLimiterAlgorithm diff --git a/core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala b/core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala new file mode 100644 index 00000000..047106df --- /dev/null +++ b/core/src/test/scala/ox/resilience/RateLimiterInterfaceTest.scala @@ -0,0 +1,126 @@ +package ox.resilience + +import ox.* +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{EitherValues, TryValues} +import scala.concurrent.duration.* + +class RateLimiterInterfaceTest extends AnyFlatSpec with Matchers with EitherValues with TryValues: + behavior of "RateLimiter interface" + + it should "drop or block operation depending on method used for fixed rate algorithm" in { + supervised: + val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second")) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + val result4 = rateLimiter.runBlocking(operation) + val result5 = rateLimiter.runBlocking(operation) + val result6 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + result4 shouldBe 0 + result5 shouldBe 0 + result6 shouldBe None + executions shouldBe 4 + } + + it should "drop or block operation depending on method used for sliding window algorithm" in { + supervised: + val rateLimiter = RateLimiter.slidingWindow(2, FiniteDuration(1, "second")) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + val result4 = rateLimiter.runBlocking(operation) + val result5 = rateLimiter.runBlocking(operation) + val result6 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + result4 shouldBe 0 + result5 shouldBe 0 + result6 shouldBe None + executions shouldBe 4 + } + + it should "drop or block operation depending on method used for bucket algorithm" in { + supervised: + val rateLimiter = RateLimiter.leakyBucket(2, FiniteDuration(1, "second")) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + val result4 = rateLimiter.runBlocking(operation) + val result5 = rateLimiter.runBlocking(operation) + val result6 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe None + result3 shouldBe None + result4 shouldBe 0 + result5 shouldBe 0 + result6 shouldBe None + executions shouldBe 3 + } + + it should "drop or block operation concurrently" in { + supervised: + val rateLimiter = RateLimiter.fixedWindow(2, FiniteDuration(1, "second")) + + def operation = 0 + + var result1: Option[Int] = Some(-1) + var result2: Option[Int] = Some(-1) + var result3: Option[Int] = Some(-1) + var result4: Int = -1 + var result5: Int = -1 + var result6: Int = -1 + + // run two operations to block the rate limiter + rateLimiter.runOrDrop(operation).discard + rateLimiter.runOrDrop(operation).discard + + // operations with runOrDrop should be dropped while operations with runBlocking should wait + supervised: + forkUserDiscard: + result1 = rateLimiter.runOrDrop(operation) + forkUserDiscard: + result2 = rateLimiter.runOrDrop(operation) + forkUserDiscard: + result3 = rateLimiter.runOrDrop(operation) + forkUserDiscard: + result4 = rateLimiter.runBlocking(operation) + forkUserDiscard: + result5 = rateLimiter.runBlocking(operation) + forkUserDiscard: + result6 = rateLimiter.runBlocking(operation) + + result1 shouldBe None + result2 shouldBe None + result3 shouldBe None + result4 shouldBe 0 + result5 shouldBe 0 + result6 shouldBe 0 + } +end RateLimiterInterfaceTest diff --git a/core/src/test/scala/ox/resilience/RateLimiterTest.scala b/core/src/test/scala/ox/resilience/RateLimiterTest.scala new file mode 100644 index 00000000..17858148 --- /dev/null +++ b/core/src/test/scala/ox/resilience/RateLimiterTest.scala @@ -0,0 +1,470 @@ +package ox.resilience + +import ox.* +import ox.util.ElapsedTime +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{EitherValues, TryValues} +import scala.concurrent.duration.* +import java.util.concurrent.atomic.AtomicReference + +class RateLimiterTest extends AnyFlatSpec with Matchers with EitherValues with TryValues with ElapsedTime: + behavior of "fixed rate RateLimiter" + + it should "drop operation when rate limit is exceeded" in { + supervised: + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.FixedWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + executions shouldBe 2 + } + + it should "restart rate limiter after given duration" in { + supervised: + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.FixedWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + ox.sleep(1.second) + ox.sleep(100.milliseconds) // make sure the rate limiter is replenished + val result4 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + result4 shouldBe Some(0) + executions shouldBe 3 + } + + it should "block operation when rate limit is exceeded" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.FixedWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val before = System.currentTimeMillis() + val result1 = rateLimiter.runBlocking(operation) + val result2 = rateLimiter.runBlocking(operation) + val result3 = rateLimiter.runBlocking(operation) + val after = System.currentTimeMillis() + + result1 shouldBe 0 + result2 shouldBe 0 + result3 shouldBe 0 + (after - before) should be >= 1000L + executions shouldBe 3 + } + } + + it should "respect time constraints when blocking" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.FixedWindow(2, FiniteDuration(1, "second")) + ) + + var order = List.empty[Int] + def operationN(n: Int) = + rateLimiter.runBlocking { + order = n :: order + n + } + + val time1 = System.currentTimeMillis() // 0 seconds + val result1 = operationN(1) + val result2 = operationN(2) + val result3 = operationN(3) // blocks until 1 second elapsed + val time2 = System.currentTimeMillis() // 1 second + val result4 = operationN(4) + val result5 = operationN(5) // blocks until 2 seconds elapsed + val time3 = System.currentTimeMillis() + val result6 = operationN(6) + val result7 = operationN(7) // blocks until 3 seconds elapsed + val time4 = System.currentTimeMillis() + val result8 = operationN(8) + val result9 = operationN(9) // blocks until 4 seconds elapsed + val time5 = System.currentTimeMillis() + + result1 shouldBe 1 + result2 shouldBe 2 + result3 shouldBe 3 + result4 shouldBe 4 + result5 shouldBe 5 + result6 shouldBe 6 + result7 shouldBe 7 + result8 shouldBe 8 + result9 shouldBe 9 + (time2 - time1) should be >= 1000L - 10 + (time3 - time1) should be >= 2000L - 10 + (time4 - time1) should be >= 3000L - 10 + (time5 - time1) should be >= 4000L - 10 + (time2 - time1) should be <= 1200L + (time3 - time1) should be <= 2200L + (time4 - time1) should be <= 3200L + (time5 - time1) should be <= 4200L + order should be(List(9, 8, 7, 6, 5, 4, 3, 2, 1)) + } + } + + it should "respect time constraints when blocking concurrently" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.FixedWindow(2, FiniteDuration(1, "second")) + ) + + val order = new AtomicReference(List.empty[Int]) + def operationN(n: Int) = + rateLimiter.runBlocking { + order.updateAndGet(ord => n :: ord) + n + } + + val before = System.currentTimeMillis() // 0 seconds + supervised { + forkUserDiscard: + operationN(1) + forkUserDiscard: + sleep(50.millis) + operationN(2) + forkUserDiscard: + sleep(100.millis) + operationN(3) + forkUserDiscard: + sleep(150.millis) + operationN(4) + forkUserDiscard: + sleep(200.millis) + operationN(5) + forkUserDiscard: + sleep(250.millis) + operationN(6) + forkUserDiscard: + sleep(300.millis) + operationN(7) + forkUserDiscard: + sleep(350.millis) + operationN(8) + forkUserDiscard: + sleep(400.millis) + operationN(9) + } + val after = System.currentTimeMillis() + + (after - before) should be >= 4000L - 10 + (after - before) should be <= 4200L + } + } + + behavior of "sliding window RateLimiter" + it should "drop operation when rate limit is exceeded" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.SlidingWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + executions shouldBe 2 + } + } + + it should "restart rate limiter after given duration" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.SlidingWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + val result3 = rateLimiter.runOrDrop(operation) + ox.sleep(1.second) + ox.sleep(100.milliseconds) // make sure the rate limiter is replenished + val result4 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe Some(0) + result3 shouldBe None + result4 shouldBe Some(0) + executions shouldBe 3 + } + } + + it should "block operation when rate limit is exceeded" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.SlidingWindow(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val ((result1, result2, result3), timeElapsed) = measure { + val r1 = rateLimiter.runBlocking(operation) + val r2 = rateLimiter.runBlocking(operation) + val r3 = rateLimiter.runBlocking(operation) + (r1, r2, r3) + } + + result1 shouldBe 0 + result2 shouldBe 0 + result3 shouldBe 0 + timeElapsed.toMillis should be >= 1000L - 10 + executions shouldBe 3 + } + } + + it should "respect time constraints when blocking" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.SlidingWindow(2, FiniteDuration(1, "second")) + ) + + var order = List.empty[Int] + def operationN(n: Int) = + rateLimiter.runBlocking { + order = n :: order + n + } + + val time1 = System.currentTimeMillis() // 0 seconds + val result1 = operationN(1) + ox.sleep(500.millis) + val result2 = operationN(2) + val result3 = operationN(3) // blocks until 1 second elapsed + val time2 = System.currentTimeMillis() // 1 second + val result4 = operationN(4) + val time3 = System.currentTimeMillis() // blocks until 1.5 seconds elapsed + + result1 shouldBe 1 + result2 shouldBe 2 + result3 shouldBe 3 + result4 shouldBe 4 + (time2 - time1) should be >= 1000L - 10 + (time3 - time1) should be >= 1500L - 10 + (time2 - time1) should be <= 1200L + (time3 - time1) should be <= 1700L + order should be(List(4, 3, 2, 1)) + } + } + + it should "respect time constraints when blocking concurrently" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.SlidingWindow(2, FiniteDuration(1, "second")) + ) + + val order = new AtomicReference(List.empty[Int]) + def operationN(n: Int) = + rateLimiter.runBlocking { + order.updateAndGet(ord => n :: ord) + n + } + + val before = System.currentTimeMillis() // 0 seconds + supervised { + forkUserDiscard: + operationN(1) + forkUserDiscard: + sleep(300.millis) + operationN(2) + forkUserDiscard: + sleep(400.millis) + operationN(3) + forkUserDiscard: + sleep(700.millis) + operationN(4) + } + val after = System.currentTimeMillis + + (after - before) should be >= 1300L - 10 + (after - before) should be <= 1400L + } + } + + behavior of "bucket RateLimiter" + + it should "drop operation when rate limit is exceeded" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.LeakyBucket(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + val result2 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe None + executions shouldBe 1 + } + } + + it should "refill token after time elapsed from last refill and not before" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.LeakyBucket(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val result1 = rateLimiter.runOrDrop(operation) + ox.sleep(500.millis) + val result2 = rateLimiter.runOrDrop(operation) + ox.sleep(600.millis) + val result3 = rateLimiter.runOrDrop(operation) + + result1 shouldBe Some(0) + result2 shouldBe None + result3 shouldBe Some(0) + executions shouldBe 2 + } + } + + it should "block operation when rate limit is exceeded" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.LeakyBucket(2, FiniteDuration(1, "second")) + ) + + var executions = 0 + def operation = + executions += 1 + 0 + + val ((result1, result2, result3), timeElapsed) = measure { + val r1 = rateLimiter.runBlocking(operation) + val r2 = rateLimiter.runBlocking(operation) + val r3 = rateLimiter.runBlocking(operation) + (r1, r2, r3) + } + + result1 shouldBe 0 + result2 shouldBe 0 + timeElapsed.toMillis should be >= 1000L - 10 + executions shouldBe 3 + } + } + + it should "respect time constraints when blocking" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.LeakyBucket(2, FiniteDuration(1, "second")) + ) + + var order = List.empty[Int] + def operationN(n: Int) = + rateLimiter.runBlocking { + order = n :: order + n + } + + val time1 = System.currentTimeMillis() // 0 seconds + val result1 = operationN(1) + val result2 = operationN(2) + val time2 = System.currentTimeMillis() // 1 second + sleep(2.seconds) // add 2 tokens + val result3 = operationN(3) // blocks until 1 second elapsed + val result4 = operationN(4) // blocks until 2 seconds elapsed + val time3 = System.currentTimeMillis() + val result5 = operationN(5) // blocks until 2 seconds elapsed + val time4 = System.currentTimeMillis() + + result1 shouldBe 1 + result2 shouldBe 2 + result3 shouldBe 3 + result4 shouldBe 4 + result5 shouldBe 5 + (time2 - time1) should be >= 1000L - 10 + (time3 - time1) should be >= 3000L - 10 + (time4 - time1) should be >= 4000L - 10 + (time2 - time1) should be <= 1200L + (time3 - time1) should be <= 3200L + (time4 - time1) should be <= 4200L + order should be(List(5, 4, 3, 2, 1)) + } + } + + it should "respect time constraints when blocking concurrently" in { + supervised { + val rateLimiter = RateLimiter( + RateLimiterAlgorithm.LeakyBucket(2, FiniteDuration(1, "second")) + ) + + val order = new AtomicReference(List.empty[Int]) + def operationN(n: Int) = + rateLimiter.runBlocking { + order.updateAndGet(ord => n :: ord) + n + } + + val before = System.currentTimeMillis() + supervised { + forkUserDiscard: + operationN(1) + forkUserDiscard: + sleep(50.millis) + operationN(2) + forkUserDiscard: + sleep(100.millis) + operationN(3) + forkUserDiscard: + sleep(150.millis) + operationN(4) + } + val after = System.currentTimeMillis() + + (after - before) should be >= 3000L - 10 + (after - before) should be <= 3200L + } + } + +end RateLimiterTest diff --git a/doc/index.md b/doc/index.md index 6ba81497..bbfdc96c 100644 --- a/doc/index.md +++ b/doc/index.md @@ -68,6 +68,7 @@ In addition to this documentation, ScalaDocs can be browsed at [https://javadoc. utils/oxapp utils/retries + utils/rate-limiter utils/repeat utils/scheduled utils/resources diff --git a/doc/utils/rate-limiter.md b/doc/utils/rate-limiter.md new file mode 100644 index 00000000..5885c1fb --- /dev/null +++ b/doc/utils/rate-limiter.md @@ -0,0 +1,57 @@ +# Rate limiter + +The rate limiter mechanism allows controlling the rate at which operations are executed. It ensures that at most a certain number of operations are run concurrently within a specified time frame, preventing system overload and ensuring fair resource usage. Note that the implemented limiting mechanism only takes into account the start of execution and not the whole execution of an operation. + +## API + +Basic rate limiter usage: + +```scala mdoc:compile-only +import ox.supervised +import ox.resilience.* +import scala.concurrent.duration.* + +val algorithm = RateLimiterAlgorithm.FixedWindow(2, 1.second) + +supervised: + val rateLimiter = RateLimiter(algorithm) + + type T + def operation: T = ??? + + val blockedOperation: T = rateLimiter.runBlocking(operation) + val droppedOperation: Option[T] = rateLimiter.runOrDrop(operation) +``` + +`blockedOperation` will block the operation until the algorithm allows it to be executed. Therefore, the return type is the same as the operation. On the other hand, if the algorithm doesn't allow execution of more operations, `runOrDrop` will drop the operation returning `None` and wrapping the result in `Some` when the operation is successfully executed. + +A rate limiter must be created within an `Ox` [concurrency scope](../structured-concurrency/fork-join.md), as a background fork is created, to replenish the rate limiter. Once the scope ends, the rate limiter is stops as well. + +## Operation definition + +The `operation` can be provided directly using a by-name parameter, i.e. `f: => T`. + +## Configuration + +The configuration of a `RateLimiter` depends on an underlying algorithm that controls whether an operation can be executed or not. The following algorithms are available: +- `RateLimiterAlgorithm.FixedWindow(rate: Int, dur: FiniteDuration)` - where `rate` is the maximum number of operations to be executed in fixed windows of `dur` duration. +- `RateLimiterAlgorithm.SlidingWindow(rate: Int, dur: FiniteDuration)` - where `rate` is the maximum number of operations to be executed in any window of time of duration `dur`. +- `RateLimiterAlgorithm.Bucket(maximum: Int, dur: FiniteDuration)` - where `maximum` is the maximum capacity of tokens available in the token bucket algorithm and one token is added each `dur`. It can represent both the leaky bucket algorithm or the token bucket algorithm. + +### API shorthands + +You can use one of the following shorthands to define a Rate Limiter with the corresponding algorithm: + +- `RateLimiter.fixedWindow(rate: Int, dur: FiniteDuration)`, +- `RateLimiter.slidingWindow(rate: Int, dur: FiniteDuration)`, +- `RateLimiter.leakyBucket(maximum: Int, dur: FiniteDuration)`, + +See the tests in `ox.resilience.*` for more. + +## Custom rate limiter algorithms + +The `RateLimiterAlgorithm` employed by `RateLimiter` can be extended to implement new algorithms or modify existing ones. Its interface is modelled like that of a `Semaphore` although the underlying implementation could be different. For best compatibility with the existing interface of `RateLimiter`, methods `acquire` and `tryAcquire` should offer the same guaranties as Java's `Semaphores`. + +Additionally, there are two methods employed by the `GenericRateLimiter` for updating its internal state automatically: +- `def update(): Unit`: Updates the internal state of the rate limiter to reflect its current situation. Invoked in a background fork repeatedly, when a rate limiter is created. +- `def getNextUpdate: Long`: Returns the time in nanoseconds after which a new `update` needs to be called.