Skip to content

Commit

Permalink
save threadlocal lookup
Browse files Browse the repository at this point in the history
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
  • Loading branch information
binmahone committed Oct 10, 2024
1 parent 0967cb8 commit 2bebd2b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object GpuSemaphore {
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
Expand Down Expand Up @@ -253,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits, lastHeld)
semaphore.acquire(numPermits, lastHeld, taskAttemptId)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand Down Expand Up @@ -333,7 +333,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
if (taskInfo.tryAcquire(semaphore)) {
GpuDeviceManager.initializeFromTask()
Expand All @@ -357,7 +357,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
taskInfo.blockUntilReady(semaphore)
GpuDeviceManager.initializeFromTask()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package com.nvidia.spark.rapids
import java.util.PriorityQueue
import java.util.concurrent.locks.{Condition, ReentrantLock}

import org.apache.spark.TaskContext

class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) {
// This lock is used to generate condition variables, which affords us the flexibility to notify
// specific threads at a time. If we use the regular synchronized pattern, we have to either
Expand Down Expand Up @@ -58,12 +56,12 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
}
}

def acquire(numPermits: Int, priority: T): Unit = {
def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = {
lock.lock()
try {
if (!tryAcquire(numPermits, priority)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits, TaskContext.get().taskAttemptId())
val info = ThreadInfo(priority, condition, numPermits, taskAttemptId)
try {
waitingQueue.add(info)
while (!info.signaled) {
Expand Down

0 comments on commit 2bebd2b

Please sign in to comment.