Skip to content

Commit

Permalink
refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
aoli-al committed Jun 21, 2024
1 parent d7f2a4c commit 6fbf56d
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 105 deletions.
209 changes: 127 additions & 82 deletions core/src/main/kotlin/cmu/pasta/fray/core/GlobalContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ object GlobalContext {
for (thread in registeredThreads.values) {
if (thread.state == ThreadState.Paused) {
thread.state = ThreadState.Enabled
val pendingOperation = thread.pendingOperation
thread.pendingOperation =
when (pendingOperation) {
is ObjectWaitBlocking -> {
ObjectWakeBlocking(pendingOperation.o)
}
is ConditionAwaitBlocking -> {
ConditionWakeBlocking(pendingOperation.condition)
}
is ObjectWakeBlocking -> {
pendingOperation
}
is ConditionWakeBlocking -> {
pendingOperation
}
else -> {
ThreadResumeOperation()
}
}
lockManager.threadUnblockedDueToDeadlock(thread.thread)
break
}
Expand Down Expand Up @@ -170,7 +189,7 @@ object GlobalContext {
val context = registeredThreads[t.id]!!

context.state = ThreadState.Enabled
context.pendingOperation = ParkOperation()
context.pendingOperation = ParkBlocking()
scheduleNextOperation(true)

context.checkInterrupt()
Expand All @@ -184,7 +203,7 @@ object GlobalContext {
val context = registeredThreads[t.id]!!

if (!context.unparkSignaled) {
context.pendingOperation = ParkOperation()
context.pendingOperation = ParkBlocking()
context.state = ThreadState.Paused
scheduleNextOperation(true)
if (context.unparkSignaled) {
Expand All @@ -193,19 +212,11 @@ object GlobalContext {
} else {
context.unparkSignaled = false
}
val state = context.state
val operation = context.pendingOperation
if (state != ThreadState.Enabled) {
println(state)
println(operation)
return
}
assert(state == ThreadState.Enabled)
}

fun threadUnpark(t: Thread) {
val context = registeredThreads[t.id]!!
if (context.state == ThreadState.Paused && context.pendingOperation is ParkOperation) {
if (context.state == ThreadState.Paused && context.pendingOperation is ParkBlocking) {
context.state = ThreadState.Enabled
context.pendingOperation = ThreadResumeOperation()
} else if (context.state == ThreadState.Enabled || context.state == ThreadState.Running) {
Expand Down Expand Up @@ -277,15 +288,18 @@ object GlobalContext {
if (canInterrupt) {
context.checkInterrupt()
}
context.blockedBy = waitingObject
// No matter if an interrupt is signaled, we need to enter the `wait` method
// first which will unlock the reentrant lock and tries to reacquire it.
context.pendingOperation = PausedOperation()
if (lockObject == waitingObject) {
context.pendingOperation = ObjectWaitBlocking(waitingObject)
} else {
context.pendingOperation = ConditionAwaitBlocking(waitingObject as Condition, canInterrupt)
}
context.state = ThreadState.Paused
lockManager.addWaitingThread(waitingObject, Thread.currentThread(), canInterrupt)
lockManager.addWaitingThread(waitingObject, Thread.currentThread())
unlockImpl(lockObject, t, true, true, lockObject == waitingObject)
checkDeadlock {
context.blockedBy = null
context.pendingOperation = ThreadResumeOperation()
assert(lockManager.lock(lockObject, t, false, true, false))
syncManager.removeWait(lockObject)
context.state = ThreadState.Running
Expand Down Expand Up @@ -348,48 +362,58 @@ object GlobalContext {
val context = registeredThreads[t.id]!!
context.interruptSignaled = true

// A thread is interrupted during wait/await.
if (context.blockedBy != null) {
val lock =
if (context.blockedBy is Condition) {
lockManager.lockFromCondition(context.blockedBy as Condition)
} else {
context.blockedBy!!
}
if (lockManager.threadInterruptDuringObjectWait(context.blockedBy!!, lock, context)) {
syncManager.createWait(lock, 1)
}
if (context.state == ThreadState.Running || context.state == ThreadState.Enabled) {
return
}

// A thread is interrupted during park.
if (context.state == ThreadState.Paused && context.pendingOperation is ParkOperation) {
context.pendingOperation = ThreadResumeOperation()
context.state = ThreadState.Enabled
val pendingOperation = context.pendingOperation
var waitingObject: Any? = null
when (pendingOperation) {
is ObjectWaitBlocking -> {
if (lockManager.threadInterruptDuringObjectWait(
pendingOperation.o, pendingOperation.o, context)) {
syncManager.createWait(pendingOperation.o, 1)
waitingObject = pendingOperation.o
}
}
is ConditionAwaitBlocking -> {
if (pendingOperation.canInterrupt) {
val lock = lockManager.lockFromCondition(pendingOperation.condition)
if (lockManager.threadInterruptDuringObjectWait(
pendingOperation.condition, lock, context)) {
syncManager.createWait(lock, 1)
waitingObject = lock
}
}
}
is CountDownLatchAwaitBlocking -> {
context.pendingOperation = ThreadResumeOperation()
context.state = ThreadState.Enabled
syncManager.createWait(pendingOperation.latch, 1)
waitingObject = pendingOperation.latch
}
is ParkBlocking -> {
context.pendingOperation = ThreadResumeOperation()
context.state = ThreadState.Enabled
}
is LockBlocking -> {
lockManager.getLockContext(pendingOperation.lock).interrupt(t.id)
}
}

// A thread is interrupted during lockInterruptibly.
if (context.waitingForLock != null) {
val lock = context.waitingForLock!!
lockManager.getLockContext(lock).interrupt(t.id)
if (waitingObject != null) {
registeredThreads[Thread.currentThread().id]!!.pendingOperation =
InterruptPendingOperation(waitingObject)
}
}

fun threadInterruptDone(t: Thread) {
val context = registeredThreads[t.id]!!
context.interruptSignaled = true

// A thread is interrupted during wait/await.
if (context.blockedBy != null) {
val lock =
if (context.blockedBy is Condition) {
lockManager.lockFromCondition(context.blockedBy as Condition)
} else {
context.blockedBy!!
}
syncManager.wait(lock)
} else {
syncManager.wait(t)
val context = registeredThreads[Thread.currentThread().id]!!
val pendingOperation = context.pendingOperation
if (pendingOperation is InterruptPendingOperation) {
syncManager.wait(pendingOperation.waitingObject)
}
context.pendingOperation = ThreadResumeOperation()
}

fun threadClearInterrupt(t: Thread): Boolean {
Expand All @@ -415,7 +439,11 @@ object GlobalContext {
lockManager.threadWaitsFor.remove(t)
val context = registeredThreads[t]!!
lockManager.addWakingThread(lockObject, context.thread)
context.blockedBy = waitingObject
if (waitingObject == lockObject) {
context.pendingOperation = ObjectWakeBlocking(waitingObject)
} else {
context.pendingOperation = ConditionWakeBlocking(waitingObject as Condition)
}
it.remove(t)
if (it.size == 0) {
lockManager.waitingThreads.remove(id)
Expand All @@ -441,7 +469,11 @@ object GlobalContext {
lockManager.threadWaitsFor.remove(t)
// We cannot enable the thread immediately because
// the thread is still waiting for the monitor lock.
context.blockedBy = waitingObject
if (waitingObject == lockObject) {
context.pendingOperation = ObjectWakeBlocking(waitingObject)
} else {
context.pendingOperation = ConditionWakeBlocking(waitingObject as Condition)
}
lockManager.addWakingThread(lockObject, context.thread)
}
lockManager.waitingThreads.remove(id)
Expand Down Expand Up @@ -488,16 +520,12 @@ object GlobalContext {
// }
while (!lockManager.lock(lock, t, shouldBlock, false, canInterrupt) && shouldBlock) {
context.state = ThreadState.Paused
context.waitingForLock = lock
try {
// We want to block current thread because we do
// not want to rely on ReentrantLock. This allows
// us to pick which Thread to run next if multiple
// threads hold the same lock.
scheduleNextOperation(true)
} finally {
context.waitingForLock = null
}
context.pendingOperation = LockBlocking(lock)
// We want to block current thread because we do
// not want to rely on ReentrantLock. This allows
// us to pick which Thread to run next if multiple
// threads hold the same lock.
scheduleNextOperation(true)
if (canInterrupt) {
context.checkInterrupt()
}
Expand Down Expand Up @@ -656,9 +684,12 @@ object GlobalContext {
if (latchManager.await(latch, true)) {
val t = Thread.currentThread().id
val context = registeredThreads[t]!!
context.pendingOperation = PausedOperation()
context.pendingOperation = CountDownLatchAwaitBlocking(latch)
context.state = ThreadState.Paused
checkDeadlock { context.state = ThreadState.Running }
checkDeadlock {
context.state = ThreadState.Running
context.pendingOperation = ThreadResumeOperation()
}
executor.submit {
while (registeredThreads[t]!!.thread.state == Thread.State.RUNNABLE) {
Thread.yield()
Expand Down Expand Up @@ -712,6 +743,23 @@ object GlobalContext {
if (thread.state == ThreadState.Paused) {
thread.state = ThreadState.Enabled
lockManager.threadUnblockedDueToDeadlock(thread.thread)
val pendingOperation = thread.pendingOperation
when (pendingOperation) {
is ObjectWaitBlocking -> {
thread.pendingOperation = ObjectWakeBlocking(pendingOperation.o)
}
is ConditionAwaitBlocking -> {
thread.pendingOperation = ConditionWakeBlocking(pendingOperation.condition)
}
is CountDownLatchAwaitBlocking -> {
val releasedThreads = latchManager.release(pendingOperation.latch)
syncManager.createWait(pendingOperation.latch, releasedThreads)
while (pendingOperation.latch.count > 0) {
pendingOperation.latch.countDown()
}
syncManager.wait(pendingOperation.latch)
}
}
scheduleNextOperation(shouldBlockCurrentThread)
break
}
Expand Down Expand Up @@ -797,34 +845,31 @@ object GlobalContext {
}
}
nextThread.state = ThreadState.Running
if (currentThread != nextThread || currentThread.blockedBy != null) {
unblockThread(nextThread)
}
unblockThread(currentThread, nextThread)
if (currentThread != nextThread && shouldBlockCurrentThread) {
currentThread.block()
}
}

fun unblockThread(t: ThreadContext) {
// If this object is blocked through JDK locks,
// the thread is waiting for monitor locks.
// We first need to give the thread lock
// and then wakes it up through `notifyAll`.
val blockedBy = t.blockedBy
t.blockedBy = null
if (blockedBy != null) {
// FIXME(aoli): relying on type check is not 100% correct,
// because a thread can still be blocked by `condition.wait()`.
if (blockedBy is Condition) {
val lock = lockManager.lockFromCondition(blockedBy)
fun unblockThread(currentThread: ThreadContext, nextThread: ThreadContext) {
val pendingOperation = nextThread.pendingOperation
when (pendingOperation) {
is ConditionWakeBlocking -> {
nextThread.pendingOperation = ThreadResumeOperation()
val lock = lockManager.lockFromCondition(pendingOperation.condition)
lock.lock()
blockedBy.signalAll()
pendingOperation.condition.signalAll()
lock.unlock()
} else {
synchronized(blockedBy) { (blockedBy as Object).notifyAll() }
return
}
} else {
t.unblock()
is ObjectWakeBlocking -> {
nextThread.pendingOperation = ThreadResumeOperation()
synchronized(pendingOperation.o) { (pendingOperation.o as Object).notifyAll() }
return
}
}
if (currentThread != nextThread) {
nextThread.unblock()
}
}
}
5 changes: 0 additions & 5 deletions core/src/main/kotlin/cmu/pasta/fray/core/ThreadContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ class ThreadContext(val thread: Thread, val index: Int) {
var interruptSignaled = false
var isExiting = false

// This field is set when a thread is resumed by `o.notify()`
// but hasn't acquire the monitor lock.
var blockedBy: Any? = null
var waitingForLock: Any? = null

// Pending operation is null if a thread is just resumed/blocked.
var pendingOperation: Operation = ThreadStartOperation()
val sync = Sync(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ class CountDownLatchContext(var count: Long) : Interruptible {
}
}

fun release(): Int {
if (count == 0L) {
return 0
}
count = 0
var threads = 0
for (tid in latchWaiters.keys) {
GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation()
GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled
threads += 1
}
return threads
}

/*
* Returns number of unblocked threads.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ class CountDownLatchManager {
fun countDown(latch: CountDownLatch): Int {
return latchStore.getLockContext(latch).countDown()
}

fun release(latch: CountDownLatch): Int {
return latchStore.getLockContext(latch).release()
}
}
Loading

0 comments on commit 6fbf56d

Please sign in to comment.