From b6fbac57091ff4a11c68a3b41d2141f06f9e7500 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Thu, 10 Oct 2024 22:53:17 +0800 Subject: [PATCH] avoid long tail tasks due to PrioritySemaphore (#11574) * use task id as tie breaker Signed-off-by: Hongbin Ma (Mahone) * save threadlocal lookup Signed-off-by: Hongbin Ma (Mahone) --------- Signed-off-by: Hongbin Ma (Mahone) --- .../scala/com/nvidia/spark/rapids/GpuSemaphore.scala | 8 ++++---- .../com/nvidia/spark/rapids/PrioritySemaphore.scala | 12 ++++++++---- .../nvidia/spark/rapids/PrioritySemaphoreSuite.scala | 10 +++++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index fab30853596..c11d1fcb250 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -162,7 +162,7 @@ object GpuSemaphore { * this is considered to be okay as there are other mechanisms in place, and it should be rather * rare. */ -private final class SemaphoreTaskInfo() extends Logging { +private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { /** * This holds threads that are not on the GPU yet. Most of the time they are * blocked waiting for the semaphore to let them on, but it may hold one @@ -253,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging { if (!done && shouldBlockOnSemaphore) { // We cannot be in a synchronized block and wait on the semaphore // so we have to release it and grab it again afterwards. - semaphore.acquire(numPermits, lastHeld) + semaphore.acquire(numPermits, lastHeld, taskAttemptId) synchronized { // We now own the semaphore so we need to wake up all of the other tasks that are // waiting. @@ -333,7 +333,7 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => { onTaskCompletion(context, completeTask) - new SemaphoreTaskInfo() + new SemaphoreTaskInfo(taskAttemptId) }) if (taskInfo.tryAcquire(semaphore)) { GpuDeviceManager.initializeFromTask() @@ -357,7 +357,7 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => { onTaskCompletion(context, completeTask) - new SemaphoreTaskInfo() + new SemaphoreTaskInfo(taskAttemptId) }) taskInfo.blockUntilReady(semaphore) GpuDeviceManager.initializeFromTask() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala index 6fdadf10e72..cdee5ab1c79 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala @@ -27,14 +27,18 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) private val lock = new ReentrantLock() private var occupiedSlots: Int = 0 - private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) { + private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int, taskId: Long) { var signaled: Boolean = false } // We expect a relatively small number of threads to be contending for this lock at any given // time, therefore we are not concerned with the insertion/removal time complexity. private val waitingQueue: PriorityQueue[ThreadInfo] = - new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse) + new PriorityQueue[ThreadInfo]( + // use task id as tie breaker when priorities are equal (both are 0 because never hold lock) + Ordering.by[ThreadInfo, T](_.priority).reverse. + thenComparing((a, b) => a.taskId.compareTo(b.taskId)) + ) def tryAcquire(numPermits: Int, priority: T): Boolean = { lock.lock() @@ -52,12 +56,12 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) } } - def acquire(numPermits: Int, priority: T): Unit = { + def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = { lock.lock() try { if (!tryAcquire(numPermits, priority)) { val condition = lock.newCondition() - val info = ThreadInfo(priority, condition, numPermits) + val info = ThreadInfo(priority, condition, numPermits, taskAttemptId) try { waitingQueue.add(info) while (!info.signaled) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala index 0ba125f60ab..cd9660a5de5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala @@ -39,11 +39,11 @@ class PrioritySemaphoreSuite extends AnyFunSuite { val t = new Thread(() => { try { - semaphore.acquire(1, 1) + semaphore.acquire(1, 1, 0) fail("Should not acquire permit") } catch { case _: InterruptedException => - semaphore.acquire(1, 1) + semaphore.acquire(1, 1, 0) } }) t.start() @@ -62,7 +62,7 @@ class PrioritySemaphoreSuite extends AnyFunSuite { def taskWithPriority(priority: Int) = new Runnable { override def run(): Unit = { - semaphore.acquire(1, priority) + semaphore.acquire(1, priority, 0) results.add(priority) semaphore.release(1) } @@ -84,9 +84,9 @@ class PrioritySemaphoreSuite extends AnyFunSuite { test("low priority thread cannot surpass high priority thread") { val semaphore = new TestPrioritySemaphore(10) - semaphore.acquire(5, 0) + semaphore.acquire(5, 0, 0) val t = new Thread(() => { - semaphore.acquire(10, 2) + semaphore.acquire(10, 2, 0) semaphore.release(10) }) t.start()