diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/GlobalContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/GlobalContext.kt index 2e4e5a0d..fc0dc881 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/GlobalContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/GlobalContext.kt @@ -9,8 +9,6 @@ import cmu.pasta.fray.core.concurrency.locks.SemaphoreManager import cmu.pasta.fray.core.concurrency.operations.* import cmu.pasta.fray.core.logger.LoggerBase import cmu.pasta.fray.core.scheduler.Choice -import cmu.pasta.fray.core.scheduler.FifoScheduler -import cmu.pasta.fray.core.scheduler.Scheduler import cmu.pasta.fray.instrumentation.memory.VolatileManager import cmu.pasta.fray.runtime.DeadlockException import cmu.pasta.fray.runtime.Delegate @@ -22,6 +20,7 @@ import java.io.PrintWriter import java.io.StringWriter import java.lang.Thread.UncaughtExceptionHandler import java.util.concurrent.CountDownLatch +import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.Semaphore import java.util.concurrent.locks.Condition @@ -32,14 +31,11 @@ import java.util.concurrent.locks.ReentrantReadWriteLock.ReadLock import java.util.concurrent.locks.ReentrantReadWriteLock.WriteLock import kotlin.system.exitProcess -// TODO(aoli): make this a class maybe? @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") -object GlobalContext { +class GlobalContext(val config: Configuration) { val registeredThreads = mutableMapOf() var currentThreadId: Long = -1 var mainThreadId: Long = -1 - var scheduler: Scheduler = FifoScheduler() - var config: Configuration? = null var bugFound: Throwable? = null var mainExiting = false var nanoTime = System.nanoTime() @@ -51,7 +47,7 @@ object GlobalContext { private var step = 0 val syncManager = SynchronizationManager() val loggers = mutableListOf() - var executor = + var executor: ExecutorService = Executors.newSingleThreadExecutor { r -> object : HelperThread() { override fun run() { @@ -72,7 +68,7 @@ object GlobalContext { } fun reportError(e: Throwable) { - if (bugFound == null && !config!!.executionInfo.ignoreUnhandledExceptions) { + if (bugFound == null && !config.executionInfo.ignoreUnhandledExceptions) { bugFound = e val sw = StringWriter() sw.append("Error found: ${e}\n") @@ -92,7 +88,7 @@ object GlobalContext { for (logger in loggers) { logger.applicationEvent(sw.toString()) } - if (config!!.exploreMode || config!!.noExitWhenBugFound) { + if (config.exploreMode || config.noExitWhenBugFound) { return } loggers.forEach { it.executionDone(bugFound != null) } @@ -165,13 +161,13 @@ object GlobalContext { } fun done() { - loggers.forEach { it.executionDone(bugFound != null && config!!.exploreMode) } + loggers.forEach { it.executionDone(bugFound != null && config.exploreMode) } loggers.clear() assert(lockManager.waitingThreads.isEmpty()) assert(syncManager.synchronizationPoints.isEmpty()) lockManager.done() registeredThreads.clear() - scheduler.done() + config.scheduler.done() } fun shutDown() { @@ -281,7 +277,7 @@ object GlobalContext { var size = 0 lockManager.getLockContext(t).wakingThreads.let { for (thread in it) { - registeredThreads[thread]!!.state = ThreadState.Enabled + thread.value.state = ThreadState.Enabled } size = it.size } @@ -323,7 +319,7 @@ object GlobalContext { unlockImpl(lockObject, t, true, true, lockObject == waitingObject) checkDeadlock { context.pendingOperation = ThreadResumeOperation() - assert(lockManager.lock(lockObject, t, false, true, false)) + assert(lockManager.lock(lockObject, context, false, true, false)) syncManager.removeWait(lockObject) context.state = ThreadState.Running } @@ -375,7 +371,7 @@ object GlobalContext { } } // If a thread is enabled, the lock must be available. - assert(lockManager.lock(lockObject, t.id, false, true, false)) + assert(lockManager.lock(lockObject, context, false, true, false)) if (canInterrupt) { context.checkInterrupt() } @@ -461,7 +457,7 @@ object GlobalContext { val t = it.removeFirst() lockManager.threadWaitsFor.remove(t) val context = registeredThreads[t]!! - lockManager.addWakingThread(lockObject, context.thread) + lockManager.addWakingThread(lockObject, context) if (waitingObject == lockObject) { context.pendingOperation = ObjectWakeBlocking(waitingObject) } else { @@ -497,7 +493,7 @@ object GlobalContext { } else { context.pendingOperation = ConditionWakeBlocking(waitingObject as Condition) } - lockManager.addWakingThread(lockObject, context.thread) + lockManager.addWakingThread(lockObject, context) } lockManager.waitingThreads.remove(id) } @@ -545,7 +541,7 @@ object GlobalContext { // synchronized(lock) { // lock.unlock(); // } - while (!lockManager.lock(lock, t, shouldBlock, false, canInterrupt) && shouldBlock) { + while (!lockManager.lock(lock, context, shouldBlock, false, canInterrupt) && shouldBlock) { context.state = ThreadState.Paused context.pendingOperation = LockBlocking(lock) // We want to block current thread because we do @@ -588,7 +584,7 @@ object GlobalContext { isMonitorLock: Boolean ) { var waitingThreads = - if (lockManager.unlock(lock, tid, unlockBecauseOfWait)) { + if (lockManager.unlock(lock, tid, unlockBecauseOfWait, bugFound != null)) { lockManager.getNumThreadsBlockBy(lock, isMonitorLock) } else { 0 @@ -649,7 +645,8 @@ object GlobalContext { context.state = ThreadState.Enabled scheduleNextOperation(true) - while (!semaphoreManager.acquire(sem, permits, shouldBlock, canInterrupt) && shouldBlock) { + while (!semaphoreManager.acquire(sem, permits, shouldBlock, canInterrupt, context) && + shouldBlock) { context.state = ThreadState.Paused scheduleNextOperation(true) @@ -672,7 +669,7 @@ object GlobalContext { } fun fieldOperation(obj: Any?, owner: String, name: String, type: MemoryOpType) { - if (!config!!.executionInfo.interleaveMemoryOps && !volatileManager.isVolatile(owner, name)) + if (!config.executionInfo.interleaveMemoryOps && !volatileManager.isVolatile(owner, name)) return val objIds = mutableListOf() if (obj != null) { @@ -690,7 +687,7 @@ object GlobalContext { } fun arrayOperation(obj: Any, index: Int, type: MemoryOpType) { - if (!config!!.executionInfo.interleaveMemoryOps) return + if (!config.executionInfo.interleaveMemoryOps) return val objId = System.identityHashCode(obj) memoryOperation((31 * objId) + index, type) } @@ -708,9 +705,9 @@ object GlobalContext { } fun latchAwait(latch: CountDownLatch) { - if (latchManager.await(latch, true)) { - val t = Thread.currentThread().id - val context = registeredThreads[t]!! + val t = Thread.currentThread().id + val context = registeredThreads[t]!! + if (latchManager.await(latch, true, context)) { context.pendingOperation = CountDownLatchAwaitBlocking(latch) context.state = ThreadState.Paused checkDeadlock { @@ -847,7 +844,7 @@ object GlobalContext { } step += 1 - if (config!!.executionInfo.maxScheduledStep in 1 ..< step && + if (config.executionInfo.maxScheduledStep in 1 ..< step && !currentThread.isExiting && Thread.currentThread() !is HelperThread && !(mainExiting && currentThreadId == mainThreadId)) { @@ -857,11 +854,11 @@ object GlobalContext { throw e } - val nextThread = scheduler.scheduleNextOperation(enabledOperations) + val nextThread = config.scheduler.scheduleNextOperation(enabledOperations) val index = enabledOperations.indexOf(nextThread) currentThreadId = nextThread.thread.id - if (enabledOperations.size > 1 || config!!.fullSchedule) { + if (enabledOperations.size > 1 || config.fullSchedule) { loggers.forEach { it.newOperationScheduled( nextThread.pendingOperation, diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/RuntimeDelegate.kt b/core/src/main/kotlin/cmu/pasta/fray/core/RuntimeDelegate.kt index 0c2311ca..0c81d4ff 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/RuntimeDelegate.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/RuntimeDelegate.kt @@ -12,7 +12,7 @@ import java.util.concurrent.locks.Lock import java.util.concurrent.locks.LockSupport import java.util.concurrent.locks.ReentrantReadWriteLock -class RuntimeDelegate : Delegate() { +class RuntimeDelegate(val context: GlobalContext) : Delegate() { var entered = ThreadLocal.withInitial { false } var skipFunctionEntered = ThreadLocal.withInitial { 0 } @@ -32,7 +32,7 @@ class RuntimeDelegate : Delegate() { return true } // We do not process threads created outside of application. - if (!GlobalContext.registeredThreads.containsKey(Thread.currentThread().id)) { + if (!context.registeredThreads.containsKey(Thread.currentThread().id)) { entered.set(false) return true } @@ -41,7 +41,7 @@ class RuntimeDelegate : Delegate() { override fun onMainExit() { if (checkEntered()) return - GlobalContext.mainExit() + context.mainExit() entered.set(false) } @@ -50,7 +50,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("thread.start") return } - GlobalContext.threadStart(t) + context.threadStart(t) onSkipMethod("thread.start") entered.set(false) } @@ -58,25 +58,25 @@ class RuntimeDelegate : Delegate() { override fun onThreadStartDone(t: Thread) { onSkipMethodDone("thread.start") if (checkEntered()) return - GlobalContext.threadStartDone(t) + context.threadStartDone(t) entered.set(false) } override fun onThreadRun() { if (checkEntered()) return - GlobalContext.threadRun() + context.threadRun() entered.set(false) } override fun onThreadEnd() { if (checkEntered()) return - GlobalContext.threadCompleted(Thread.currentThread()) + context.threadCompleted(Thread.currentThread()) entered.set(false) } override fun onThreadGetState(t: Thread, state: Thread.State): Thread.State { if (checkEntered()) return state - val result = GlobalContext.threadGetState(t, state) + val result = context.threadGetState(t, state) entered.set(false) return result } @@ -84,7 +84,7 @@ class RuntimeDelegate : Delegate() { override fun onObjectWait(o: Any) { if (checkEntered()) return try { - GlobalContext.objectWait(o) + context.objectWait(o) } finally { entered.set(false) } @@ -93,7 +93,7 @@ class RuntimeDelegate : Delegate() { override fun onObjectWaitDone(o: Any) { if (checkEntered()) return try { - GlobalContext.objectWaitDone(o) + context.objectWaitDone(o) } finally { entered.set(false) } @@ -101,13 +101,13 @@ class RuntimeDelegate : Delegate() { override fun onObjectNotify(o: Any) { if (checkEntered()) return - GlobalContext.objectNotify(o) + context.objectNotify(o) entered.set(false) } override fun onObjectNotifyAll(o: Any) { if (checkEntered()) return - GlobalContext.objectNotifyAll(o) + context.objectNotifyAll(o) entered.set(false) } @@ -117,7 +117,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.lockLock(l, true) + context.lockLock(l, true) } finally { entered.set(false) onSkipMethod("Lock.lock") @@ -129,7 +129,7 @@ class RuntimeDelegate : Delegate() { entered.set(false) return result } - val result = GlobalContext.lockHasQueuedThreads(l) + val result = context.lockHasQueuedThreads(l) entered.set(false) return result } @@ -139,7 +139,7 @@ class RuntimeDelegate : Delegate() { entered.set(false) return result } - val result = GlobalContext.lockHasQueuedThread(l, t) + val result = context.lockHasQueuedThread(l, t) entered.set(false) return result } @@ -150,7 +150,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.lockTryLock(l, false) + context.lockTryLock(l, false) } finally { entered.set(false) onSkipMethod("Lock.tryLock") @@ -167,7 +167,7 @@ class RuntimeDelegate : Delegate() { return timeout } try { - GlobalContext.lockTryLock(l, true) + context.lockTryLock(l, true) } finally { entered.set(false) onSkipMethod("Lock.tryLock") @@ -185,7 +185,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.lockLock(l, false) + context.lockLock(l, false) } finally { onSkipMethod("Lock.lock") entered.set(false) @@ -199,7 +199,7 @@ class RuntimeDelegate : Delegate() { override fun onAtomicOperation(o: Any, type: MemoryOpType) { if (checkEntered()) return try { - GlobalContext.atomicOperation(o, type) + context.atomicOperation(o, type) } finally { entered.set(false) } @@ -211,7 +211,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.lockUnlock(l) + context.lockUnlock(l) } finally { entered.set(false) onSkipMethod("Lock.unlock") @@ -221,14 +221,14 @@ class RuntimeDelegate : Delegate() { override fun onLockUnlockDone(l: Lock) { onSkipMethodDone("Lock.unlock") if (checkEntered()) return - GlobalContext.lockUnlockDone(l) + context.lockUnlockDone(l) entered.set(false) } override fun onMonitorEnter(o: Any) { if (checkEntered()) return try { - GlobalContext.monitorEnter(o) + context.monitorEnter(o) } finally { entered.set(false) } @@ -236,19 +236,19 @@ class RuntimeDelegate : Delegate() { override fun onMonitorExit(o: Any) { if (checkEntered()) return - GlobalContext.monitorExit(o) + context.monitorExit(o) entered.set(false) } override fun onMonitorExitDone(o: Any) { if (checkEntered()) return - GlobalContext.monitorEnterDone(o) + context.monitorEnterDone(o) entered.set(false) } override fun onLockNewCondition(c: Condition, l: Lock) { if (checkEntered()) return - GlobalContext.lockNewCondition(c, l) + context.lockNewCondition(c, l) entered.set(false) } @@ -258,7 +258,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.conditionAwait(o, true) + context.conditionAwait(o, true) } finally { entered.set(false) onSkipMethod("Condition.await") @@ -271,7 +271,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.conditionAwait(o, false) + context.conditionAwait(o, false) } finally { entered.set(false) onSkipMethod("Condition.await") @@ -284,7 +284,7 @@ class RuntimeDelegate : Delegate() { } if (checkEntered()) return try { - GlobalContext.conditionAwaitDone(o, true) + context.conditionAwaitDone(o, true) } finally { entered.set(false) } @@ -296,7 +296,7 @@ class RuntimeDelegate : Delegate() { } if (checkEntered()) return try { - GlobalContext.conditionAwaitDone(o, false) + context.conditionAwaitDone(o, false) } finally { entered.set(false) } @@ -307,7 +307,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Condition.signal") return } - GlobalContext.conditionSignal(o) + context.conditionSignal(o) entered.set(false) onSkipMethod("Condition.signal") } @@ -321,7 +321,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Condition.signal") return } - GlobalContext.conditionSignalAll(o) + context.conditionSignalAll(o) entered.set(false) onSkipMethod("Condition.signal") } @@ -330,7 +330,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.unsafeOperation(o, offset, MemoryOpType.MEMORY_READ) + context.unsafeOperation(o, offset, MemoryOpType.MEMORY_READ) } finally { entered.set(false) } @@ -340,7 +340,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.unsafeOperation(o, offset, MemoryOpType.MEMORY_WRITE) + context.unsafeOperation(o, offset, MemoryOpType.MEMORY_WRITE) } finally { entered.set(false) } @@ -350,7 +350,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.fieldOperation(o, owner, name, MemoryOpType.MEMORY_READ) + context.fieldOperation(o, owner, name, MemoryOpType.MEMORY_READ) } finally { entered.set(false) } @@ -360,7 +360,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.fieldOperation(o, owner, name, MemoryOpType.MEMORY_WRITE) + context.fieldOperation(o, owner, name, MemoryOpType.MEMORY_WRITE) } finally { entered.set(false) } @@ -369,7 +369,7 @@ class RuntimeDelegate : Delegate() { override fun onStaticFieldRead(owner: String, name: String, descriptor: String) { if (checkEntered()) return try { - GlobalContext.fieldOperation(null, owner, name, MemoryOpType.MEMORY_READ) + context.fieldOperation(null, owner, name, MemoryOpType.MEMORY_READ) } finally { entered.set(false) } @@ -378,7 +378,7 @@ class RuntimeDelegate : Delegate() { override fun onStaticFieldWrite(owner: String, name: String, descriptor: String) { if (checkEntered()) return try { - GlobalContext.fieldOperation(null, owner, name, MemoryOpType.MEMORY_WRITE) + context.fieldOperation(null, owner, name, MemoryOpType.MEMORY_WRITE) } finally { entered.set(false) } @@ -387,7 +387,7 @@ class RuntimeDelegate : Delegate() { override fun onExit(status: Int) { if (checkEntered()) return if (status != 0) { - GlobalContext.reportError(RuntimeException("Exit with status $status")) + context.reportError(RuntimeException("Exit with status $status")) } entered.set(false) } @@ -395,14 +395,14 @@ class RuntimeDelegate : Delegate() { override fun onYield() { if (checkEntered()) return try { - GlobalContext.yield() + context.yield() } finally { entered.set(false) } } override fun onSkipMethod(signature: String) { - if (!GlobalContext.registeredThreads.containsKey(Thread.currentThread().id)) { + if (!context.registeredThreads.containsKey(Thread.currentThread().id)) { return } stackTrace.get().add(signature) @@ -410,7 +410,7 @@ class RuntimeDelegate : Delegate() { } override fun onSkipMethodDone(signature: String): Boolean { - if (!GlobalContext.registeredThreads.containsKey(Thread.currentThread().id)) { + if (!context.registeredThreads.containsKey(Thread.currentThread().id)) { return false } if (stackTrace.get().isEmpty()) { @@ -427,7 +427,7 @@ class RuntimeDelegate : Delegate() { override fun onThreadPark() { if (checkEntered()) return try { - GlobalContext.threadPark() + context.threadPark() } finally { entered.set(false) } @@ -436,7 +436,7 @@ class RuntimeDelegate : Delegate() { override fun onThreadParkDone() { if (checkEntered()) return try { - GlobalContext.threadParkDone() + context.threadParkDone() } finally { entered.set(false) } @@ -447,21 +447,21 @@ class RuntimeDelegate : Delegate() { if (checkEntered()) { return } - GlobalContext.threadUnpark(t) + context.threadUnpark(t) entered.set(false) } override fun onThreadUnparkDone(t: Thread?) { if (t == null) return if (checkEntered()) return - GlobalContext.threadUnparkDone(t) + context.threadUnparkDone(t) entered.set(false) } override fun onThreadInterrupt(t: Thread) { if (checkEntered()) return try { - GlobalContext.threadInterrupt(t) + context.threadInterrupt(t) } finally { entered.set(false) } @@ -469,26 +469,26 @@ class RuntimeDelegate : Delegate() { override fun onThreadInterruptDone(t: Thread) { if (checkEntered()) return - GlobalContext.threadInterruptDone(t) + context.threadInterruptDone(t) entered.set(false) } override fun onThreadClearInterrupt(origin: Boolean, t: Thread): Boolean { if (checkEntered()) return origin - val o = GlobalContext.threadClearInterrupt(t) + val o = context.threadClearInterrupt(t) entered.set(false) return o } override fun onReentrantReadWriteLockInit(lock: ReentrantReadWriteLock) { if (checkEntered()) return - GlobalContext.reentrantReadWriteLockInit(lock.readLock(), lock.writeLock()) + context.reentrantReadWriteLockInit(lock.readLock(), lock.writeLock()) entered.set(false) } override fun onSemaphoreInit(sem: Semaphore) { if (checkEntered()) return - GlobalContext.semaphoreInit(sem) + context.semaphoreInit(sem) entered.set(false) } @@ -498,7 +498,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.semaphoreAcquire(sem, permits, false, true) + context.semaphoreAcquire(sem, permits, false, true) } finally { onSkipMethod("Semaphore.acquire") entered.set(false) @@ -511,7 +511,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.semaphoreAcquire(sem, permits, true, true) + context.semaphoreAcquire(sem, permits, true, true) } finally { onSkipMethod("Semaphore.acquire") entered.set(false) @@ -524,7 +524,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.semaphoreAcquire(sem, permits, true, false) + context.semaphoreAcquire(sem, permits, true, false) } finally { entered.set(false) onSkipMethod("Semaphore.acquire") @@ -540,7 +540,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Semaphore.release") return } - GlobalContext.semaphoreRelease(sem, permits) + context.semaphoreRelease(sem, permits) entered.set(false) onSkipMethod("Semaphore.release") } @@ -554,7 +554,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Semaphore.drainPermits") return } - GlobalContext.semaphoreDrainPermits(sem) + context.semaphoreDrainPermits(sem) entered.set(false) onSkipMethod("Semaphore.drainPermits") } @@ -568,7 +568,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Semaphore.reducePermits") return } - GlobalContext.semaphoreReducePermits(sem, permits) + context.semaphoreReducePermits(sem, permits) entered.set(false) onSkipMethod("Semaphore.reducePermits") } @@ -583,7 +583,7 @@ class RuntimeDelegate : Delegate() { return } try { - GlobalContext.latchAwait(latch) + context.latchAwait(latch) } finally { entered.set(false) onSkipMethod("Latch.await") @@ -591,7 +591,7 @@ class RuntimeDelegate : Delegate() { } override fun onLatchAwaitTimeout(latch: CountDownLatch, timeout: Long, unit: TimeUnit): Boolean { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() return false } else { @@ -603,7 +603,7 @@ class RuntimeDelegate : Delegate() { override fun onLatchAwaitDone(latch: CountDownLatch) { onSkipMethodDone("Latch.await") if (checkEntered()) return - GlobalContext.latchAwaitDone(latch) + context.latchAwaitDone(latch) entered.set(false) } @@ -612,7 +612,7 @@ class RuntimeDelegate : Delegate() { onSkipMethod("Latch.countDown") return } - GlobalContext.latchCountDown(latch) + context.latchCountDown(latch) entered.set(false) onSkipMethod("Latch.countDown") } @@ -620,13 +620,13 @@ class RuntimeDelegate : Delegate() { override fun onLatchCountDownDone(latch: CountDownLatch) { onSkipMethodDone("Latch.countDown") if (checkEntered()) return - GlobalContext.latchCountDownDone(latch) + context.latchCountDownDone(latch) entered.set(false) } override fun onReportError(e: Throwable) { if (checkEntered()) return - GlobalContext.reportError(e) + context.reportError(e) entered.set(false) } @@ -634,7 +634,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.arrayOperation(o, index, MemoryOpType.MEMORY_READ) + context.arrayOperation(o, index, MemoryOpType.MEMORY_READ) } finally { entered.set(false) } @@ -644,7 +644,7 @@ class RuntimeDelegate : Delegate() { if (o == null) return if (checkEntered()) return try { - GlobalContext.arrayOperation(o, index, MemoryOpType.MEMORY_WRITE) + context.arrayOperation(o, index, MemoryOpType.MEMORY_WRITE) } finally { entered.set(false) } @@ -655,7 +655,7 @@ class RuntimeDelegate : Delegate() { // Therefor we cannot call `checkEntered` here. try { entered.set(true) - GlobalContext.start() + context.start() entered.set(false) } catch (e: Throwable) { e.printStackTrace() @@ -663,7 +663,7 @@ class RuntimeDelegate : Delegate() { } override fun onThreadParkNanos(nanos: Long) { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() } else { LockSupport.park() @@ -671,7 +671,7 @@ class RuntimeDelegate : Delegate() { } override fun onThreadParkUntil(nanos: Long) { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() } else { LockSupport.park() @@ -679,7 +679,7 @@ class RuntimeDelegate : Delegate() { } override fun onThreadParkNanosWithBlocker(blocker: Any?, nanos: Long) { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() } else { LockSupport.park(blocker) @@ -687,7 +687,7 @@ class RuntimeDelegate : Delegate() { } override fun onThreadParkUntilWithBlocker(blocker: Any?, nanos: Long) { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() } else { LockSupport.park(blocker) @@ -695,7 +695,7 @@ class RuntimeDelegate : Delegate() { } override fun onConditionAwaitTime(o: Condition, time: Long, unit: TimeUnit): Boolean { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() return false } else { @@ -705,7 +705,7 @@ class RuntimeDelegate : Delegate() { } override fun onConditionAwaitNanos(o: Condition, nanos: Long): Long { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() return 0 } else { @@ -715,7 +715,7 @@ class RuntimeDelegate : Delegate() { } override fun onConditionAwaitUntil(o: Condition, deadline: Date): Boolean { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() return false } else { @@ -726,13 +726,13 @@ class RuntimeDelegate : Delegate() { override fun onThreadIsInterrupted(result: Boolean, t: Thread): Boolean { if (checkEntered()) return result - val isInterrupted = GlobalContext.threadIsInterrupted(t, result) + val isInterrupted = context.threadIsInterrupted(t, result) entered.set(false) return isInterrupted } override fun onLockTryLockTimeout(l: Lock, timeout: Long, unit: TimeUnit): Boolean { - if (GlobalContext.config!!.executionInfo.timedOpAsYield) { + if (context.config!!.executionInfo.timedOpAsYield) { onYield() return false } else { @@ -741,12 +741,12 @@ class RuntimeDelegate : Delegate() { } override fun onNanoTime(): Long { - return GlobalContext.nanoTime() + return context.nanoTime() } override fun onThreadHashCode(t: Any): Int { if (t is Thread) { - val context = GlobalContext.registeredThreads[t.id] + val context = context.registeredThreads[t.id] if (context != null) { return 0 } else { @@ -755,4 +755,8 @@ class RuntimeDelegate : Delegate() { } return t.hashCode() } + + override fun log(format: String, vararg args: Any) { + context.log(format, *args) + } } diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/TestRunner.kt b/core/src/main/kotlin/cmu/pasta/fray/core/TestRunner.kt index 06b95854..5da20fb1 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/TestRunner.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/TestRunner.kt @@ -1,7 +1,5 @@ package cmu.pasta.fray.core -import cmu.pasta.fray.core.GlobalContext.bugFound -import cmu.pasta.fray.core.GlobalContext.loggers import cmu.pasta.fray.core.command.Configuration import cmu.pasta.fray.core.logger.ConsoleLogger import cmu.pasta.fray.core.scheduler.ReplayScheduler @@ -22,30 +20,30 @@ class TestRunner(val config: Configuration) { println("Report is available at: ${path.toAbsolutePath()}") } - fun setup() { + fun createContext(): GlobalContext { + val context = GlobalContext(config) if (config.scheduler !is ReplayScheduler) { prepareReportPath(config.report) - config.loggers.forEach(GlobalContext::registerLogger) + config.loggers.forEach(context::registerLogger) } - GlobalContext.registerLogger(ConsoleLogger()) - GlobalContext.scheduler = config.scheduler - GlobalContext.config = config - GlobalContext.bootstrap() + context.registerLogger(ConsoleLogger()) + context.bootstrap() + return context } fun run(): Throwable? { config.executionInfo.executor.beforeExecution() + val context = createContext() if (config.noFray) { config.executionInfo.executor.execute() } else { - setup() val timeSource = TimeSource.Monotonic val start = timeSource.markNow() var i = 0 while (i != config.iter) { println("Starting iteration $i") try { - Runtime.DELEGATE = RuntimeDelegate() + Runtime.DELEGATE = RuntimeDelegate(context) Runtime.start() config.executionInfo.executor.execute() Runtime.onMainExit() @@ -53,7 +51,7 @@ class TestRunner(val config: Configuration) { Runtime.onReportError(e) Runtime.onMainExit() } - if (bugFound != null) { + if (context.bugFound != null) { println( "Error found at iter: $i, Elapsed time: ${(timeSource.markNow() - start).inWholeMilliseconds}ms") if (!config.exploreMode) { @@ -62,9 +60,9 @@ class TestRunner(val config: Configuration) { } i++ } - GlobalContext.shutDown() + context.shutDown() } config.executionInfo.executor.afterExecution() - return GlobalContext.bugFound + return context.bugFound } } diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchContext.kt index 4c27ac1b..4ff2e049 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchContext.kt @@ -1,18 +1,18 @@ package cmu.pasta.fray.core.concurrency.locks -import cmu.pasta.fray.core.GlobalContext +import cmu.pasta.fray.core.ThreadContext import cmu.pasta.fray.core.ThreadState import cmu.pasta.fray.core.concurrency.operations.ThreadResumeOperation class CountDownLatchContext(var count: Long) : Interruptible { - val latchWaiters = mutableMapOf() + val latchWaiters = mutableMapOf() - fun await(canInterrupt: Boolean): Boolean { + fun await(canInterrupt: Boolean, thread: ThreadContext): Boolean { if (count > 0) { if (canInterrupt) { - GlobalContext.registeredThreads[Thread.currentThread().id]?.checkInterrupt() + thread.checkInterrupt() } - latchWaiters[Thread.currentThread().id] = canInterrupt + latchWaiters[Thread.currentThread().id] = LockWaiter(canInterrupt, thread) return true } assert(count == 0L) @@ -20,9 +20,10 @@ class CountDownLatchContext(var count: Long) : Interruptible { } override fun interrupt(tid: Long) { - if (latchWaiters[tid] == true) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + val lockWaiter = latchWaiters[tid] ?: return + if (lockWaiter.canInterrupt) { + lockWaiter.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.thread.state = ThreadState.Enabled latchWaiters.remove(tid) } } @@ -33,9 +34,9 @@ class CountDownLatchContext(var count: Long) : Interruptible { } count = 0 var threads = 0 - for (tid in latchWaiters.keys) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + for (lockWaiter in latchWaiters.values) { + lockWaiter.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.thread.state = ThreadState.Enabled threads += 1 } return threads @@ -53,9 +54,9 @@ class CountDownLatchContext(var count: Long) : Interruptible { count -= 1 if (count == 0L) { var threads = 0 - for (tid in latchWaiters.keys) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + for (lockWaiter in latchWaiters.values) { + lockWaiter.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.thread.state = ThreadState.Enabled threads += 1 } return threads diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchManager.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchManager.kt index adf49d15..e7f08953 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchManager.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/CountDownLatchManager.kt @@ -1,5 +1,6 @@ package cmu.pasta.fray.core.concurrency.locks +import cmu.pasta.fray.core.ThreadContext import java.util.concurrent.CountDownLatch class CountDownLatchManager { @@ -11,8 +12,8 @@ class CountDownLatchManager { } } - fun await(latch: CountDownLatch, canInterrupt: Boolean): Boolean { - return latchStore.getLockContext(latch).await(canInterrupt) + fun await(latch: CountDownLatch, canInterrupt: Boolean, thread: ThreadContext): Boolean { + return latchStore.getLockContext(latch).await(canInterrupt, thread) } /* diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockContext.kt index e38e9bc4..d50e893c 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockContext.kt @@ -1,21 +1,23 @@ package cmu.pasta.fray.core.concurrency.locks +import cmu.pasta.fray.core.ThreadContext + interface LockContext : Interruptible { - val wakingThreads: MutableSet + val wakingThreads: MutableMap - fun addWakingThread(lockObject: Any, t: Thread) + fun addWakingThread(lockObject: Any, t: ThreadContext) fun canLock(tid: Long): Boolean fun lock( lock: Any, - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean ): Boolean - fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean): Boolean + fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean, earlyExit: Boolean): Boolean fun hasQueuedThreads(): Boolean diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockManager.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockManager.kt index 00dfce29..7400169f 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockManager.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockManager.kt @@ -39,12 +39,12 @@ class LockManager { /** Return true if [lock] is acquired by the current thread. */ fun lock( lock: Any, - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean ): Boolean { - return getLockContext(lock).lock(lock, tid, shouldBlock, lockBecauseOfWait, canInterrupt) + return getLockContext(lock).lock(lock, lockThread, shouldBlock, lockBecauseOfWait, canInterrupt) } fun hasQueuedThreads(lock: Any): Boolean { @@ -55,7 +55,7 @@ class LockManager { return getLockContext(lock).hasQueuedThread(t.id) } - fun addWakingThread(lockObject: Any, t: Thread) { + fun addWakingThread(lockObject: Any, t: ThreadContext) { getLockContext(lockObject).addWakingThread(lockObject, t) } @@ -82,7 +82,7 @@ class LockManager { if (waitingThreads[id]?.isEmpty() == true) { waitingThreads.remove(id) } - addWakingThread(lockObject, context.thread) + addWakingThread(lockObject, context) if (waitingObject == lockObject) { context.pendingOperation = ObjectWakeBlocking(waitingObject) } else { @@ -123,8 +123,8 @@ class LockManager { } } - fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean): Boolean { - return getLockContext(lock).unlock(lock, tid, unlockBecauseOfWait) + fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean, earlyExit: Boolean): Boolean { + return getLockContext(lock).unlock(lock, tid, unlockBecauseOfWait, earlyExit) } fun done() { diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockWaiter.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockWaiter.kt new file mode 100644 index 00000000..1309b790 --- /dev/null +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/LockWaiter.kt @@ -0,0 +1,5 @@ +package cmu.pasta.fray.core.concurrency.locks + +import cmu.pasta.fray.core.ThreadContext + +class LockWaiter(val canInterrupt: Boolean, val thread: ThreadContext) {} diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantLockContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantLockContext.kt index 94e973a0..f0737b52 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantLockContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantLockContext.kt @@ -1,6 +1,6 @@ package cmu.pasta.fray.core.concurrency.locks -import cmu.pasta.fray.core.GlobalContext +import cmu.pasta.fray.core.ThreadContext import cmu.pasta.fray.core.ThreadState import cmu.pasta.fray.core.concurrency.operations.ThreadResumeOperation @@ -8,11 +8,11 @@ class ReentrantLockContext : LockContext { private var lockHolder: Long? = null private val lockTimes = mutableMapOf() // Mapping from thread id to whether the thread is interruptible. - private val lockWaiters = mutableMapOf() - override val wakingThreads: MutableSet = mutableSetOf() + private val lockWaiters = mutableMapOf() + override val wakingThreads = mutableMapOf() - override fun addWakingThread(lockObject: Any, t: Thread) { - wakingThreads.add(t.id) + override fun addWakingThread(lockObject: Any, t: ThreadContext) { + wakingThreads[t.thread.id] = t } override fun canLock(tid: Long) = lockHolder == null || lockHolder == tid @@ -30,11 +30,12 @@ class ReentrantLockContext : LockContext { override fun lock( lock: Any, - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean, ): Boolean { + val tid = lockThread.thread.id if (lockHolder == null || lockHolder == tid) { lockHolder = tid if (!lockBecauseOfWait) { @@ -43,24 +44,29 @@ class ReentrantLockContext : LockContext { wakingThreads.remove(tid) // TODO(aoli): I don't like the design that we need to access GlobalContext here. - for (thread in wakingThreads) { - GlobalContext.registeredThreads[thread]!!.state = ThreadState.Paused + for (thread in wakingThreads.values) { + thread.state = ThreadState.Paused } return true } else { if (canInterrupt) { - GlobalContext.registeredThreads[tid]?.checkInterrupt() + lockThread.checkInterrupt() } if (shouldBlock) { - lockWaiters[tid] = canInterrupt + lockWaiters[tid] = LockWaiter(canInterrupt, lockThread) } } return false } - override fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean): Boolean { - assert(lockHolder == tid || GlobalContext.bugFound != null) - if (lockHolder != tid && GlobalContext.bugFound != null) { + override fun unlock( + lock: Any, + tid: Long, + unlockBecauseOfWait: Boolean, + earlyExit: Boolean + ): Boolean { + assert(lockHolder == tid || earlyExit) + if (lockHolder != tid && earlyExit) { return false } if (!unlockBecauseOfWait) { @@ -72,17 +78,15 @@ class ReentrantLockContext : LockContext { lockTimes.remove(tid) } lockHolder = null - for (thread in wakingThreads) { - val context = GlobalContext.registeredThreads[thread]!! - if (context.state != ThreadState.Completed) { - context.state = ThreadState.Enabled + for (thread in wakingThreads.values) { + if (thread.state != ThreadState.Completed) { + thread.state = ThreadState.Enabled } } - for (thread in lockWaiters.keys) { - val context = GlobalContext.registeredThreads[thread]!! - if (context.state != ThreadState.Completed) { - context.pendingOperation = ThreadResumeOperation() - context.state = ThreadState.Enabled + for (lockWaiter in lockWaiters.values) { + if (lockWaiter.thread.state != ThreadState.Completed) { + lockWaiter.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.thread.state = ThreadState.Enabled } } lockWaiters.clear() @@ -92,9 +96,10 @@ class ReentrantLockContext : LockContext { } override fun interrupt(tid: Long) { - if (lockWaiters[tid] == true) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + val lockWaiter = lockWaiters[tid] ?: return + if (lockWaiter.canInterrupt) { + lockWaiter.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.thread.state = ThreadState.Enabled lockWaiters.remove(tid) } } diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantReadWriteLockContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantReadWriteLockContext.kt index 94d28465..c8ecab9b 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantReadWriteLockContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/ReentrantReadWriteLockContext.kt @@ -1,6 +1,6 @@ package cmu.pasta.fray.core.concurrency.locks -import cmu.pasta.fray.core.GlobalContext +import cmu.pasta.fray.core.ThreadContext import cmu.pasta.fray.core.ThreadState import cmu.pasta.fray.core.concurrency.operations.ThreadResumeOperation import java.util.concurrent.locks.ReentrantReadWriteLock.ReadLock @@ -10,10 +10,10 @@ class ReentrantReadWriteLockContext : LockContext { private var writeLockHolder: Long? = null private val readLockTimes = mutableMapOf() private val writeLockTimes = mutableMapOf() - override val wakingThreads: MutableSet = mutableSetOf() + override val wakingThreads = mutableMapOf() - private val readLockWaiters = mutableMapOf() - private val writeLockWaiters = mutableMapOf() + private val readLockWaiters = mutableMapOf() + private val writeLockWaiters = mutableMapOf() override fun hasQueuedThreads(): Boolean { return writeLockWaiters.any() || wakingThreads.any() || readLockWaiters.any() @@ -25,8 +25,8 @@ class ReentrantReadWriteLockContext : LockContext { readLockWaiters.containsKey(tid) } - override fun addWakingThread(lockObject: Any, t: Thread) { - wakingThreads.add(t.id) + override fun addWakingThread(lockObject: Any, t: ThreadContext) { + wakingThreads[t.thread.id] = t } override fun canLock(tid: Long) = @@ -43,32 +43,39 @@ class ReentrantReadWriteLockContext : LockContext { override fun lock( lock: Any, - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean, ): Boolean { return if (lock is ReadLock) { - readLockLock(tid, shouldBlock, lockBecauseOfWait, canInterrupt) + readLockLock(lockThread, shouldBlock, lockBecauseOfWait, canInterrupt) } else { - writeLockLock(tid, shouldBlock, lockBecauseOfWait, canInterrupt) + writeLockLock(lockThread, shouldBlock, lockBecauseOfWait, canInterrupt) } } override fun interrupt(tid: Long) { - if (writeLockWaiters[tid] == true) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + val writeLockWaiter = writeLockWaiters[tid] + if (writeLockWaiter != null && writeLockWaiter.canInterrupt) { + writeLockWaiter.thread.pendingOperation = ThreadResumeOperation() + writeLockWaiter.thread.state = ThreadState.Enabled writeLockWaiters.remove(tid) } - if (readLockWaiters[tid] == true) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + val readLockWaiter = readLockWaiters[tid] + if (readLockWaiter != null && readLockWaiter.canInterrupt) { + readLockWaiter.thread.pendingOperation = ThreadResumeOperation() + readLockWaiter.thread.state = ThreadState.Enabled readLockWaiters.remove(tid) } } - override fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean): Boolean { + override fun unlock( + lock: Any, + tid: Long, + unlockBecauseOfWait: Boolean, + earlyExit: Boolean + ): Boolean { return if (lock is ReadLock) { readLockUnlock(tid, unlockBecauseOfWait) } else { @@ -77,18 +84,19 @@ class ReentrantReadWriteLockContext : LockContext { } fun readLockLock( - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean ): Boolean { assert(!lockBecauseOfWait) // Read lock does not have `Condition` + val tid = lockThread.thread.id if (writeLockHolder != null && writeLockHolder != tid) { if (canInterrupt) { - GlobalContext.registeredThreads[tid]?.checkInterrupt() + lockThread.checkInterrupt() } if (shouldBlock) { - readLockWaiters[tid] = canInterrupt + readLockWaiters[tid] = LockWaiter(canInterrupt, lockThread) } return false } @@ -98,17 +106,18 @@ class ReentrantReadWriteLockContext : LockContext { } fun writeLockLock( - tid: Long, + lockThread: ThreadContext, shouldBlock: Boolean, lockBecauseOfWait: Boolean, canInterrupt: Boolean ): Boolean { + val tid = lockThread.thread.id if ((writeLockHolder != null && writeLockHolder != tid) || readLockHolder.isNotEmpty()) { if (canInterrupt) { - GlobalContext.registeredThreads[tid]?.checkInterrupt() + lockThread.checkInterrupt() } if (shouldBlock) { - writeLockWaiters[tid] = canInterrupt + writeLockWaiters[tid] = LockWaiter(canInterrupt, lockThread) } return false } @@ -154,21 +163,21 @@ class ReentrantReadWriteLockContext : LockContext { } fun unlockWriteWaiters() { - for (writeLockWaiter in writeLockWaiters.keys) { - GlobalContext.registeredThreads[writeLockWaiter]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[writeLockWaiter]!!.state = ThreadState.Enabled + for (writeLockWaiter in writeLockWaiters.values) { + writeLockWaiter.thread.pendingOperation = ThreadResumeOperation() + writeLockWaiter.thread.state = ThreadState.Enabled } // Waking threads are write waiters as well. - for (thread in wakingThreads) { - GlobalContext.registeredThreads[thread]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[thread]!!.state = ThreadState.Enabled + for (thread in wakingThreads.values) { + thread.pendingOperation = ThreadResumeOperation() + thread.state = ThreadState.Enabled } } fun unlockReadWaiters() { - for (readLockWaiter in readLockWaiters.keys) { - GlobalContext.registeredThreads[readLockWaiter]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[readLockWaiter]!!.state = ThreadState.Enabled + for (readLockWaiter in readLockWaiters.values) { + readLockWaiter.thread.pendingOperation = ThreadResumeOperation() + readLockWaiter.thread.state = ThreadState.Enabled } } diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreContext.kt index b93466e2..29b788be 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreContext.kt @@ -1,22 +1,27 @@ package cmu.pasta.fray.core.concurrency.locks -import cmu.pasta.fray.core.GlobalContext +import cmu.pasta.fray.core.ThreadContext import cmu.pasta.fray.core.ThreadState import cmu.pasta.fray.core.concurrency.operations.ThreadResumeOperation class SemaphoreContext(var totalPermits: Int) : Interruptible { - private val lockWaiters = mutableMapOf>() + private val lockWaiters = mutableMapOf>() - fun acquire(permits: Int, shouldBlock: Boolean, canInterrupt: Boolean): Boolean { + fun acquire( + permits: Int, + shouldBlock: Boolean, + canInterrupt: Boolean, + thread: ThreadContext + ): Boolean { if (totalPermits >= permits) { totalPermits -= permits return true } else { if (canInterrupt) { - GlobalContext.registeredThreads[Thread.currentThread().id]?.checkInterrupt() + thread.checkInterrupt() } if (shouldBlock) { - lockWaiters[Thread.currentThread().id] = Pair(permits, canInterrupt) + lockWaiters[Thread.currentThread().id] = Pair(permits, LockWaiter(canInterrupt, thread)) } } return false @@ -35,8 +40,8 @@ class SemaphoreContext(var totalPermits: Int) : Interruptible { while (it.hasNext()) { val (tid, p) = it.next() if (totalPermits >= p.first) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + p.second.thread.pendingOperation = ThreadResumeOperation() + p.second.thread.state = ThreadState.Enabled lockWaiters.remove(tid) } } @@ -48,9 +53,10 @@ class SemaphoreContext(var totalPermits: Int) : Interruptible { } override fun interrupt(tid: Long) { - if (lockWaiters[tid]?.second == true) { - GlobalContext.registeredThreads[tid]!!.pendingOperation = ThreadResumeOperation() - GlobalContext.registeredThreads[tid]!!.state = ThreadState.Enabled + val lockWaiter = lockWaiters[tid] ?: return + if (lockWaiter.second.canInterrupt) { + lockWaiter.second.thread.pendingOperation = ThreadResumeOperation() + lockWaiter.second.thread.state = ThreadState.Enabled lockWaiters.remove(tid) } } diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreManager.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreManager.kt index 9807ceac..54729468 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreManager.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/SemaphoreManager.kt @@ -1,5 +1,6 @@ package cmu.pasta.fray.core.concurrency.locks +import cmu.pasta.fray.core.ThreadContext import java.util.concurrent.Semaphore class SemaphoreManager { @@ -16,8 +17,16 @@ class SemaphoreManager { lockContextManager.addContext(sem, context) } - fun acquire(sem: Semaphore, permits: Int, shouldBlock: Boolean, canInterrupt: Boolean): Boolean { - return lockContextManager.getLockContext(sem).acquire(permits, shouldBlock, canInterrupt) + fun acquire( + sem: Semaphore, + permits: Int, + shouldBlock: Boolean, + canInterrupt: Boolean, + threadContext: ThreadContext + ): Boolean { + return lockContextManager + .getLockContext(sem) + .acquire(permits, shouldBlock, canInterrupt, threadContext) } fun release(sem: Semaphore, permits: Int) { diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/StampedLockContext.kt b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/StampedLockContext.kt index 7808f9e7..1df9f398 100644 --- a/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/StampedLockContext.kt +++ b/core/src/main/kotlin/cmu/pasta/fray/core/concurrency/locks/StampedLockContext.kt @@ -1,41 +1 @@ package cmu.pasta.fray.core.concurrency.locks - -class StampedLockContext : LockContext { - override val wakingThreads: MutableSet = mutableSetOf() - - override fun hasQueuedThread(tid: Long): Boolean { - TODO("Not yet implemented") - } - - override fun hasQueuedThreads(): Boolean { - TODO("Not yet implemented") - } - - override fun addWakingThread(lockObject: Any, t: Thread) { - TODO("Not yet implemented") - } - - override fun canLock(tid: Long): Boolean { - TODO("Not yet implemented") - } - - override fun lock( - lock: Any, - tid: Long, - shouldBlock: Boolean, - lockBecauseOfWait: Boolean, - canInterrupt: Boolean - ): Boolean { - TODO("Not yet implemented") - } - - override fun interrupt(tid: Long) {} - - override fun unlock(lock: Any, tid: Long, unlockBecauseOfWait: Boolean): Boolean { - TODO("Not yet implemented") - } - - override fun isEmpty(): Boolean { - TODO("Not yet implemented") - } -} diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/randomness/DefaultRandomnessProvider.kt b/core/src/main/kotlin/cmu/pasta/fray/core/randomness/DefaultRandomnessProvider.kt new file mode 100644 index 00000000..ed16d459 --- /dev/null +++ b/core/src/main/kotlin/cmu/pasta/fray/core/randomness/DefaultRandomnessProvider.kt @@ -0,0 +1,45 @@ +package cmu.pasta.fray.core.randomness + +class DefaultRandomnessProvider : RandomnessProvider { + val random = java.util.Random() + val integers = mutableListOf() + val doubles = mutableListOf() + val longs = mutableListOf() + val bytes = mutableListOf() + + override fun nextInt(): Int { + val next = random.nextInt() + integers.add(next) + return next + } + + override fun nextInt(bound: Int): Int { + val next = random.nextInt(bound) + integers.add(next) + return next + } + + override fun nextLong(): Long { + val next = random.nextLong() + longs.add(next) + return next + } + + override fun nextBoolean(): Boolean { + val next = random.nextInt() + integers.add(next) + return next % 2 == 0 + } + + override fun nextDouble(): Double { + val next = random.nextDouble() + doubles.add(next) + return next + } + + override fun nextGaussian(): Double { + val next = random.nextGaussian() + doubles.add(next) + return next + } +} diff --git a/core/src/main/kotlin/cmu/pasta/fray/core/randomness/RandomnessProvider.kt b/core/src/main/kotlin/cmu/pasta/fray/core/randomness/RandomnessProvider.kt new file mode 100644 index 00000000..192cad18 --- /dev/null +++ b/core/src/main/kotlin/cmu/pasta/fray/core/randomness/RandomnessProvider.kt @@ -0,0 +1,15 @@ +package cmu.pasta.fray.core.randomness + +interface RandomnessProvider { + fun nextInt(): Int + + fun nextInt(bound: Int): Int + + fun nextLong(): Long + + fun nextBoolean(): Boolean + + fun nextDouble(): Double + + fun nextGaussian(): Double +} diff --git a/integration-tests/src/main/java/example/BadParallelMergeSort.java b/integration-tests/src/main/java/example/BadParallelMergeSort.java index 2cd40be9..2260597b 100644 --- a/integration-tests/src/main/java/example/BadParallelMergeSort.java +++ b/integration-tests/src/main/java/example/BadParallelMergeSort.java @@ -9,7 +9,7 @@ public class BadParallelMergeSort { static final Object MSLOCK = "MERGESORTLOCK"; private static void log (String format, Object... args) { - GlobalContext.INSTANCE.log(format, args); +// GlobalContext.INSTANCE.log(format, args); } /** diff --git a/integration-tests/src/main/java/example/ConcurrentQueue.java b/integration-tests/src/main/java/example/ConcurrentQueue.java index ac968147..f9268ede 100644 --- a/integration-tests/src/main/java/example/ConcurrentQueue.java +++ b/integration-tests/src/main/java/example/ConcurrentQueue.java @@ -14,7 +14,7 @@ class ConcurrentQueue { private Integer maxSize; private void log (String format, Object... args) { - GlobalContext.INSTANCE.log(format, args); +// GlobalContext.INSTANCE.log(format, args); } public ConcurrentQueue(int size) { diff --git a/integration-tests/src/main/java/example/CounterMap.java b/integration-tests/src/main/java/example/CounterMap.java index e43050d1..20a43e51 100644 --- a/integration-tests/src/main/java/example/CounterMap.java +++ b/integration-tests/src/main/java/example/CounterMap.java @@ -14,7 +14,7 @@ public CounterMap() { } private void log (String format, Object... args) { - GlobalContext.INSTANCE.log(format, args); +// GlobalContext.INSTANCE.log(format, args); } public void putOrIncrement(String s) { diff --git a/integration-tests/src/main/java/example/ParallelMergeSort.java b/integration-tests/src/main/java/example/ParallelMergeSort.java index 29f22161..2942497b 100644 --- a/integration-tests/src/main/java/example/ParallelMergeSort.java +++ b/integration-tests/src/main/java/example/ParallelMergeSort.java @@ -14,7 +14,7 @@ public class ParallelMergeSort { static final Object MSLOCK = "MERGESORTLOCK"; private static void log (String format, Object... args) { - GlobalContext.INSTANCE.log(format, args); +// GlobalContext.INSTANCE.log(format, args); } /** diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/IntegrationTestRunner.java b/integration-tests/src/test/java/cmu/pasta/fray/it/IntegrationTestRunner.java deleted file mode 100644 index edb26ed6..00000000 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/IntegrationTestRunner.java +++ /dev/null @@ -1,71 +0,0 @@ -package cmu.pasta.fray.it; - - -import cmu.pasta.fray.core.*; -import cmu.pasta.fray.core.command.Configuration; -import cmu.pasta.fray.core.command.ExecutionInfo; -import cmu.pasta.fray.core.command.LambdaExecutor; -import cmu.pasta.fray.core.logger.JsonLogger; -import cmu.pasta.fray.core.scheduler.FifoScheduler; -import cmu.pasta.fray.core.scheduler.Scheduler; -import kotlin.Unit; -import kotlin.jvm.functions.Function0; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class IntegrationTestRunner { - public String runTest(Function0 exec) { - return runTest(exec, new FifoScheduler(), 1); - } - - public String runTest(Function0 exec, Scheduler scheduler, int iter) { - String testName = this.getClass().getSimpleName(); - EventLogger logger = new EventLogger(); - GlobalContext.INSTANCE.getLoggers().add(logger); - Configuration config = new Configuration( - new ExecutionInfo( - new LambdaExecutor(() -> { - exec.invoke(); - return null; - }), - false, - true, - false, - 10000 - ), - "/tmp/report", - iter, - scheduler, - true, - List.of(new JsonLogger("/tmp/report", false)), - false, - true, - false - ); - TestRunner runner = new TestRunner(config); - runner.run(); - return logger.sb.toString(); - } - - public String getResourceAsString(String path) { - try(InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(path)) { - BufferedReader reader = new BufferedReader(new InputStreamReader(is)); - StringBuffer sb = new StringBuffer(); - String line; - while ((line = reader.readLine()) != null) { - sb.append(line); - } - return sb.toString(); - } catch (IOException e) { - e.printStackTrace(); - } - return ""; - - } -} \ No newline at end of file diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/Utils.java b/integration-tests/src/test/java/cmu/pasta/fray/it/Utils.java index 28936974..41163630 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/Utils.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/Utils.java @@ -1,9 +1,9 @@ package cmu.pasta.fray.it; -import cmu.pasta.fray.core.GlobalContext; +import cmu.pasta.fray.runtime.Runtime; public class Utils { public static void log(String format, Object... args) { - GlobalContext.INSTANCE.log(format, args); + Runtime.log(format, args); } } diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/core/CountDownLatchTest.java b/integration-tests/src/test/java/cmu/pasta/fray/it/core/CountDownLatchTest.java index 8e97735a..4186d15f 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/core/CountDownLatchTest.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/core/CountDownLatchTest.java @@ -3,9 +3,7 @@ import cmu.edu.pasta.fray.junit.annotations.Analyze; import cmu.edu.pasta.fray.junit.annotations.FrayTest; import cmu.pasta.fray.core.scheduler.FifoScheduler; -import cmu.pasta.fray.it.IntegrationTestRunner; import cmu.pasta.fray.it.Utils; -import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.concurrent.CountDownLatch; diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/core/LivenessTest.java b/integration-tests/src/test/java/cmu/pasta/fray/it/core/LivenessTest.java index 6b85cd12..76e6b4ea 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/core/LivenessTest.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/core/LivenessTest.java @@ -2,17 +2,14 @@ import cmu.edu.pasta.fray.junit.annotations.Analyze; import cmu.edu.pasta.fray.junit.annotations.FrayTest; -import cmu.pasta.fray.core.scheduler.ControlledRandom; import cmu.pasta.fray.core.scheduler.PCTScheduler; -import cmu.pasta.fray.it.IntegrationTestRunner; import cmu.pasta.fray.runtime.DeadlockException; -import org.junit.jupiter.api.Test; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @FrayTest -public class LivenessTest extends IntegrationTestRunner { +public class LivenessTest { static int i = 0; @Analyze( diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/core/ReentrantReadWriteLockTest.java b/integration-tests/src/test/java/cmu/pasta/fray/it/core/ReentrantReadWriteLockTest.java index b378f612..3fbd0841 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/core/ReentrantReadWriteLockTest.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/core/ReentrantReadWriteLockTest.java @@ -2,11 +2,8 @@ import cmu.edu.pasta.fray.junit.annotations.Analyze; import cmu.edu.pasta.fray.junit.annotations.FrayTest; -import cmu.pasta.fray.core.command.Fifo; import cmu.pasta.fray.core.scheduler.FifoScheduler; -import cmu.pasta.fray.it.IntegrationTestRunner; import cmu.pasta.fray.it.Utils; -import org.junit.jupiter.api.Test; import java.util.concurrent.Semaphore; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -14,7 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @FrayTest -public class ReentrantReadWriteLockTest extends IntegrationTestRunner { +public class ReentrantReadWriteLockTest { private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(true); diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/core/SemaphoreTest.java b/integration-tests/src/test/java/cmu/pasta/fray/it/core/SemaphoreTest.java index 4e5ec20a..74007e4c 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/core/SemaphoreTest.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/core/SemaphoreTest.java @@ -3,9 +3,7 @@ import cmu.edu.pasta.fray.junit.annotations.Analyze; import cmu.edu.pasta.fray.junit.annotations.FrayTest; import cmu.pasta.fray.core.scheduler.FifoScheduler; -import cmu.pasta.fray.it.IntegrationTestRunner; import cmu.pasta.fray.it.Utils; -import org.junit.jupiter.api.Test; import java.util.concurrent.Semaphore; @@ -78,7 +76,7 @@ public void run() { // Driver class @FrayTest -public class SemaphoreTest extends IntegrationTestRunner { +public class SemaphoreTest { @Analyze( expectedLog = "[1]: Starting A\n" + diff --git a/integration-tests/src/test/java/cmu/pasta/fray/it/core/ThreadInterruptTest.java b/integration-tests/src/test/java/cmu/pasta/fray/it/core/ThreadInterruptTest.java index 9f7e3d64..cf1494da 100644 --- a/integration-tests/src/test/java/cmu/pasta/fray/it/core/ThreadInterruptTest.java +++ b/integration-tests/src/test/java/cmu/pasta/fray/it/core/ThreadInterruptTest.java @@ -4,14 +4,12 @@ import cmu.edu.pasta.fray.junit.annotations.Analyze; import cmu.edu.pasta.fray.junit.annotations.FrayTest; import cmu.pasta.fray.core.scheduler.FifoScheduler; -import cmu.pasta.fray.it.IntegrationTestRunner; import cmu.pasta.fray.it.Utils; -import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.*; @FrayTest -public class ThreadInterruptTest extends IntegrationTestRunner { +public class ThreadInterruptTest { @Analyze( scheduler = FifoScheduler.class ) diff --git a/runtime/src/main/java/cmu/pasta/fray/runtime/Delegate.java b/runtime/src/main/java/cmu/pasta/fray/runtime/Delegate.java index 0e7a5848..5ad8c487 100644 --- a/runtime/src/main/java/cmu/pasta/fray/runtime/Delegate.java +++ b/runtime/src/main/java/cmu/pasta/fray/runtime/Delegate.java @@ -272,5 +272,8 @@ public long onNanoTime() { public int onThreadHashCode(Object t) { return t.hashCode(); } + + public void log(String format, Object... args) { + } } diff --git a/runtime/src/main/java/cmu/pasta/fray/runtime/Runtime.java b/runtime/src/main/java/cmu/pasta/fray/runtime/Runtime.java index 53f84696..bcd68f42 100644 --- a/runtime/src/main/java/cmu/pasta/fray/runtime/Runtime.java +++ b/runtime/src/main/java/cmu/pasta/fray/runtime/Runtime.java @@ -362,4 +362,8 @@ public static long onNanoTime() { public static int onThreadHashCode(Object t) { return DELEGATE.onThreadHashCode(t); } + + public static void log(String format, Object... args) { + DELEGATE.log(format, args); + } }