Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RateLimiter - consider whole operation execution time #251

Merged
merged 11 commits into from
Dec 17, 2024
88 changes: 88 additions & 0 deletions core/src/main/scala/ox/resilience/DurationRateLimiter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package ox.resilience

import ox.{Ox, forkDiscard}
import scala.annotation.tailrec
import scala.concurrent.duration.FiniteDuration

/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. operationMode decides
* if whole time of execution should be considered or just the start.
*/
class DurationRateLimiter private (algorithm: DurationRateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm.acquire()
algorithm.startOperation()
val result = operation
algorithm.endOperation()
result

/** Runs or drops the operation, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
if algorithm.tryAcquire() then
algorithm.startOperation()
val result = operation
algorithm.endOperation()
Some(result)
else None

end DurationRateLimiter

object DurationRateLimiter:
def apply(algorithm: DurationRateLimiterAlgorithm)(using Ox): DurationRateLimiter =
@tailrec
def update(): Unit =
val waitTime = algorithm.getNextUpdate
val millis = waitTime / 1000000
val nanos = waitTime % 1000000
Thread.sleep(millis, nanos.toInt)
algorithm.update()
update()
end update

forkDiscard(update())
new DurationRateLimiter(algorithm)
end apply

/** Creates a rate limiter using a fixed window algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within a time [[window]].
* @param window
* Interval of time between replenishing the rate limiter. The rate limiter is replenished to allow up to [[maxOperations]] in the next
* time window.
*/
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using
Ox
): DurationRateLimiter =
apply(DurationRateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within any [[window]] of time.
* @param window
* Length of the window.
*/
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): DurationRateLimiter =
apply(DurationRateLimiterAlgorithm.SlidingWindow(maxOperations, window))

/** Creates a rate limiter with token/leaky bucket algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxTokens
* Max capacity of tokens in the algorithm, limiting the operations that are allowed to **start** concurrently.
* @param refillInterval
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): DurationRateLimiter =
apply(DurationRateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
end DurationRateLimiter
158 changes: 158 additions & 0 deletions core/src/main/scala/ox/resilience/DurationRateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package ox.resilience

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.duration.FiniteDuration

trait DurationRateLimiterAlgorithm extends RateLimiterAlgorithm:

def startOperation(permits: Int): Unit

def endOperation(permits: Int): Unit

final def startOperation(): Unit = startOperation(1)

final def endOperation(): Unit = endOperation(1)

end DurationRateLimiterAlgorithm

object DurationRateLimiterAlgorithm:
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
case class FixedWindow(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastUpdate.set(now)
semaphore.release(rate - semaphore.availablePermits() - runningOperations.get())
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
end update

def startOperation(permits: Int): Unit =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
runningOperations.updateAndGet(current => (current - permits).max(0))
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
()

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)
addTimestampToLog(permits)

def tryAcquire(permits: Int): Boolean =
if semaphore.tryAcquire(permits) then
addTimestampToLog(permits)
true
else false

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
log.updateAndGet { q =>
q.enqueue((now, permits))
}
()

def getNextUpdate: Long =
log.get().headOption match
case None =>
// no logs so no need to update until `per` has passed
per.toNanos
case Some(record) =>
// oldest log provides the new updating point
val waitTime = record._1 + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def startOperation(permits: Int): Unit =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
runningOperations.updateAndGet(current => (current - permits).max(0))
addTimestampToLog(permits)
()

def update(): Unit =
val now = System.nanoTime()
// retrieving current queue to append it later if some elements were added concurrently
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
emil-bar marked this conversation as resolved.
Show resolved Hide resolved
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
end update

@tailrec
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
q.dequeueOption match
case None => q
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(0.max(permits - runningOperations.get()))
removeRecords(tail, now)
else q

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start a new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(1)
private val runningOperations = AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastRefillTime.set(now)
if (semaphore.availablePermits() + runningOperations.get()) < rate then semaphore.release()

def startOperation(permits: Int): Unit =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
runningOperations.updateAndGet(current => (current - permits).max(0))
()

end LeakyBucket

end DurationRateLimiterAlgorithm
65 changes: 52 additions & 13 deletions core/src/main/scala/ox/resilience/RateLimiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,43 @@ package ox.resilience
import scala.concurrent.duration.FiniteDuration
import ox.*

import java.util.concurrent.Semaphore
import scala.annotation.tailrec

/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. operationMode decides
* if whole time of execution should be considered or just the start.
*/
class RateLimiter private (algorithm: RateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm.acquire()
operation
algorithm match
case alg: DurationRateLimiterAlgorithm =>
emil-bar marked this conversation as resolved.
Show resolved Hide resolved
alg.acquire()
alg.startOperation()
val result = operation
alg.endOperation()
result
case alg: RateLimiterAlgorithm =>
alg.acquire()
operation

/** Runs or drops the operation, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
if algorithm.tryAcquire() then Some(operation)
else None
algorithm match
case alg: DurationRateLimiterAlgorithm =>
if alg.tryAcquire() then
alg.startOperation()
val result = operation
alg.endOperation()
Some(result)
else None
case alg: RateLimiterAlgorithm =>
if alg.tryAcquire() then Some(operation)
else None

end RateLimiter

Expand All @@ -46,11 +66,15 @@ object RateLimiter:
* @param maxOperations
* Maximum number of operations that are allowed to **start** within a time [[window]].
* @param window
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next
* Interval of time between replenishing the rate limiter. The rate limiter is replenished to allow up to [[maxOperations]] in the next
* time window.
*/
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
def fixedWindow(maxOperations: Int, window: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
*
Expand All @@ -61,10 +85,14 @@ object RateLimiter:
* @param window
* Length of the window.
*/
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
def slidingWindow(maxOperations: Int, window: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.SlidingWindow(maxOperations, window))

/** Rate limiter with token/leaky bucket algorithm.
/** Creates a rate limiter with token/leaky bucket algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
Expand All @@ -73,6 +101,17 @@ object RateLimiter:
* @param refillInterval
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))

end RateLimiter

/** Decides if RateLimiter should consider only start of an operation or whole time of execution.
*/
enum RateLimiterMode:
case OperationStart
case OperationDuration
4 changes: 3 additions & 1 deletion core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ trait RateLimiterAlgorithm:
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
def getNextUpdate: Long

def rate: Int

end RateLimiterAlgorithm

object RateLimiterAlgorithm:
Expand Down Expand Up @@ -117,7 +119,7 @@ object RateLimiterAlgorithm:

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */
/** Token/leaky bucket algorithm It adds a token to start a new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
Expand Down
Loading
Loading