Skip to content

Commit

Permalink
remove DurationRateLimiter, DurationRateLimiterAlgorithm extends Rate…
Browse files Browse the repository at this point in the history
…LimiterAlgorithm
  • Loading branch information
Kamil-Lontkowski committed Dec 16, 2024
1 parent d4cb818 commit e4b796d
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 190 deletions.
88 changes: 0 additions & 88 deletions core/src/main/scala/ox/resilience/DurationRateLimiter.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,9 @@ import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.duration.FiniteDuration

trait DurationRateLimiterAlgorithm extends RateLimiterAlgorithm:

def startOperation(permits: Int): Unit

def endOperation(permits: Int): Unit

final def startOperation(): Unit = startOperation(1)

final def endOperation(): Unit = endOperation(1)

end DurationRateLimiterAlgorithm

object DurationRateLimiterAlgorithm:
case class FixedWindow(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
/** Fixed window algorithm: allows to run at most `rate` operations in consecutively segments of duration `per`. */
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)
Expand All @@ -42,32 +31,25 @@ object DurationRateLimiterAlgorithm:
semaphore.release(rate - semaphore.availablePermits() - runningOperations.get())
end update

def startOperation(permits: Int): Unit =
def runOperation[T](operation: => T, permits: Int): T =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
val result = operation
runningOperations.updateAndGet(current => (current - permits).max(0))
()
result

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully
/** Sliding window algorithm: allows to run at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after finishing running operation
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)
addTimestampToLog(permits)

def tryAcquire(permits: Int): Boolean =
if semaphore.tryAcquire(permits) then
addTimestampToLog(permits)
true
else false
semaphore.tryAcquire(permits)

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
Expand All @@ -87,14 +69,11 @@ object DurationRateLimiterAlgorithm:
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def startOperation(permits: Int): Unit =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
runningOperations.updateAndGet(current => (current - permits).max(0))
def runOperation[T](operation: => T, permits: Int): T =
val result = operation
// Consider end of operation as a point to release permit after `per` passes
addTimestampToLog(permits)
()
result

def update(): Unit =
val now = System.nanoTime()
Expand All @@ -103,11 +82,12 @@ object DurationRateLimiterAlgorithm:
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
()
end update

@tailrec
Expand All @@ -117,42 +97,12 @@ object DurationRateLimiterAlgorithm:
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(0.max(permits - runningOperations.get()))
semaphore.release(permits)
removeRecords(tail, now)
else q
end match
end removeRecords

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start a new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends DurationRateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(1)
private val runningOperations = AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastRefillTime.set(now)
if (semaphore.availablePermits() + runningOperations.get()) < rate then semaphore.release()

def startOperation(permits: Int): Unit =
runningOperations.updateAndGet(_ + permits)
()

def endOperation(permits: Int): Unit =
runningOperations.updateAndGet(current => (current - permits).max(0))
()

end LeakyBucket

end DurationRateLimiterAlgorithm
31 changes: 6 additions & 25 deletions core/src/main/scala/ox/resilience/RateLimiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,17 @@ import scala.annotation.tailrec
class RateLimiter private (algorithm: RateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm match
case alg: DurationRateLimiterAlgorithm =>
alg.acquire()
alg.startOperation()
val result = operation
alg.endOperation()
result
case alg: RateLimiterAlgorithm =>
alg.acquire()
operation
algorithm.acquire()
algorithm.runOperation(operation)

/** Runs or drops the operation, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
algorithm match
case alg: DurationRateLimiterAlgorithm =>
if alg.tryAcquire() then
alg.startOperation()
val result = operation
alg.endOperation()
Some(result)
else None
case alg: RateLimiterAlgorithm =>
if alg.tryAcquire() then Some(operation)
else None
if algorithm.tryAcquire() then Some(algorithm.runOperation(operation))
else None

end RateLimiter

Expand Down Expand Up @@ -101,12 +84,10 @@ object RateLimiter:
* @param refillInterval
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))

end RateLimiter

Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ trait RateLimiterAlgorithm:
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
def getNextUpdate: Long

def rate: Int
/** Runs operation. For cases where execution time is not needed it just returns result */
final def runOperation[T](operation: => T): T = runOperation(operation, 1)

/** Runs operation. For cases where execution time is not needed it just returns result */
def runOperation[T](operation: => T, permits: Int): T

end RateLimiterAlgorithm

Expand All @@ -56,6 +60,8 @@ object RateLimiterAlgorithm:
semaphore.release(rate - semaphore.availablePermits())
end update

def runOperation[T](operation: => T, permits: Int): T = operation

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
Expand Down Expand Up @@ -99,11 +105,12 @@ object RateLimiterAlgorithm:
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
()
end update

@tailrec
Expand All @@ -117,6 +124,8 @@ object RateLimiterAlgorithm:
removeRecords(tail, now)
else q

def runOperation[T](operation: => T, permits: Int): T = operation

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start a new operation each `per` with a maximum number of tokens of `rate`. */
Expand All @@ -140,5 +149,7 @@ object RateLimiterAlgorithm:
lastRefillTime.set(now)
if semaphore.availablePermits() < rate then semaphore.release()

def runOperation[T](operation: => T, permits: Int): T = operation

end LeakyBucket
end RateLimiterAlgorithm
Loading

0 comments on commit e4b796d

Please sign in to comment.