Skip to content

Commit

Permalink
avoid long tail tasks due to PrioritySemaphore (NVIDIA#11574)
Browse files Browse the repository at this point in the history
* use task id as tie breaker

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>

* save threadlocal lookup

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>

---------

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
  • Loading branch information
binmahone committed Oct 11, 2024
1 parent b715ef2 commit b6fbac5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object GpuSemaphore {
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
Expand Down Expand Up @@ -253,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits, lastHeld)
semaphore.acquire(numPermits, lastHeld, taskAttemptId)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand Down Expand Up @@ -333,7 +333,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
if (taskInfo.tryAcquire(semaphore)) {
GpuDeviceManager.initializeFromTask()
Expand All @@ -357,7 +357,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
taskInfo.blockUntilReady(semaphore)
GpuDeviceManager.initializeFromTask()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
private val lock = new ReentrantLock()
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) {
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int, taskId: Long) {
var signaled: Boolean = false
}

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] =
new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse)
new PriorityQueue[ThreadInfo](
// use task id as tie breaker when priorities are equal (both are 0 because never hold lock)
Ordering.by[ThreadInfo, T](_.priority).reverse.
thenComparing((a, b) => a.taskId.compareTo(b.taskId))
)

def tryAcquire(numPermits: Int, priority: T): Boolean = {
lock.lock()
Expand All @@ -52,12 +56,12 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
}
}

def acquire(numPermits: Int, priority: T): Unit = {
def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = {
lock.lock()
try {
if (!tryAcquire(numPermits, priority)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits)
val info = ThreadInfo(priority, condition, numPermits, taskAttemptId)
try {
waitingQueue.add(info)
while (!info.signaled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
}
})
t.start()
Expand All @@ -62,7 +62,7 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
semaphore.acquire(1, priority)
semaphore.acquire(1, priority, 0)
results.add(priority)
semaphore.release(1)
}
Expand All @@ -84,9 +84,9 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

test("low priority thread cannot surpass high priority thread") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(5, 0)
semaphore.acquire(5, 0, 0)
val t = new Thread(() => {
semaphore.acquire(10, 2)
semaphore.acquire(10, 2, 0)
semaphore.release(10)
})
t.start()
Expand Down

0 comments on commit b6fbac5

Please sign in to comment.