diff --git a/bootstrap/src/sun/nio/ch/lincheck/EventTracker.java b/bootstrap/src/sun/nio/ch/lincheck/EventTracker.java index d41ada7a5..7d939ba07 100644 --- a/bootstrap/src/sun/nio/ch/lincheck/EventTracker.java +++ b/bootstrap/src/sun/nio/ch/lincheck/EventTracker.java @@ -31,23 +31,30 @@ public interface EventTracker { void beforeNewObjectCreation(String className); void afterNewObjectCreation(Object obj); + void afterObjectInitialization(Object obj); - boolean beforeReadField(Object obj, String className, String fieldName, int codeLocation, + boolean beforeReadField(Object obj, String className, String fieldName, String typeDescriptor, int codeLocation, boolean isStatic, boolean isFinal); - boolean beforeReadArrayElement(Object array, int index, int codeLocation); + boolean beforeReadArrayElement(Object array, int index, String typeDescriptor, int codeLocation); void afterRead(Object value); - boolean beforeWriteField(Object obj, String className, String fieldName, Object value, int codeLocation, + Object interceptReadResult(); + + boolean beforeWriteField(Object obj, String className, String fieldName, String typeDescriptor, Object value, int codeLocation, boolean isStatic, boolean isFinal); - boolean beforeWriteArrayElement(Object array, int index, Object value, int codeLocation); + boolean beforeWriteArrayElement(Object array, int index, String typeDescriptor, Object value, int codeLocation); void afterWrite(); void afterReflectiveSetter(Object receiver, Object value); - void beforeMethodCall(Object owner, String className, String methodName, int codeLocation, int methodId, Object[] params); + void onArrayCopy(Object srcArray, int srcPos, Object dstArray, int dstPos, int length); + + boolean beforeMethodCall(Object owner, String className, String methodName, int codeLocation, int methodId, Object[] params); void onMethodCallReturn(Object result); void onMethodCallException(Throwable t); + Object interceptMethodCallResult(); + Random getThreadLocalRandom(); int randomNextInt(); diff --git a/bootstrap/src/sun/nio/ch/lincheck/Injections.java b/bootstrap/src/sun/nio/ch/lincheck/Injections.java index b4fddc0f1..ea4f3bb8d 100644 --- a/bootstrap/src/sun/nio/ch/lincheck/Injections.java +++ b/bootstrap/src/sun/nio/ch/lincheck/Injections.java @@ -186,10 +186,10 @@ public static boolean isRandom(Object any) { * * @return whether the trace point was created */ - public static boolean beforeReadField(Object obj, String className, String fieldName, int codeLocation, + public static boolean beforeReadField(Object obj, String className, String fieldName, String typeDescriptor, int codeLocation, boolean isStatic, boolean isFinal) { if (!isStatic && obj == null) return false; // Ignore, NullPointerException will be thrown - return getEventTracker().beforeReadField(obj, className, fieldName, codeLocation, isStatic, isFinal); + return getEventTracker().beforeReadField(obj, className, fieldName, typeDescriptor, codeLocation, isStatic, isFinal); } /** @@ -197,9 +197,13 @@ public static boolean beforeReadField(Object obj, String className, String field * * @return whether the trace point was created */ - public static boolean beforeReadArray(Object array, int index, int codeLocation) { + public static boolean beforeReadArray(Object array, int index, String typeDescriptor, int codeLocation) { if (array == null) return false; // Ignore, NullPointerException will be thrown - return getEventTracker().beforeReadArrayElement(array, index, codeLocation); + return getEventTracker().beforeReadArrayElement(array, index, typeDescriptor, codeLocation); + } + + public static Object interceptReadResult() { + return getEventTracker().interceptReadResult(); } /** @@ -214,10 +218,10 @@ public static void afterRead(Object value) { * * @return whether the trace point was created */ - public static boolean beforeWriteField(Object obj, String className, String fieldName, Object value, int codeLocation, + public static boolean beforeWriteField(Object obj, String className, String fieldName, String typeDescriptor, Object value, int codeLocation, boolean isStatic, boolean isFinal) { if (!isStatic && obj == null) return false; // Ignore, NullPointerException will be thrown - return getEventTracker().beforeWriteField(obj, className, fieldName, value, codeLocation, isStatic, isFinal); + return getEventTracker().beforeWriteField(obj, className, fieldName, typeDescriptor, value, codeLocation, isStatic, isFinal); } /** @@ -225,9 +229,9 @@ public static boolean beforeWriteField(Object obj, String className, String fiel * * @return whether the trace point was created */ - public static boolean beforeWriteArray(Object array, int index, Object value, int codeLocation) { + public static boolean beforeWriteArray(Object array, int index, String typeDescriptor, Object value, int codeLocation) { if (array == null) return false; // Ignore, NullPointerException will be thrown - return getEventTracker().beforeWriteArrayElement(array, index, value, codeLocation); + return getEventTracker().beforeWriteArrayElement(array, index, typeDescriptor, value, codeLocation); } /** @@ -250,13 +254,27 @@ public static void afterReflectiveSetter(Object receiver, Object value) { getEventTracker().afterReflectiveSetter(receiver, value); } + public static void onArrayCopy(Object srcArray, int srcPos, Object dstArray, int dstPos, int length) { + getEventTracker().onArrayCopy(srcArray, srcPos, dstArray, dstPos, length); + } + /** * Called from the instrumented code before any method call. * * @param owner is `null` for public static methods. + * @return true if the method result should be intercepted. TODO: revisit this API decision */ - public static void beforeMethodCall(Object owner, String className, String methodName, int codeLocation, int methodId, Object[] params) { - getEventTracker().beforeMethodCall(owner, className, methodName, codeLocation, methodId, params); + public static boolean beforeMethodCall(Object owner, String className, String methodName, int codeLocation, int methodId, Object[] params) { + return getEventTracker().beforeMethodCall(owner, className, methodName, codeLocation, methodId, params); + } + + /** + * Intercepts the result of a method call. + * + * @return The intercepted result of the method call. + */ + public static Object interceptMethodCallResult() { + return getEventTracker().interceptMethodCallResult(); } /** @@ -294,6 +312,10 @@ public static void afterNewObjectCreation(Object obj) { getEventTracker().afterNewObjectCreation(obj); } + public static void afterObjectInitialization(Object obj) { + getEventTracker().afterObjectInitialization(obj); + } + /** * Called from the instrumented code to replace [java.lang.Object.hashCode] method call with some * deterministic value. diff --git a/build.gradle.kts b/build.gradle.kts index 7bfd83ff4..d28d97fbb 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -26,7 +26,7 @@ repositories { kotlin { @OptIn(ExperimentalKotlinGradlePluginApi::class) compilerOptions { - allWarningsAsErrors = true + // allWarningsAsErrors = true } jvm { @@ -119,7 +119,16 @@ tasks { if (instrumentAllClasses.toBoolean()) { systemProperty("lincheck.instrumentAllClasses", "true") } - val extraArgs = mutableListOf() + val extraArgs = mutableListOf( + // flags to import Unsafe module; + // it is used in some tests to check handling of unsafe APIs by Lincheck + "--add-opens", "java.base/jdk.internal.misc=ALL-UNNAMED", + "--add-exports", "java.base/jdk.internal.util=ALL-UNNAMED", + ) + val useExperimentalModelChecking: String by project + if (useExperimentalModelChecking.toBoolean()) { + extraArgs.add("-Dlincheck.useExperimentalModelChecking=true") + } val withEventIdSequentialCheck: String by project if (withEventIdSequentialCheck.toBoolean()) { extraArgs.add("-Dlincheck.debug.withEventIdSequentialCheck=true") @@ -157,6 +166,8 @@ tasks { ideaActive -> 10 else -> 0 } + // temporarily ignore representation tests, because they are unsupported in the new strategy + // exclude("org/jetbrains/kotlinx/lincheck_test/representation") } val jvmTestIsolated = register("jvmTestIsolated") { diff --git a/gradle.properties b/gradle.properties index 21c2eefee..3faceda17 100644 --- a/gradle.properties +++ b/gradle.properties @@ -18,6 +18,7 @@ lastCopyrightYear=2023 jdkToolchainVersion=17 runAllTestsInSeparateJVMs=false instrumentAllClasses=false +useExperimentalModelChecking=false withEventIdSequentialCheck=false kotlinVersion=1.9.21 diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/CTestConfiguration.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/CTestConfiguration.kt index af395ccbb..5caa4f5d5 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/CTestConfiguration.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/CTestConfiguration.kt @@ -102,7 +102,8 @@ internal fun createFromTestClassAnnotations(testClass: Class<*>): List, object } } -private fun readField(obj: Any?, field: Field): Any? { - if (!field.type.isPrimitive) { - return readFieldViaUnsafe(obj, field, Unsafe::getObject) - } - return when (field.type) { - Boolean::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getBoolean) - Byte::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getByte) - Char::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getChar) - Short::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getShort) - Int::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getInt) - Long::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getLong) - Double::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getDouble) - Float::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getFloat) - else -> error("No more types expected") - } -} - private fun isAtomic(value: Any?): Boolean { if (value == null) return false return value.javaClass.canonicalName.let { diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/Reporter.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/Reporter.kt index da64b09ee..5f43aa7f3 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/Reporter.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/Reporter.kt @@ -45,7 +45,7 @@ class Reporter(private val logLevel: LoggingLevel) { } } -@JvmField val DEFAULT_LOG_LEVEL = WARN +@JvmField val DEFAULT_LOG_LEVEL = INFO enum class LoggingLevel { INFO, WARN } diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/Utils.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/Utils.kt index a88706336..4529a6722 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/Utils.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/Utils.kt @@ -182,7 +182,7 @@ internal fun CancellableContinuation.cancelByLincheck(promptCancellation: } } -internal enum class CancellationResult { CANCELLED_BEFORE_RESUMPTION, CANCELLED_AFTER_RESUMPTION, CANCELLATION_FAILED } +enum class CancellationResult { CANCELLED_BEFORE_RESUMPTION, CANCELLED_AFTER_RESUMPTION, CANCELLATION_FAILED } /** * Returns `true` if the continuation was cancelled by [CancellableContinuation.cancel]. @@ -233,17 +233,21 @@ internal val Class<*>.allDeclaredFieldWithSuperclasses get(): List { @Suppress("DEPRECATION") internal fun findFieldNameByOffset(targetType: Class<*>, offset: Long): String? { // Extract the private offset value and find the matching field. - for (field in targetType.declaredFields) { + for (field in targetType.allDeclaredFieldWithSuperclasses) { try { if (Modifier.isNative(field.modifiers)) continue - val fieldOffset = if (Modifier.isStatic(field.modifiers)) UnsafeHolder.UNSAFE.staticFieldOffset(field) - else UnsafeHolder.UNSAFE.objectFieldOffset(field) - if (fieldOffset == offset) return field.name + val fieldOffset = + if (Modifier.isStatic(field.modifiers)) + UnsafeHolder.UNSAFE.staticFieldOffset(field) + else + UnsafeHolder.UNSAFE.objectFieldOffset(field) + if (fieldOffset == offset) { + return field.name + } } catch (t: Throwable) { t.printStackTrace() } } - return null // Field not found } diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/execution/HBClock.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/execution/HBClock.kt index ca4cd17c4..5dfb098cf 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/execution/HBClock.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/execution/HBClock.kt @@ -13,15 +13,32 @@ import org.jetbrains.kotlinx.lincheck.Result data class HBClock(val clock: IntArray) { val threads: Int get() = clock.size - + val empty: Boolean get() = clock.all { it == 0 } operator fun get(i: Int) = clock[i] + fun set(other: HBClock) { + check(clock.size == other.clock.size) + for (i in clock.indices) { + clock[i] = other.clock[i] + } + } + + fun reset() { + for (i in clock.indices) { + clock[i] = 0 + } + } + /** * Checks whether the clock contains information for any thread * excluding the one this clock is associated with. */ fun isEmpty(clockThreadId: Int) = clock.filterIndexed { t, _ -> t != clockThreadId }.all { it == 0 } + fun copy(): HBClock { + return HBClock(clock.copyOf()) + } + override fun toString() = clock.joinToString(prefix = "[", separator = ",", postfix = "]") override fun equals(other: Any?): Boolean { diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/InvocationResult.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/InvocationResult.kt index d2b4c1c8e..4ecf4cf52 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/InvocationResult.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/InvocationResult.kt @@ -11,6 +11,7 @@ package org.jetbrains.kotlinx.lincheck.runner import org.jetbrains.kotlinx.lincheck.execution.* import org.jetbrains.kotlinx.lincheck.strategy.managed.ManagedStrategy +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency.Inconsistency /** * Represents results for invocations, see [Runner.run]. @@ -65,7 +66,29 @@ class ObstructionFreedomViolationInvocationResult( val results: ExecutionResult ) : InvocationResult() +class InconsistentInvocationResult( + val inconsistency: Inconsistency +) : InvocationResult() + +/** + * Invocation is aborted due to one of the threads reaching + * the bound on the number of spin-loop iterations. + */ +class SpinLoopBoundInvocationResult : InvocationResult() + /** * Indicates that spin-cycle has been found for the first time and replay of current interleaving is required. */ -data object SpinCycleFoundAndReplayRequired: InvocationResult() \ No newline at end of file +data object SpinCycleFoundAndReplayRequired: InvocationResult() + +fun InvocationResult.isAbortedInvocation(): Boolean = + when (this) { + is ManagedDeadlockInvocationResult, + is RunnerTimeoutInvocationResult, + is SpinLoopBoundInvocationResult, + is UnexpectedExceptionInvocationResult, + is ObstructionFreedomViolationInvocationResult, + is InconsistentInvocationResult + -> true + else -> false + } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt index 3859ba290..226efcb73 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt @@ -17,8 +17,7 @@ import org.jetbrains.kotlinx.lincheck.runner.ExecutionPart.* import org.jetbrains.kotlinx.lincheck.runner.ParallelThreadsRunner.Completion.* import org.jetbrains.kotlinx.lincheck.runner.UseClocks.* import org.jetbrains.kotlinx.lincheck.strategy.* -import org.jetbrains.kotlinx.lincheck.strategy.managed.ManagedStrategy -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingStrategy +import org.jetbrains.kotlinx.lincheck.strategy.managed.* import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent import org.jetbrains.kotlinx.lincheck.util.* import sun.nio.ch.lincheck.* @@ -101,7 +100,9 @@ internal open class ParallelThreadsRunner( protected inner class Completion(private val iThread: Int, private val actorId: Int) : Continuation { val resWithCont = SuspensionPointResultWithContinuation(null) - override var context = ParallelThreadRunnerInterceptor(resWithCont) + StoreExceptionHandler() + Job() + private val interceptor = ParallelThreadRunnerInterceptor(resWithCont) + + override val context = interceptor + StoreExceptionHandler() + Job() // We need to run this code in an ignored section, // as it is called in the testing code but should not be analyzed. @@ -118,12 +119,13 @@ internal open class ParallelThreadsRunner( } // write function's final result suspensionPointResults[iThread][actorId] = createLincheckResult(result) + onResumeCoroutine(iThread, actorId) } } fun reset() { resWithCont.set(null) - context = ParallelThreadRunnerInterceptor(resWithCont) + StoreExceptionHandler() + Job() + interceptor.reset(resWithCont) } /** @@ -136,9 +138,12 @@ internal open class ParallelThreadsRunner( private var resWithCont: SuspensionPointResultWithContinuation ) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor { + var continuation: Continuation? = null + // We need to run this code in an ignored section, // as it is called in the testing code but should not be analyzed. override fun interceptContinuation(continuation: Continuation): Continuation = runInIgnoredSection { + this.continuation = (continuation as Continuation) return Continuation(StoreExceptionHandler() + Job()) { result -> runInIgnoredSection { // decrement completed or suspended threads only if the operation was not cancelled @@ -149,15 +154,22 @@ internal open class ParallelThreadsRunner( completedOrSuspendedThreads.incrementAndGet() } @Suppress("UNCHECKED_CAST") - resWithCont.set(result to continuation as Continuation) + resWithCont.set(result to this.continuation as Continuation) + onResumeCoroutine(iThread, actorId) } } } } + + fun reset(resWithCont: SuspensionPointResultWithContinuation) { + this.resWithCont = resWithCont + this.continuation = null + } } } private fun resetState() { + currentExecutionPart = null suspensionPointResults.forEach { it.fill(NoResult) } completedOrSuspendedThreads.set(0) completions.forEach { @@ -185,7 +197,7 @@ internal open class ParallelThreadsRunner( private fun createTestInstance() { @Suppress("DEPRECATION") testInstance = testClass.newInstance() - if (strategy is ModelCheckingStrategy) { + if (strategy is ManagedStrategy) { // We pass the test instance to the strategy to initialize the call stack. // It should be done here as we create the test instance in the `run` method in the runner, after // `initializeInvocation` method call of ManagedStrategy. @@ -230,8 +242,6 @@ internal open class ParallelThreadsRunner( return finalResult } - override fun afterCoroutineCancelled(iThread: Int) {} - // We need to run this code in an ignored section, // as it is called in the testing code but should not be analyzed. private fun waitAndInvokeFollowUp(thread: TestThread, actorId: Int): Result = runInIgnoredSection { @@ -244,13 +254,19 @@ internal open class ParallelThreadsRunner( // wait for the final result of the method call otherwise. val completion = completions[threadId][actorId] // Check if the coroutine is already resumed and if not, enter the spin loop. + var blocked = false if (!isCoroutineResumed(threadId, actorId)) { spinners[threadId].spinWaitUntil { // Check whether the scenario is completed and the current suspended operation cannot be resumed. - if (currentExecutionPart == POST || isParallelExecutionCompleted) { + if (currentExecutionPart == POST || isParallelExecutionCompleted || blocked) { suspensionPointResults[threadId][actorId] = NoResult return Suspended } + if (strategy is ManagedStrategy) { + strategy.switchCurrentThread(threadId, SwitchReason.STRATEGY_SWITCH, mustSwitch = true) + blocked = strategy.isBlocked() + return@spinWaitUntil false + } // Wait until coroutine is resumed. isCoroutineResumed(threadId, actorId) } @@ -287,6 +303,10 @@ internal open class ParallelThreadsRunner( override fun afterCoroutineResumed(iThread: Int) {} + override fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, result: CancellationResult) {} + + override fun onResumeCoroutine(iResumedThread: Int, iResumedActor: Int) {} + // We cannot use `completionStatuses` here since // they are set _before_ the result is published. override fun isCoroutineResumed(iThread: Int, actorId: Int) = @@ -361,7 +381,6 @@ internal open class ParallelThreadsRunner( afterPostStateRepresentation = afterPostStateRepresentation ) - private fun createInitialPartExecution() = if (scenario.initExecution.isNotEmpty()) { TestThreadExecutionGenerator.create(this, INIT_THREAD_ID, diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/Runner.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/Runner.kt index b4fb3453f..35ae1331e 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/Runner.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/Runner.kt @@ -29,7 +29,7 @@ abstract class Runner protected constructor( protected val completedOrSuspendedThreads = AtomicInteger(0) var currentExecutionPart: ExecutionPart? = null - private set + protected set /** * Returns the current state representation of the test instance constructed via @@ -82,7 +82,12 @@ abstract class Runner protected constructor( * This method is invoked by the corresponding test thread * when the current coroutine is cancelled. */ - abstract fun afterCoroutineCancelled(iThread: Int) + abstract fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, result: CancellationResult) + + /** + * This method is invoked by a test thread that attempts to resume coroutine. + */ + abstract fun onResumeCoroutine(iResumedThread: Int, iResumedActor: Int) /** * Returns `true` if the coroutine corresponding to @@ -102,14 +107,14 @@ abstract class Runner protected constructor( * Is invoked after each actor execution from the specified thread, even if a legal exception was thrown. * The invocations are inserted into the generated code. */ - fun onActorFinish() { - strategy.onActorFinish() + fun onActorFinish(iThread: Int) { + strategy.onActorFinish(iThread) } fun beforePart(part: ExecutionPart) { + strategy.beforePart(part) completedOrSuspendedThreads.set(0) currentExecutionPart = part - strategy.beforePart(part) } /** diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/TestThreadExecutionGenerator.java b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/TestThreadExecutionGenerator.java index ecd80a027..4df2d5066 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/TestThreadExecutionGenerator.java +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/runner/TestThreadExecutionGenerator.java @@ -38,7 +38,7 @@ public class TestThreadExecutionGenerator { private static final Method RUNNER_ON_START_METHOD = new Method("onStart", VOID_TYPE, new Type[]{INT_TYPE}); private static final Method RUNNER_ON_FINISH_METHOD = new Method("onFinish", VOID_TYPE, new Type[]{INT_TYPE}); private static final Method RUNNER_ON_ACTOR_START = new Method("onActorStart", Type.VOID_TYPE, new Type[]{ Type.INT_TYPE }); - private static final Method RUNNER_ON_ACTOR_FINISH = new Method("onActorFinish", Type.VOID_TYPE, NO_ARGS); + private static final Method RUNNER_ON_ACTOR_FINISH = new Method("onActorFinish", Type.VOID_TYPE, new Type[]{ Type.INT_TYPE }); private static final Type TEST_THREAD_EXECUTION_TYPE = getType(TestThreadExecution.class); private static final Method TEST_THREAD_EXECUTION_CONSTRUCTOR; @@ -251,6 +251,7 @@ private static void generateRun(ClassVisitor cv, Type testType, int iThread, Lis // Invoke runner onActorFinish method mv.loadThis(); mv.getField(TEST_THREAD_EXECUTION_TYPE, "runner", RUNNER_TYPE); + mv.push(iThread); mv.invokeVirtual(RUNNER_TYPE, RUNNER_ON_ACTOR_FINISH); // Increment the clock mv.loadThis(); diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/Strategy.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/Strategy.kt index d388e70a5..0bab6cf41 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/Strategy.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/Strategy.kt @@ -12,6 +12,7 @@ package org.jetbrains.kotlinx.lincheck.strategy import org.jetbrains.kotlinx.lincheck.runner.* import org.jetbrains.kotlinx.lincheck.execution.ExecutionScenario import org.jetbrains.kotlinx.lincheck.strategy.managed.Trace +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.EventStructureStrategy import org.jetbrains.kotlinx.lincheck.verifier.Verifier import java.io.Closeable @@ -82,7 +83,7 @@ abstract class Strategy protected constructor( /** * Is invoked after each actor execution, even if a legal exception was thrown */ - open fun onActorFinish() {} + open fun onActorFinish(iThread: Int) {} /** * Closes the strategy and releases any resources associated with it. @@ -101,15 +102,19 @@ abstract class Strategy protected constructor( * @return the failure, if detected, null otherwise. */ fun Strategy.runIteration(invocations: Int, verifier: Verifier): LincheckFailure? { + var failure: LincheckFailure? = null for (invocation in 0 until invocations) { if (!nextInvocation()) - return null + break val result = runInvocation() - val failure = verify(result, verifier) + // TODO: should we count failed inconsistent executions as used invocations? + if (result is InconsistentInvocationResult) continue + failure = verify(result, verifier) if (failure != null) - return failure + break } - return null + printStatistics() + return failure } /** @@ -126,6 +131,14 @@ fun Strategy.verify(result: InvocationResult, verifier: Verifier): LincheckFailu if (!verifier.verifyResults(scenario, result.results)) { IncorrectResultsFailure(scenario, result.results, tryCollectTrace(result)) } else null - else -> - result.toLincheckFailure(scenario, tryCollectTrace(result)) + + is SpinLoopBoundInvocationResult -> null + + else -> result.toLincheckFailure(scenario, tryCollectTrace(result)) +} + +private fun Strategy.printStatistics() { + if (this is EventStructureStrategy) { + println(stats) + } } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/AtomicFieldUpdaterNames.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/AtomicFieldUpdaterNames.kt index 53dbae8ec..a08bfe55e 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/AtomicFieldUpdaterNames.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/AtomicFieldUpdaterNames.kt @@ -25,6 +25,10 @@ internal object AtomicFieldUpdaterNames { @Suppress("DEPRECATION") internal fun getAtomicFieldUpdaterName(updater: Any): String? { + return getAtomicFieldUpdaterInfo(updater)?.fieldName + } + + internal fun getAtomicFieldUpdaterInfo(updater: Any): AtomicFieldUpdaterInfo? { if (updater !is AtomicIntegerFieldUpdater<*> && updater !is AtomicLongFieldUpdater<*> && updater !is AtomicReferenceFieldUpdater<*, *>) { throw IllegalArgumentException("Provided object is not a recognized Atomic*FieldUpdater type.") } @@ -32,16 +36,16 @@ internal object AtomicFieldUpdaterNames { try { // Cannot use neither reflection not MethodHandles.Lookup, as they lead to a warning. val tclassField = updater.javaClass.getDeclaredField("tclass") - val targetType = UNSAFE.getObject(updater, UNSAFE.objectFieldOffset(tclassField)) as Class<*> - + val tclass = UNSAFE.getObject(updater, UNSAFE.objectFieldOffset(tclassField)) as Class<*> val offsetField = updater.javaClass.getDeclaredField("offset") val offset = UNSAFE.getLong(updater, UNSAFE.objectFieldOffset(offsetField)) - - return findFieldNameByOffset(targetType, offset) + val fieldName = findFieldNameByOffset(tclass, offset) ?: return null + return AtomicFieldUpdaterInfo(tclass.name, fieldName) } catch (t: Throwable) { t.printStackTrace() } - return null // Field not found } -} \ No newline at end of file +} + +data class AtomicFieldUpdaterInfo(val className: String, val fieldName: String) \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ManagedStrategy.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ManagedStrategy.kt index df4ed8eec..593e32373 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ManagedStrategy.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ManagedStrategy.kt @@ -28,8 +28,10 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.ObjectLabelFactory.adorne import org.jetbrains.kotlinx.lincheck.strategy.managed.ObjectLabelFactory.cleanObjectNumeration import org.jetbrains.kotlinx.lincheck.strategy.managed.UnsafeName.* import org.jetbrains.kotlinx.lincheck.strategy.managed.VarHandleMethodType.* +import org.objectweb.asm.Type +import org.objectweb.asm.Type.* import java.lang.invoke.VarHandle -import java.lang.reflect.* +import java.lang.reflect.Method import java.util.* import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED @@ -46,7 +48,7 @@ abstract class ManagedStrategy( scenario: ExecutionScenario, private val validationFunction: Actor?, private val stateRepresentationFunction: Method?, - private val testCfg: ManagedCTestConfiguration + private val testCfg: ManagedCTestConfiguration, ) : Strategy(scenario), EventTracker { // The number of parallel threads. protected val nThreads: Int = scenario.nThreads @@ -62,7 +64,8 @@ abstract class ManagedStrategy( // Which thread is allowed to perform operations? @Volatile - protected var currentThread: Int = 0 + var currentThread: Int = 0 + protected set // Which threads finished all the operations? private val finished = BooleanArray(nThreads) { false } @@ -70,6 +73,9 @@ abstract class ManagedStrategy( // Which threads are suspended? private val isSuspended = BooleanArray(nThreads) { false } + // Which threads are spin-bound blocked? + private val isSpinBoundBlocked = BooleanArray(nThreads) { false } + // Current actor id for each thread. protected val currentActorId = IntArray(nThreads) @@ -78,11 +84,15 @@ abstract class ManagedStrategy( // Tracker of objects' allocations and object graph topology. protected abstract val objectTracker: ObjectTracker + // Tracker of shared memory accesses. + protected abstract val memoryTracker: MemoryTracker? // Tracker of the monitors' operations. protected abstract val monitorTracker: MonitorTracker // Tracker of the thread parking. protected abstract val parkingTracker: ParkingTracker + protected open val trackFinalFields: Boolean = false + // InvocationResult that was observed by the strategy during the execution (e.g., a deadlock). @Volatile protected var suddenInvocationResult: InvocationResult? = null @@ -192,7 +202,7 @@ abstract class ManagedStrategy( * Returns whether thread should switch at the switch point. * @param iThread the current thread */ - protected abstract fun shouldSwitch(iThread: Int): Boolean + protected abstract fun shouldSwitch(iThread: Int): ThreadSwitchDecision /** * Choose a thread to switch from thread [iThread]. @@ -200,19 +210,24 @@ abstract class ManagedStrategy( */ protected abstract fun chooseThread(iThread: Int): Int + enum class ThreadSwitchDecision { NOT, MAY, MUST } + /** * Resets all internal data to the initial state and initializes current invocation to be run. */ protected open fun initializeInvocation() { finished.fill(false) isSuspended.fill(false) + isSpinBoundBlocked.fill(false) currentActorId.fill(-1) traceCollector = if (collectTrace) TraceCollector() else null suddenInvocationResult = null callStackTrace.forEach { it.clear() } suspendedFunctionsStack.forEach { it.clear() } randoms.forEachIndexed { i, r -> r.setSeed(i + 239L) } + loopDetector.initialize() objectTracker.reset() + memoryTracker?.reset() monitorTracker.reset() parkingTracker.reset() } @@ -223,13 +238,17 @@ abstract class ManagedStrategy( override fun runInvocation(): InvocationResult { while (true) { initializeInvocation() - val result = runner.run() + val result = runInvocationImpl() // In case the runner detects a deadlock, some threads can still manipulate the current strategy, // so we're not interested in suddenInvocationResult in this case // and immediately return RunnerTimeoutInvocationResult. if (result is RunnerTimeoutInvocationResult) { return result } + // We also immediately return any inconsistent result + if (result is InconsistentInvocationResult) { + return result + } // If strategy has not detected a sudden invocation result, // then return, otherwise process the sudden result. val suddenResult = suddenInvocationResult ?: return result @@ -246,12 +265,25 @@ abstract class ManagedStrategy( } } + // TODO: better name? + protected open fun runInvocationImpl(): InvocationResult { + return runner.run() + } + protected open fun enableSpinCycleReplay() {} // == BASIC STRATEGY METHODS == - override fun beforePart(part: ExecutionPart) { + override fun beforePart(part: ExecutionPart) = runInIgnoredSection { traceCollector?.passCodeLocation(SectionDelimiterTracePoint(part)) + val nextThread = when (part) { + INIT -> 0 + PARALLEL -> chooseThread(0) + POST -> 0 + VALIDATION -> 0 + } + loopDetector.beforePart(nextThread) + currentThread = nextThread } /** @@ -279,8 +311,12 @@ abstract class ManagedStrategy( ) cleanObjectNumeration() - runner.close() - runner = createRunner() + // TODO: for the event structure strategy, we cannot re-create the runner object, + // since it may lead to non-determinism, due to different values that can be + // read from tracked memory locations; + // INVESTIGATE IT FURTHER! + // runner.close() + // runner = createRunner() val loggedResults = runInvocation() // In case the runner detects a deadlock, some threads can still be in an active state, @@ -362,7 +398,7 @@ abstract class ManagedStrategy( // check we are in the right thread check(iThread == currentThread) // check if we need to switch - val shouldSwitch = when { + val threadSwitchDecision = when { /* * When replaying executions, it's important to repeat the same thread switches * recorded in the loop detector history during the last execution. @@ -381,25 +417,31 @@ abstract class ManagedStrategy( * the spin cycle in thread 1, so no bug will appear. */ loopDetector.replayModeEnabled -> - loopDetector.shouldSwitchInReplayMode() + if (loopDetector.shouldSwitchInReplayMode()) + ThreadSwitchDecision.MUST + else + ThreadSwitchDecision.NOT /* * In the regular mode, we use loop detector only to determine should we * switch current thread or not due to new or early detection of spin locks. * Regular thread switches are dictated by the current interleaving. */ else -> - (runner.currentExecutionPart == PARALLEL) && shouldSwitch(iThread) + if (runner.currentExecutionPart == PARALLEL) + shouldSwitch(iThread) + else + ThreadSwitchDecision.NOT } // check if live-lock is detected - val decision = loopDetector.visitCodeLocation(iThread, codeLocation) + val loopDetectorDecision = loopDetector.visitCodeLocation(iThread, codeLocation) // if we reached maximum number of events threshold, then fail immediately - if (decision == LoopDetector.Decision.EventsThresholdReached) { + if (loopDetectorDecision == LoopDetector.Decision.EventsThresholdReached) { failDueToDeadlock() } // if any kind of live-lock was detected, check for obstruction-freedom violation - if (decision.isLivelockDetected) { + if (loopDetectorDecision.isLivelockDetected) { failIfObstructionFreedomIsRequired { - if (decision is LoopDetector.Decision.LivelockFailureDetected) { + if (loopDetectorDecision is LoopDetector.Decision.LivelockFailureDetected) { // if failure is detected, add a special obstruction-freedom violation // trace point to account for that traceCollector?.passObstructionFreedomViolationTracePoint(currentThread, beforeMethodCall = tracePoint is MethodCallTracePoint) @@ -411,18 +453,18 @@ abstract class ManagedStrategy( } } // if live-lock failure was detected, then fail immediately - if (decision is LoopDetector.Decision.LivelockFailureDetected) { + if (loopDetectorDecision is LoopDetector.Decision.LivelockFailureDetected) { traceCollector?.newSwitch(currentThread, SwitchReason.ACTIVE_LOCK, beforeMethodCallSwitch = tracePoint is MethodCallTracePoint) failDueToDeadlock() } // if live-lock was detected, and replay was requested, // then abort current execution and start the replay - if (decision.isReplayRequired) { + if (loopDetectorDecision.isReplayRequired) { suddenInvocationResult = SpinCycleFoundAndReplayRequired throw ForcibleExecutionFinishError } // if the current thread in a live-lock, then try to switch to another thread - if (decision is LoopDetector.Decision.LivelockThreadSwitch) { + if (loopDetectorDecision is LoopDetector.Decision.LivelockThreadSwitch) { val switchHappened = switchCurrentThread(iThread, SwitchReason.ACTIVE_LOCK, tracePoint = tracePoint) if (switchHappened) { loopDetector.initializeFirstCodeLocationAfterSwitch(codeLocation) @@ -431,7 +473,7 @@ abstract class ManagedStrategy( return } // if strategy requested thread switch, then do it - if (shouldSwitch) { + if (threadSwitchDecision != ThreadSwitchDecision.NOT) { val switchHappened = switchCurrentThread(iThread, SwitchReason.STRATEGY_SWITCH, tracePoint = tracePoint) if (switchHappened) { loopDetector.initializeFirstCodeLocationAfterSwitch(codeLocation) @@ -451,6 +493,9 @@ abstract class ManagedStrategy( */ open fun onStart(iThread: Int) { awaitTurn(iThread) + while (!isActive(iThread)) { + switchCurrentThread(iThread, mustSwitch = true) + } } /** @@ -490,7 +535,7 @@ abstract class ManagedStrategy( (Thread.currentThread() as TestThread).inTestingCode = true } - override fun onActorFinish() { + override fun onActorFinish(iThread: Int) { // This is a hack to guarantee correct stepping in the plugin. // When stepping out to the TestThreadExecution class, stepping continues unproductively. // With this method, we force the debugger to stop at the beginning of the next actor. @@ -502,17 +547,22 @@ abstract class ManagedStrategy( * Returns whether the specified thread is active and * can continue its execution (i.e. is not blocked/finished). */ - private fun isActive(iThread: Int): Boolean = + protected open fun isActive(iThread: Int): Boolean = !finished[iThread] && !(isSuspended[iThread] && !runner.isCoroutineResumed(iThread, currentActorId[iThread])) && + !isSpinBoundBlocked[iThread] && !monitorTracker.isWaiting(iThread) && !parkingTracker.isParked(iThread) + // TODO: refactor --- get rid of this!!! + internal fun isBlocked(): Boolean = + (0 until nThreads).all { !isActive(it) } + /** * Waits until the specified thread can continue * the execution according to the strategy decision. */ - private fun awaitTurn(iThread: Int) = runInIgnoredSection { + protected fun awaitTurn(iThread: Int) = runInIgnoredSection { spinners[iThread].spinWaitUntil { // Finish forcibly if an error occurred and we already have an `InvocationResult`. if (suddenInvocationResult != null) throw ForcibleExecutionFinishError @@ -525,12 +575,15 @@ abstract class ManagedStrategy( * * @return was this thread actually switched to another or not */ - private fun switchCurrentThread( + internal fun switchCurrentThread( iThread: Int, reason: SwitchReason = SwitchReason.STRATEGY_SWITCH, mustSwitch: Boolean = false, tracePoint: TracePoint? = null ): Boolean { + if (reason == SwitchReason.SPIN_BOUND) { + isSpinBoundBlocked[iThread] = true + } val nextThread = chooseThreadSwitch(iThread, mustSwitch) val switchHappened = (iThread != nextThread) if (switchHappened) { @@ -569,6 +622,12 @@ abstract class ManagedStrategy( if (suspendedThread != null) { return suspendedThread } + // if some threads (but not all of them!) are blocked due to spin-loop bounding, + // then finish the execution but do not count it as a deadlock; + if (isSpinBoundBlocked.any { it } && !isSpinBoundBlocked.all { it }) { + suddenInvocationResult = SpinLoopBoundInvocationResult() + throw ForcibleExecutionFinishError + } // any other situation is considered to be a deadlock suddenInvocationResult = ManagedDeadlockInvocationResult(runner.collectExecutionResults()) throw ForcibleExecutionFinishError @@ -741,7 +800,7 @@ abstract class ManagedStrategy( /** * Returns `true` if a switch point is created. */ - override fun beforeReadField(obj: Any?, className: String, fieldName: String, codeLocation: Int, + override fun beforeReadField(obj: Any?, className: String, fieldName: String, typeDescriptor: String, codeLocation: Int, isStatic: Boolean, isFinal: Boolean) = runInIgnoredSection { // We need to ensure all the classes related to the reading object are instrumented. // The following call checks all the static fields. @@ -749,7 +808,7 @@ abstract class ManagedStrategy( LincheckJavaAgent.ensureClassHierarchyIsTransformed(className.canonicalClassName) } // Optimization: do not track final field reads - if (isFinal) { + if (isFinal && !trackFinalFields) { return@runInIgnoredSection false } // Do not track accesses to untracked objects @@ -773,12 +832,20 @@ abstract class ManagedStrategy( lastReadTracePoint[iThread] = tracePoint } newSwitchPoint(iThread, codeLocation, tracePoint) + if (memoryTracker != null) { + val type = Type.getType(typeDescriptor) + val location = objectTracker.getFieldAccessMemoryLocation(obj, className, fieldName, type, + isStatic = isStatic, + isFinal = isFinal, + ) + memoryTracker!!.beforeRead(iThread, codeLocation, location) + } loopDetector.beforeReadField(obj) return@runInIgnoredSection true } /** Returns true if a switch point is created. */ - override fun beforeReadArrayElement(array: Any, index: Int, codeLocation: Int): Boolean = runInIgnoredSection { + override fun beforeReadArrayElement(array: Any, index: Int, typeDescriptor: String, codeLocation: Int): Boolean = runInIgnoredSection { if (!objectTracker.shouldTrackObjectAccess(array)) { return@runInIgnoredSection false } @@ -799,10 +866,20 @@ abstract class ManagedStrategy( lastReadTracePoint[iThread] = tracePoint } newSwitchPoint(iThread, codeLocation, tracePoint) + if (memoryTracker != null && typeDescriptor != VOID_TYPE.descriptor) { + val type = Type.getType(typeDescriptor) + val location = objectTracker.getArrayAccessMemoryLocation(array, index, type) + memoryTracker!!.beforeRead(iThread, codeLocation, location) + } loopDetector.beforeReadArrayElement(array, index) true } + override fun interceptReadResult(): Any? = runInIgnoredSection { + val iThread = currentThread + return memoryTracker?.interceptReadResult(iThread) + } + override fun afterRead(value: Any?) = runInIgnoredSection { if (collectTrace) { val iThread = currentThread @@ -812,14 +889,14 @@ abstract class ManagedStrategy( loopDetector.afterRead(value) } - override fun beforeWriteField(obj: Any?, className: String, fieldName: String, value: Any?, codeLocation: Int, + override fun beforeWriteField(obj: Any?, className: String, fieldName: String, typeDescriptor: String, value: Any?, codeLocation: Int, isStatic: Boolean, isFinal: Boolean): Boolean = runInIgnoredSection { objectTracker.registerObjectLink(fromObject = obj ?: StaticObject, toObject = value) if (!objectTracker.shouldTrackObjectAccess(obj ?: StaticObject)) { return@runInIgnoredSection false } // Optimization: do not track final field writes - if (isFinal) { + if (isFinal && !trackFinalFields) { return@runInIgnoredSection false } val iThread = currentThread @@ -838,11 +915,19 @@ abstract class ManagedStrategy( null } newSwitchPoint(iThread, codeLocation, tracePoint) + if (memoryTracker != null) { + val type = Type.getType(typeDescriptor) + val location = objectTracker.getFieldAccessMemoryLocation(obj, className, fieldName, type, + isStatic = isStatic, + isFinal = isFinal, + ) + memoryTracker!!.beforeWrite(iThread, codeLocation, location, value) + } loopDetector.beforeWriteField(obj, value) return@runInIgnoredSection true } - override fun beforeWriteArrayElement(array: Any, index: Int, value: Any?, codeLocation: Int): Boolean = runInIgnoredSection { + override fun beforeWriteArrayElement(array: Any, index: Int, typeDescriptor: String, value: Any?, codeLocation: Int): Boolean = runInIgnoredSection { objectTracker.registerObjectLink(fromObject = array, toObject = value) if (!objectTracker.shouldTrackObjectAccess(array)) { return@runInIgnoredSection false @@ -863,6 +948,11 @@ abstract class ManagedStrategy( null } newSwitchPoint(iThread, codeLocation, tracePoint) + if (memoryTracker != null) { + val type = Type.getType(typeDescriptor) + val location = objectTracker.getArrayAccessMemoryLocation(array, index, type) + memoryTracker!!.beforeWrite(iThread, codeLocation, location, value) + } loopDetector.beforeWriteArrayElement(array, index, value) true } @@ -879,6 +969,14 @@ abstract class ManagedStrategy( objectTracker.registerObjectLink(fromObject = receiver ?: StaticObject, toObject = value) } + override fun onArrayCopy(srcArray: Any?, srcPos: Int, dstArray: Any?, dstPos: Int, length: Int) = runInIgnoredSection { + val iThread = currentThread + val codeLocation = SYSTEM_ARRAYCOPY_CODE_LOCATION + if (memoryTracker != null) { + memoryTracker!!.interceptArrayCopy(iThread, codeLocation, srcArray, srcPos, dstArray, dstPos, length) + } + } + override fun getThreadLocalRandom(): Random = runInIgnoredSection { return randoms[currentThread] } @@ -910,6 +1008,10 @@ abstract class ManagedStrategy( } } + override fun afterObjectInitialization(obj: Any) = runInIgnoredSection { + objectTracker.initializeObject(obj) + } + private fun methodGuaranteeType(owner: Any?, className: String, methodName: String): ManagedGuaranteeType? = runInIgnoredSection { userDefinedGuarantees?.forEach { guarantee -> val ownerName = owner?.javaClass?.canonicalName ?: className @@ -928,7 +1030,8 @@ abstract class ManagedStrategy( codeLocation: Int, methodId: Int, params: Array - ) { + ): Boolean { + var shouldInterceptMethodResult = false val guarantee = runInIgnoredSection { val atomicMethodDescriptor = getAtomicMethodDescriptor(owner, methodName) val guarantee = when { @@ -945,6 +1048,11 @@ abstract class ManagedStrategy( if (guarantee == ManagedGuaranteeType.TREAT_AS_ATOMIC) { newSwitchPointOnAtomicMethodCall(codeLocation, params) } + if (atomicMethodDescriptor != null) { + shouldInterceptMethodResult = trackAtomicMethodMemoryAccess( + owner, className, methodName, codeLocation, params, atomicMethodDescriptor + ) + } if (guarantee == null) { loopDetector.beforeMethodCall(codeLocation, params) } @@ -957,6 +1065,85 @@ abstract class ManagedStrategy( // so enterIgnoredSection would have no effect enterIgnoredSection() } + return shouldInterceptMethodResult + } + + private fun trackAtomicMethodMemoryAccess( + owner: Any?, + className: String, + methodName: String, + codeLocation: Int, + params: Array, + methodDescriptor: AtomicMethodDescriptor, + ): Boolean { + if (memoryTracker == null) + return false + val iThread = currentThread + val location = objectTracker.getAtomicAccessMemoryLocation(className, methodName, owner, params) + ?: return false + var argOffset = 0 + // atomic reflection case (AFU, VarHandle or Unsafe) - the first argument is a reflection object + argOffset += if (!isAtomic(owner) && !isAtomicArray(owner)) 1 else 0 + // Unsafe has an additional offset argument + argOffset += if (isUnsafe(owner)) 1 else 0 + // array accesses (besides Unsafe) take index as an additional argument + argOffset += if (location is ArrayElementMemoryLocation && !isUnsafe(owner)) 1 else 0 + when (methodDescriptor.kind) { + AtomicMethodKind.SET -> { + memoryTracker!!.beforeWrite(iThread, codeLocation, location, + value = params[argOffset] + ) + } + AtomicMethodKind.GET -> { + memoryTracker!!.beforeRead(iThread, codeLocation, location) + } + AtomicMethodKind.GET_AND_SET -> { + memoryTracker!!.beforeGetAndSet(iThread, codeLocation, location, + newValue = params[argOffset] + ) + } + AtomicMethodKind.COMPARE_AND_SET, AtomicMethodKind.WEAK_COMPARE_AND_SET -> { + memoryTracker!!.beforeCompareAndSet(iThread, codeLocation, location, + expectedValue = params[argOffset], + newValue = params[argOffset + 1] + ) + } + AtomicMethodKind.COMPARE_AND_EXCHANGE -> { + memoryTracker!!.beforeCompareAndExchange(iThread, codeLocation, location, + expectedValue = params[argOffset], + newValue = params[argOffset + 1] + ) + } + AtomicMethodKind.GET_AND_ADD -> { + memoryTracker!!.beforeGetAndAdd(iThread, codeLocation, location, + delta = (params[argOffset] as Number) + ) + } + AtomicMethodKind.ADD_AND_GET -> { + memoryTracker!!.beforeAddAndGet(iThread, codeLocation, location, + delta = (params[argOffset] as Number) + ) + } + AtomicMethodKind.GET_AND_INCREMENT -> { + memoryTracker!!.beforeGetAndAdd(iThread, codeLocation, location, delta = 1.convert(location.type)) + } + AtomicMethodKind.INCREMENT_AND_GET -> { + memoryTracker!!.beforeAddAndGet(iThread, codeLocation, location, delta = 1.convert(location.type)) + } + AtomicMethodKind.GET_AND_DECREMENT -> { + memoryTracker!!.beforeGetAndAdd(iThread, codeLocation, location, delta = (-1).convert(location.type)) + } + AtomicMethodKind.DECREMENT_AND_GET -> { + memoryTracker!!.beforeAddAndGet(iThread, codeLocation, location, delta = (-1).convert(location.type)) + } + } + return (methodDescriptor.kind != AtomicMethodKind.SET) + } + + override fun interceptMethodCallResult(): Any? = runInIgnoredSection { + check(memoryTracker != null) + val iThread = currentThread + return memoryTracker?.interceptReadResult(iThread) } override fun onMethodCallReturn(result: Any?) { @@ -1052,7 +1239,7 @@ abstract class ManagedStrategy( * if a coroutine was suspended. * @param iThread number of invoking thread */ - internal fun afterCoroutineSuspended(iThread: Int) { + internal open fun afterCoroutineSuspended(iThread: Int) { check(currentThread == iThread) isSuspended[iThread] = true if (runner.isCoroutineResumed(iThread, currentActorId[iThread])) { @@ -1068,21 +1255,38 @@ abstract class ManagedStrategy( * This method is invoked by a test thread * if a coroutine was resumed. */ - internal fun afterCoroutineResumed() { - isSuspended[currentThread] = false + internal open fun afterCoroutineResumed(iThread: Int) { + check(currentThread == iThread) + isSuspended[iThread] = false } /** * This method is invoked by a test thread * if a coroutine was cancelled. */ - internal fun afterCoroutineCancelled() { - val iThread = currentThread + internal open fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, cancellationResult: CancellationResult) { + check(currentThread == iThread) + if (cancellationResult == CANCELLATION_FAILED) + return isSuspended[iThread] = false // method will not be resumed after suspension, so clear prepared for resume call stack suspendedFunctionsStack[iThread].clear() } + /** + * This method is invoked by a test thread that attempts to resume coroutine. + */ + internal open fun onResumeCoroutine(iThread: Int, iResumedThread: Int, iResumedActor: Int) { + check(currentThread == iThread) + } + + /** + * This method is invoked by a test thread to check if the coroutine was resumed. + */ + internal open fun isCoroutineResumed(iThread: Int, iActor: Int): Boolean { + return true + } + private fun addBeforeMethodCallTracePoint( owner: Any?, codeLocation: Int, @@ -1651,11 +1855,20 @@ internal class ManagedStrategyRunner( } override fun afterCoroutineResumed(iThread: Int) = runInIgnoredSection { - managedStrategy.afterCoroutineResumed() + managedStrategy.afterCoroutineResumed(iThread) + } + + override fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, result: CancellationResult) = runInIgnoredSection { + managedStrategy.afterCoroutineCancelled(iThread, promptCancellation, result) + } + + override fun onResumeCoroutine(iResumedThread: Int, iResumedActor: Int) = runInIgnoredSection { + super.onResumeCoroutine(iResumedThread, iResumedActor) + managedStrategy.onResumeCoroutine(managedStrategy.currentThread, iResumedThread, iResumedActor) } - override fun afterCoroutineCancelled(iThread: Int) = runInIgnoredSection { - managedStrategy.afterCoroutineCancelled() + override fun isCoroutineResumed(iThread: Int, actorId: Int): Boolean { + return super.isCoroutineResumed(iThread, actorId) && managedStrategy.isCoroutineResumed(iThread, actorId) } override fun constructStateRepresentation(): String? { @@ -1672,12 +1885,12 @@ internal class ManagedStrategyRunner( val cancellationTracePoint = managedStrategy.createAndLogCancellationTracePoint() try { // Call the `cancel` method. + val iThread = managedStrategy.currentThread val cancellationResult = super.cancelByLincheck(cont, promptCancellation) // Pass the result to `cancellationTracePoint`. cancellationTracePoint?.initializeCancellationResult(cancellationResult) // Invoke `strategy.afterCoroutineCancelled` if the coroutine was cancelled successfully. - if (cancellationResult != CANCELLATION_FAILED) - managedStrategy.afterCoroutineCancelled() + afterCoroutineCancelled(iThread, promptCancellation, cancellationResult) return cancellationResult } catch (e: Throwable) { cancellationTracePoint?.initializeException(e) @@ -1686,7 +1899,6 @@ internal class ManagedStrategyRunner( } } - /** * This exception is used to finish the execution correctly for managed strategies. * Otherwise, there is no way to do it in case of (e.g.) deadlocks. @@ -1704,6 +1916,8 @@ internal const val COROUTINE_SUSPENSION_CODE_LOCATION = -1 // when spin-loop is detected, we might need to replay the execution up to N times private const val MAX_SPIN_CYCLE_REPLAY_COUNT = 3 +internal const val SYSTEM_ARRAYCOPY_CODE_LOCATION = -1 // currently the exact place of System.arraycopy is not known + private const val OBSTRUCTION_FREEDOM_SPINLOCK_VIOLATION_MESSAGE = "The algorithm should be non-blocking, but an active lock is detected" diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryLocation.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryLocation.kt new file mode 100644 index 000000000..57bb86db6 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryLocation.kt @@ -0,0 +1,444 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed + +import org.jetbrains.kotlinx.lincheck.strategy.managed.AtomicFieldUpdaterNames.getAtomicFieldUpdaterInfo +import org.jetbrains.kotlinx.lincheck.canonicalClassName +import org.jetbrains.kotlinx.lincheck.util.* +import org.objectweb.asm.Type +import org.objectweb.asm.commons.InstructionAdapter.OBJECT_TYPE +import java.lang.invoke.VarHandle +import java.lang.reflect.* +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater +import java.util.concurrent.atomic.AtomicLongFieldUpdater +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater +import kotlin.reflect.KClass +import java.lang.reflect.Array as ReflectArray + + +typealias ValueMapper = (Type, ValueID) -> OpaqueValue? + +interface MemoryLocation { + val objID: ObjectID + + // TODO: decide if we really want to expose ASM Type here, + // or we should use some other type: + // - kClass (there is a problem with boxed and primitive types being represented by the same kClass) + // - custom enum class (?) + val type: Type + + fun read(valueMapper: ValueMapper): Any? + fun write(value: Any?, valueMapper: ValueMapper) +} + +val MemoryLocation.kClass: KClass<*> + get() = type.getKClass() + +fun ObjectTracker.getFieldAccessMemoryLocation(obj: Any?, className: String, fieldName: String, type: Type, + isStatic: Boolean, isFinal: Boolean): MemoryLocation { + if (isStatic) { + return StaticFieldMemoryLocation(className.canonicalClassName, fieldName, type) + } + val clazz = obj!!.javaClass + val id = getObjectId(obj) + return ObjectFieldMemoryLocation(clazz, id, clazz.name, fieldName, type) +} + +fun ObjectTracker.getArrayAccessMemoryLocation(array: Any, index: Int, type: Type): MemoryLocation { + val clazz = array.javaClass + val id = getObjectId(array) + return ArrayElementMemoryLocation(clazz, id, index, type) +} + +fun ObjectTracker.getAtomicAccessMemoryLocation( + className: String, + methodName: String, + receiver: Any?, + params: Array +): MemoryLocation? = when { + + receiver is AtomicReferenceFieldUpdater<*, *> -> { + val info = getAtomicFieldUpdaterInfo(receiver)!! + val obj = params[0] + getFieldAccessMemoryLocation( + obj = obj, + className = info.className, + fieldName = info.fieldName, + type = OBJECT_TYPE, + isStatic = (obj == null), + isFinal = false, // TODO: fixme? + ) + } + + receiver is AtomicIntegerFieldUpdater<*> -> { + val info = getAtomicFieldUpdaterInfo(receiver)!! + val obj = params[0] + getFieldAccessMemoryLocation( + obj = obj, + className = info.className, + fieldName = info.fieldName, + type = Type.INT_TYPE, + isStatic = (obj == null), + isFinal = false, // TODO: fixme? + ) + } + + receiver is AtomicLongFieldUpdater<*> -> { + val info = getAtomicFieldUpdaterInfo(receiver)!! + val obj = params[0] + getFieldAccessMemoryLocation( + obj = obj, + className = info.className, + fieldName = info.fieldName, + type = Type.LONG_TYPE, + isStatic = (obj == null), + isFinal = false, // TODO: fixme? + ) + } + + receiver is VarHandle -> { + val info = VarHandleNames.varHandleMethodType(receiver, params).ensure { + it !is VarHandleMethodType.TreatAsDefaultMethod + } + val obj = info.instance + when (info) { + is VarHandleMethodType.ArrayVarHandleMethod -> getArrayAccessMemoryLocation( + array = obj!!, + index = info.index, + type = info.type, + ) + else -> getFieldAccessMemoryLocation( + obj = obj, + className = info.className!!, + fieldName = info.fieldName.orEmpty(), + type = info.type, + isStatic = (obj == null), + isFinal = false, // TODO: fixme? + ) + } + } + + isUnsafe(receiver) -> { + val info = UnsafeNames.getMethodCallType(params).ensure { + it !is UnsafeName.TreatAsDefaultMethod + } + val obj = info.instance + when (info) { + is UnsafeName.UnsafeArrayMethod -> getArrayAccessMemoryLocation( + array = obj!!, + index = info.index, + type = parseUnsafeMethodAccessType(methodName)!!, + ) + else -> getFieldAccessMemoryLocation( + obj = obj, + className = info.className!!, + fieldName = info.fieldName.orEmpty(), + type = parseUnsafeMethodAccessType(methodName)!!, + isStatic = (obj == null), + isFinal = false, // TODO: fixme? + ) + } + } + + isAtomic(receiver) -> { + AtomicPrimitiveMemoryLocation( + clazz = receiver!!::class.java, + objID = getObjectId(receiver), + type = getAtomicType(receiver)!!, + ) + } + + isAtomicArray(receiver) -> { + getArrayAccessMemoryLocation( + array = receiver!!, + index = (params[0] as Int), + type = getAtomicType(receiver)!!, + ) + } + + else -> null +} + +class StaticFieldMemoryLocation( + val className: String, + val fieldName: String, + override val type: Type, +) : MemoryLocation { + + override val objID: ObjectID = STATIC_OBJECT_ID + + private val field: Field by lazy { + val resolvedClass = resolveClass(className = className) + resolveField(resolvedClass, className, fieldName) + // .apply { isAccessible = true } + } + + override fun read(valueMapper: ValueMapper): Any? { + // return field.get(null) + return readField(null, field) + } + + override fun write(value: Any?, valueMapper: ValueMapper) { + // field.set(null, value) + writeField(null, field, value) + } + + override fun equals(other: Any?): Boolean { + if (this === other) + return true + return (other is StaticFieldMemoryLocation) + && (className == other.className) + && (fieldName == other.fieldName) + && (kClass == other.kClass) + } + + override fun hashCode(): Int { + var result = className.hashCode() + result = 31 * result + fieldName.hashCode() + return result + } + + override fun toString(): String = + "$className::$fieldName" + +} + +class ObjectFieldMemoryLocation( + clazz: Class<*>, + override val objID: ObjectID, + val className: String, + val fieldName: String, + override val type: Type, +) : MemoryLocation { + + init { + check(objID != NULL_OBJECT_ID) + } + + val simpleClassName: String = clazz.simpleName + + private val field: Field by lazy { + val resolvedClass = resolveClass(clazz, className = className) + resolveField(resolvedClass, className, fieldName) + // .apply { isAccessible = true } + } + + override fun read(valueMapper: ValueMapper): Any? { + // return field.get(valueMapper(OBJECT_TYPE, objID)?.unwrap()) + return readField(valueMapper(OBJECT_TYPE, objID)?.unwrap(), field) + } + + override fun write(value: Any?, valueMapper: ValueMapper) { + // field.set(valueMapper(OBJECT_TYPE, objID)?.unwrap(), value) + writeField(valueMapper(OBJECT_TYPE, objID)?.unwrap(), field, value) + } + + override fun equals(other: Any?): Boolean { + if (this === other) + return true + return (other is ObjectFieldMemoryLocation) + && (objID == other.objID) + && (className == other.className) + && (fieldName == other.fieldName) + && (kClass == other.kClass) + } + + override fun hashCode(): Int { + var result = objID.hashCode() + result = 31 * result + className.hashCode() + result = 31 * result + fieldName.hashCode() + return result + } + + override fun toString(): String { + return "${objRepr(simpleClassName, objID)}::$fieldName" + } + +} + +class ArrayElementMemoryLocation( + clazz: Class<*>, + override val objID: ObjectID, + val index: Int, + override val type: Type, +) : MemoryLocation { + + init { + check(objID != NULL_OBJECT_ID) + } + + val className: String = clazz.simpleName + + private val isPlainArray = clazz.isArray + + private val getMethod: Method? by lazy { + if (isPlainArray) { + return@lazy null + } + val resolvedClass = resolveClass(clazz) + return@lazy resolvedClass.methods + // TODO: can we use getOpaque() for atomic arrays here? + .first { it.name == "get" } + .apply { isAccessible = true } + } + + private val setMethod by lazy { + if (isPlainArray) { + return@lazy null + } + val resolvedClass = resolveClass(clazz) + return@lazy resolvedClass.methods + // TODO: can we use setOpaque() for atomic arrays here? + .first { it.name == "set" } + .apply { isAccessible = true } + } + + override fun read(valueMapper: ValueMapper): Any? { + // TODO: also use unsafe? + if (isPlainArray) { + return ReflectArray.get(valueMapper(OBJECT_TYPE, objID)?.unwrap(), index) + } + return getMethod!!.invoke(valueMapper(OBJECT_TYPE, objID)?.unwrap(), index) + } + + override fun write(value: Any?, valueMapper: ValueMapper) { + if (isPlainArray) { + ReflectArray.set(valueMapper(OBJECT_TYPE, objID)?.unwrap(), index, value) + return + } + setMethod!!.invoke(valueMapper(OBJECT_TYPE, objID)?.unwrap(), index, value) + } + + override fun equals(other: Any?): Boolean { + if (this === other) + return true + return (other is ArrayElementMemoryLocation) + && (objID == other.objID) + && (index == other.index) + && (kClass == other.kClass) + } + + override fun hashCode(): Int { + var result = objID.hashCode() + result = 31 * result + index + return result + } + + override fun toString(): String { + return "${objRepr(className, objID)}[$index]" + } + +} + +class AtomicPrimitiveMemoryLocation( + clazz: Class<*>, + override val objID: ObjectID, + override val type: Type, +) : MemoryLocation { + + init { + require(objID != NULL_OBJECT_ID) + } + + val className: String = clazz.simpleName + + private val getMethod by lazy { + // TODO: can we use getOpaque() here? + resolveClass(clazz).methods + .first { it.name == "get" } + .apply { isAccessible = true } + } + + private val setMethod by lazy { + // TODO: can we use setOpaque() here? + resolveClass(clazz).methods + .first { it.name == "set" } + .apply { isAccessible = true } + } + + override fun read(valueMapper: ValueMapper): Any? { + // TODO: also use unsafe? + return getMethod.invoke(valueMapper(OBJECT_TYPE, objID)?.unwrap()) + } + + override fun write(value: Any?, valueMapper: ValueMapper) { + // TODO: also use unsafe? + setMethod.invoke(valueMapper(OBJECT_TYPE, objID)?.unwrap(), value) + } + + override fun equals(other: Any?): Boolean { + if (this === other) + return true + return (other is AtomicPrimitiveMemoryLocation) + && (objID == other.objID) + && (kClass == other.kClass) + } + + override fun hashCode(): Int { + return objID.hashCode() + } + + override fun toString(): String { + check(objID != NULL_OBJECT_ID) + return objRepr(className, objID) + } + +} + +internal fun objRepr(className: String, objID: ObjectID): String { + return when (objID) { + NULL_OBJECT_ID -> "null" + else -> "$className@$objID" + } +} + +private fun matchClassName(clazz: Class<*>, className: String) = + clazz.name.endsWith(className) || (clazz.canonicalName?.endsWith(className) ?: false) + +private fun resolveClass(clazz: Class<*>? = null, className: String? = null): Class<*> { + if (className == null) { + check(clazz != null) + return clazz + } + if (clazz == null) { + return Class.forName(className) + } + if (matchClassName(clazz, className)) { + return clazz + } + var superClass = clazz.superclass + while (superClass != null) { + if (matchClassName(superClass, className)) + return superClass + superClass = superClass.superclass + } + throw IllegalStateException("Cannot find class $className for object of class ${clazz.name}!") +} + +private fun resolveField(clazz: Class<*>, className: String, fieldName: String): Field { + var currentClass: Class<*>? = clazz + do { + currentClass?.fields?.firstOrNull { it.name == fieldName }?.let { return it } + currentClass?.declaredFields?.firstOrNull { it.name == fieldName }?.let { return it } + currentClass = currentClass?.superclass + } while (currentClass != null) + throw IllegalStateException("Cannot find field $className::$fieldName for class $clazz!") +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryTracker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryTracker.kt new file mode 100644 index 000000000..efe498bea --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/MemoryTracker.kt @@ -0,0 +1,52 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed + +/** + * Tracks memory operations with shared variables. + */ +interface MemoryTracker { + + fun beforeWrite(iThread: Int, codeLocation: Int, location: MemoryLocation, value: Any?) + + fun beforeRead(iThread: Int, codeLocation: Int, location: MemoryLocation) + + fun beforeGetAndSet(iThread: Int, codeLocation: Int, location: MemoryLocation, newValue: Any?) + + fun beforeCompareAndSet(iThread: Int, codeLocation: Int, location: MemoryLocation, expectedValue: Any?, newValue: Any?) + + fun beforeCompareAndExchange(iThread: Int, codeLocation: Int, location: MemoryLocation, expectedValue: Any?, newValue: Any?) + + // TODO: move increment kind enum here? + fun beforeGetAndAdd(iThread: Int, codeLocation: Int, location: MemoryLocation, delta: Number) + + fun beforeAddAndGet(iThread: Int, codeLocation: Int, location: MemoryLocation, delta: Number) + + fun interceptReadResult(iThread: Int): Any? + + fun interceptArrayCopy(iThread: Int, codeLocation: Int, srcArray: Any?, srcPos: Int, dstArray: Any?, dstPos: Int, length: Int) + + fun reset() + +} + +typealias MemoryInitializer = (MemoryLocation) -> OpaqueValue? +typealias MemoryIDInitializer = (MemoryLocation) -> ValueID \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ObjectTracker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ObjectTracker.kt index d7dd9dcbe..1a805fdaf 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ObjectTracker.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/ObjectTracker.kt @@ -32,6 +32,9 @@ interface ObjectTracker { */ fun registerObjectLink(fromObject: Any, toObject: Any?) + // TODO: add constructor name as parameter? + fun initializeObject(obj: Any) + /** * Determines whether accesses to the fields of the given object should be tracked. * @@ -40,13 +43,10 @@ interface ObjectTracker { */ fun shouldTrackObjectAccess(obj: Any): Boolean + fun getObjectId(obj: Any): ObjectID + /** * Resets the state of the object tracker. */ fun reset() -} - -/** - * Special auxiliary object used as an owner of static fields (instead of `null`). - */ -internal object StaticObject: Any() \ No newline at end of file +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/OpaqueValue.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/OpaqueValue.kt new file mode 100644 index 000000000..282da9d28 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/OpaqueValue.kt @@ -0,0 +1,136 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed + +import org.objectweb.asm.Type +import kotlin.reflect.KClass + +/** + * Auxiliary class to represent values of variables in managed strategies. + * + * [ManagedStrategy] can intercept reads and writes from/to shared variables + * and call special functions in order to model the behavior of shared memory + * (see [ManagedStrategy.onSharedVariableRead], [ManagedStrategy.onSharedVariableWrite] and others). + * These functions in turn may return or take as arguments read/written values + * and store them in some internal data-structures of the managed strategy. + * Because of the dynamic nature of the managed strategies memory is untyped, + * and thus values generally are passed and stored as objects of type [Any]. + * One can encounter several potential pitfalls when operating with these values inside managed strategy. + * + * 1. Values of primitive and reference types should be handled differently. + * For example, primitive values should be compared structurally, + * while reference values --- by reference. + * + * 2. It is dangerous to call any methods on a value of reference type, + * including methods such as [equals], [hashCode], and [toString]. + * This is because the class of the value is instrumented by [ManagedStrategyTransformer]. + * Thus calling methods of the value inside the implementation of managed strategy + * can in turn lead to interception of read/writes from/to shared memory + * and therefore to recursive calls to internal methods of managed strategy and so on. + * + * [OpaqueValue] is a wrapper around value of type [Any] that helps to avoid these problems. + * It provides safe implementations of [equals], [hashCode], [toString] methods that + * correctly distinguish values of primitive and reference types. + * It also provides other useful utility functions for working with values inside [ManagedStrategy]. + * + * TODO: use @JvmInline value class? + */ +class OpaqueValue private constructor(private val value: Any) { + + companion object { + fun fromAny(value: Any): OpaqueValue = + OpaqueValue(value) + + fun default(kClass: KClass<*>): OpaqueValue? = kClass.defaultValue() + } + + fun unwrap(): Any = value + + val isPrimitive: Boolean + get() = value.isPrimitive() + + operator fun plus(delta: Number): OpaqueValue = when (value) { + is Int -> (value + delta as Int).opaque() + is Long -> (value + delta as Long).opaque() + // TODO: handle other Numeric types? + else -> throw IllegalStateException() + } + + override fun equals(other: Any?): Boolean { + if (other !is OpaqueValue) + return false + return if (isPrimitive) { + other.isPrimitive && this.value == other.value + } else (this.value === other.value) + } + + override fun hashCode(): Int = + System.identityHashCode(value) + + override fun toString(): String = + if (isPrimitive) value.toString() else value.toOpaqueString() + +} + +fun Any.opaque(): OpaqueValue = + OpaqueValue.fromAny(this) + +fun OpaqueValue?.isInstanceOf(kClass: KClass<*>) = + this?.unwrap()?.let { kClass.isInstance(it) } ?: true + +fun KClass<*>.defaultValue(): OpaqueValue? = when(this) { + Int::class -> 0 + Byte::class -> 0.toByte() + Short::class -> 0.toShort() + Long::class -> 0.toLong() + Float::class -> 0.toFloat() + Double::class -> 0.toDouble() + Char::class -> 0.toChar() + Boolean::class -> false + else -> null +}?.opaque() + +fun Any?.toOpaqueString(): String { + if (this == null) + return "null" + val className = this::class.simpleName.orEmpty() + val objRepr = Integer.toHexString(System.identityHashCode(this)) + return "${className}@${objRepr}" +} + +fun Any.isPrimitive(): Boolean = + (this::class.javaPrimitiveType != null) + +typealias ValueID = Long +typealias ObjectID = Long + +internal object StaticObject : Any() + +// TODO: override `toString` ? +internal const val INVALID_OBJECT_ID = -2L +internal const val STATIC_OBJECT_ID = -1L +internal const val NULL_OBJECT_ID = 0L + +internal fun Int.convert(type: Type): Number = when (type.sort) { + Type.LONG -> toLong() + Type.INT -> this + else -> throw IllegalArgumentException("Expected Long or Int") +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/TracePoint.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/TracePoint.kt index 575066796..679aab16b 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/TracePoint.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/TracePoint.kt @@ -299,12 +299,13 @@ private fun StackTraceElement.shorten(): String { return stackTraceElement } -internal enum class SwitchReason(private val reason: String) { +enum class SwitchReason(private val reason: String) { MONITOR_WAIT("wait on monitor"), LOCK_WAIT("lock is already acquired"), PARK_WAIT("thread is parked"), ACTIVE_LOCK("active lock detected"), SUSPENDED("coroutine is suspended"), + SPIN_BOUND("spinning bound is reached"), STRATEGY_SWITCH(""); override fun toString() = reason diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/UnsafeNames.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/UnsafeNames.kt index 90912729e..470bb5355 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/UnsafeNames.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/UnsafeNames.kt @@ -12,6 +12,7 @@ package org.jetbrains.kotlinx.lincheck.strategy.managed import org.jetbrains.kotlinx.lincheck.findFieldNameByOffset import org.jetbrains.kotlinx.lincheck.strategy.managed.UnsafeName.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.UnsafeName.TreatAsDefaultMethod import org.jetbrains.kotlinx.lincheck.util.UnsafeHolder /** @@ -96,3 +97,27 @@ internal sealed interface UnsafeName { */ data object TreatAsDefaultMethod : UnsafeName } + +internal val UnsafeName.instance: Any? get() = when (this) { + is UnsafeArrayMethod -> array + is UnsafeInstanceMethod -> owner + else -> null +} + +internal val UnsafeName.className: String? get() = when (this) { + is UnsafeArrayMethod -> array.javaClass.name + is UnsafeInstanceMethod -> owner.javaClass.name + is UnsafeStaticMethod -> clazz.name + else -> null +} + +internal val UnsafeName.fieldName: String? get() = when (this) { + is UnsafeInstanceMethod -> fieldName + is UnsafeStaticMethod -> fieldName + else -> null +} + +internal val UnsafeName.index: Int get() = when (this) { + is UnsafeArrayMethod -> index + else -> -1 +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/VarHandleNames.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/VarHandleNames.kt index 5604b5498..4d73495e3 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/VarHandleNames.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/VarHandleNames.kt @@ -10,9 +10,11 @@ package org.jetbrains.kotlinx.lincheck.strategy.managed -import org.jetbrains.kotlinx.lincheck.findFieldNameByOffset +import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.strategy.managed.VarHandleMethodType.* -import org.jetbrains.kotlinx.lincheck.util.readFieldViaUnsafe +import org.jetbrains.kotlinx.lincheck.util.* +import org.objectweb.asm.Type +import org.objectweb.asm.commons.InstructionAdapter.OBJECT_TYPE import sun.misc.Unsafe import java.lang.invoke.VarHandle import java.lang.reflect.Field @@ -105,7 +107,7 @@ internal object VarHandleNames { val firstParameter = parameters.firstOrNull() ?: return TreatAsDefaultMethod if (!ownerType.isInstance(firstParameter)) return TreatAsDefaultMethod - return InstanceVarHandleMethod(firstParameter, fieldName, parameters.drop(1)) + return InstanceVarHandleMethod(firstParameter, fieldName, varHandle.getType(), parameters.drop(1)) } } @@ -125,7 +127,7 @@ internal object VarHandleNames { val fieldName = findFieldNameByOffset(ownerType, fieldOffset) ?: return TreatAsDefaultMethod - return StaticVarHandleMethod(ownerType, fieldName, parameters.toList()) + return StaticVarHandleMethod(ownerType, fieldName, varHandle.getType(), parameters.toList()) } } @@ -139,7 +141,7 @@ internal object VarHandleNames { val firstParameter = parameters[0] ?: return TreatAsDefaultMethod val index = parameters[1] as? Int ?: return TreatAsDefaultMethod - return ArrayVarHandleMethod(firstParameter, index, parameters.drop(2)) + return ArrayVarHandleMethod(firstParameter, index, varHandle.getType(), parameters.drop(2)) } } @@ -173,6 +175,25 @@ internal object VarHandleNames { } return null } + + private fun VarHandle.getType(): Type { + val className = this::class.java.name + return when { + "Int" in className -> Type.INT_TYPE + "Long" in className -> Type.LONG_TYPE + "Short" in className -> Type.SHORT_TYPE + "Byte" in className -> Type.BYTE_TYPE + "Char" in className -> Type.CHAR_TYPE + "Boolean" in className -> Type.BOOLEAN_TYPE + "Double" in className -> Type.DOUBLE_TYPE + "Float" in className -> Type.FLOAT_TYPE + "Reference" in className -> OBJECT_TYPE + "Object" in className -> OBJECT_TYPE + + else -> unreachable() + } + } + } /** @@ -187,17 +208,43 @@ internal sealed interface VarHandleMethodType { /** * Array cell access method call. */ - data class ArrayVarHandleMethod(val array: Any, val index: Int, val parameters: List) : VarHandleMethodType + data class ArrayVarHandleMethod(val array: Any, val index: Int, val type: Type, val parameters: List) : VarHandleMethodType /** * Method call affecting field [fieldName] of the [owner]. */ - data class InstanceVarHandleMethod(val owner: Any, val fieldName: String, val parameters: List) : + data class InstanceVarHandleMethod(val owner: Any, val fieldName: String, val type: Type, val parameters: List) : VarHandleMethodType /** * Method call affecting static field [fieldName] of the [ownerClass]. */ - data class StaticVarHandleMethod(val ownerClass: Class<*>, val fieldName: String, val parameters: List) : + data class StaticVarHandleMethod(val ownerClass: Class<*>, val fieldName: String, val type: Type, val parameters: List) : VarHandleMethodType +} + +internal val VarHandleMethodType.instance: Any? get() = when (this) { + is ArrayVarHandleMethod -> array + is InstanceVarHandleMethod -> owner + else -> null +} + +internal val VarHandleMethodType.className: String? get() = when (this) { + is ArrayVarHandleMethod -> array.javaClass.name + is InstanceVarHandleMethod -> owner.javaClass.name + is StaticVarHandleMethod -> ownerClass.name + else -> null +} + +internal val VarHandleMethodType.fieldName: String? get() = when (this) { + is InstanceVarHandleMethod -> fieldName + is StaticVarHandleMethod -> fieldName + else -> null +} + +internal val VarHandleMethodType.type: Type get() = when (this) { + is InstanceVarHandleMethod -> type + is StaticVarHandleMethod -> type + is ArrayVarHandleMethod -> type + TreatAsDefaultMethod -> throw IllegalArgumentException() } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Event.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Event.kt new file mode 100644 index 000000000..5d6532c3d --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Event.kt @@ -0,0 +1,481 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.strategy.managed.MemoryLocation +import org.jetbrains.kotlinx.lincheck.util.* + +typealias EventID = Int + +interface Event : Comparable { + /** + * Event's ID. + */ + val id: EventID + + /** + * Event's label. + */ + val label: EventLabel + + /** + * List of event's dependencies. + */ + val dependencies: List + + override fun compareTo(other: Event): Int { + return id.compareTo(other.id) + } +} + +interface ThreadEvent : Event { + /** + * Event's thread + */ + val threadId: Int + + /** + * Event's position in a thread. + */ + val threadPosition: Int + + /** + * Event's parent in program order. + */ + // TODO: do we need to store it in `ThreadEvent` ? + val parent: ThreadEvent? + + /** + * List of event's dependencies. + */ + override val dependencies: List + + /** + * Vector clock to track causality relation. + */ + val causalityClock: VectorClock + + /** + * Returns n-th predecessor of the given event. + */ + fun predNth(n: Int): ThreadEvent? +} + +fun ThreadEvent.pred(inclusive: Boolean = false, predicate: (ThreadEvent) -> Boolean): ThreadEvent? { + if (inclusive && predicate(this)) + return this + var event: ThreadEvent? = parent + while (event != null && !predicate(event)) { + event = event.parent + } + return event +} + +val ThreadEvent.threadRoot: Event + get() = predNth(threadPosition)!! + +fun ThreadEvent.threadPrefix(inclusive: Boolean = false, reversed: Boolean = false): List { + val events = arrayListOf() + if (inclusive) { + events.add(this) + } + // obtain a list of predecessors of a given event + var event: ThreadEvent? = parent + while (event != null) { + events.add(event) + event = event.parent + } + // since we iterate from child to parent, the list by default is in reverse order; + // if the callee passed `reversed=true` leave it as is, otherwise reverse + // to get the list in the order from ancestors to descendants + if (!reversed) { + events.reverse() + } + return events +} + +private fun ThreadEvent?.calculateNextEventPosition(): Int = + 1 + (this?.threadPosition ?: -1) + + +interface SynchronizedEvent : Event { + /** + * List of events which synchronize into the given event. + */ + val synchronized: List +} + +fun SynchronizedEvent.resynchronize(algebra: SynchronizationAlgebra): EventLabel = + if (synchronized.isNotEmpty()) + algebra.synchronize(synchronized).ensureNotNull() + else label + +interface AtomicThreadEvent : ThreadEvent, SynchronizedEvent { + + override val parent: AtomicThreadEvent? + + /** + * Sender events corresponding to this event. + * Applicable only to response events. + */ + val senders: List + + /** + * The allocation event for the accessed object. + * Applicable only to object accessing events. + */ + val allocation: AtomicThreadEvent? + + /** + * The allocation event for the value produced by this label + * (for example, written value for write access label). + */ + // TODO: refactor! + val source: AtomicThreadEvent? +} + +/** + * Request event corresponding to this event. + * Applicable only to response and receive events. + */ +val AtomicThreadEvent.request: AtomicThreadEvent? get() = + if (label.isResponse && !label.isSpanLabel) parent!! else null + +val AtomicThreadEvent.syncFrom: AtomicThreadEvent get() = run { + require(label.isResponse) + require(senders.size == 1) + senders.first() +} + +val AtomicThreadEvent.readsFrom: AtomicThreadEvent get() = run { + require(label is ReadAccessLabel) + syncFrom +} + +val AtomicThreadEvent.locksFrom: AtomicThreadEvent get() = run { + require(label is LockLabel) + syncFrom +} + +val AtomicThreadEvent.notifiedBy: AtomicThreadEvent get() = run { + require(label is WaitLabel) + syncFrom +} + +val AtomicThreadEvent.exclusiveReadPart: AtomicThreadEvent get() = run { + require(label.satisfies { isExclusive }) + parent!! +} + +/** + * Checks whether this event is valid response to the [request] event. + * If this event is not a response or [request] is not a request returns false. + * + * Response is considered to be valid if: + * - request is a parent of response, + * - request-label can be synchronized-into response-label. + * + * @see EventLabel.isValidResponse + */ +fun AtomicThreadEvent.isValidResponse(request: ThreadEvent) = + label.isResponse && request.label.isRequest && parent == request && label.isValidResponse(request.label) + +/** + * Checks whether this event is a valid response to its parent request event. + * If this event is not a response or its parent is not a request returns false. + * + * @see isValidResponse + */ +fun AtomicThreadEvent.isValidResponse() = + parent?.let { isValidResponse(it) } ?: false + +/** + * Checks whether this event is valid write part of atomic read-modify-write, + * of which the [readResponse] is a read-response part. + * If this event is not an exclusive write or [readResponse] is not an exclusive read-response returns false. + * + * Write is considered to be valid write part of read-modify-write if: + * - read-response is a parent of write, + * - read-response and write access same location, + * - both have exclusive flag set. + * request-label can be synchronized-into response-label. + * + * @see MemoryAccessLabel.isExclusive + */ +fun AtomicThreadEvent.isWritePartOfAtomicUpdate(readResponse: ThreadEvent): Boolean { + val writeLabel = label.refine { isExclusive } + ?: return false + val readLabel = readResponse.label.refine { isResponse && isExclusive } + ?: return false + return (parent == readResponse) && readLabel.location == writeLabel.location +} + + +/** + * Hyper event is a composite event consisting of multiple atomic events. + * It allows viewing subset of events of an execution as an atomic event by itself. + * Some notable examples of hyper events are listed below: + * - pair of consecutive request and response events of the same operation + * can be viewed as a composite receive event; + * - pair of exclusive read and write events of the same atomic operation + * can be viewed as a composite read-modify-write event; + * - all the events between lock acquire and lock release events + * can be viewed as a composite critical section event; + * for other examples see subclasses of this class. + * + * We support only sequential hyper events --- that is set of events + * totally ordered by some criterion. + * + * This class of events is called "hyper" after term "hyper pomsets" from [1]. + * + * [1] Brunet, Paul, and David Pym. + * "Pomsets with Boxes: Protection, Separation, and Locality in Concurrent Kleene Algebra." + * 5th International Conference on Formal Structures for Computation and Deduction. 2020. + * + */ +interface HyperEvent : Event { + val events: List +} + +abstract class AbstractEvent(final override val label: EventLabel) : Event { + + companion object { + private var nextID: EventID = 0 + } + + final override val id: EventID = nextID++ + + protected open fun validate() {} + + override fun equals(other: Any?): Boolean { + // TODO: think again --- is it sound? seems so, as we only create single event per ID + return (this === other) + // return (other is AbstractEvent) && (id == other.id) + } + + override fun hashCode(): Int { + return id.hashCode() + } + +} + +abstract class AbstractThreadEvent( + label: EventLabel, + parent: AbstractThreadEvent?, + dependencies: List, +) : AbstractEvent(label), ThreadEvent { + + final override val threadId: Int = when (label) { + is InitializationLabel -> label.initThreadID + is ThreadStartLabel -> label.threadId + is ActorLabel -> label.threadId + else -> parent!!.threadId + } + + final override val threadPosition: Int = + parent.calculateNextEventPosition() + + final override val causalityClock: VectorClock = run { + dependencies.fold(parent?.causalityClock?.copy() ?: MutableVectorClock()) { clock, event -> + clock + event.causalityClock + }.apply { + set(threadId, threadPosition) + } + } + + override fun validate() { + super.validate() + require(threadPosition == parent.calculateNextEventPosition()) + // require((parent != null) implies { parent!! in dependencies }) + } + + override fun toString(): String { + return "#${id}: [${threadId}, ${threadPosition}] $label" + } + + override fun predNth(n: Int): ThreadEvent? { + return predNthOptimized(n) + // .also { check(it == predNthNaive(n)) } + } + + // naive implementation with O(N) complexity, just for testing and debugging + private fun predNthNaive(n : Int): ThreadEvent? { + var e: ThreadEvent = this + for (i in 0 until n) + e = e.parent ?: return null + return e + } + + // binary lifting search with O(lgN) complexity + // https://cp-algorithms.com/graph/lca_binary_lifting.html; + private fun predNthOptimized(n: Int): ThreadEvent? { + require(n >= 0) + var e = this + var r = n + while (r > MAX_JUMP) { + e = e.jumps[N_JUMPS - 1] ?: return null + r -= MAX_JUMP + } + while (r != 0) { + val k = 31 - Integer.numberOfLeadingZeros(r) + val jump = Integer.highestOneBit(r) + e = e.jumps[k] ?: return null + r -= jump + } + return e + } + + private val jumps = Array(N_JUMPS) { null } + + init { + calculateJumps(jumps, parent) + } + + companion object { + private const val N_JUMPS = 10 + private const val MAX_JUMP = 1 shl (N_JUMPS - 1) + + private fun calculateJumps(jumps: Array, parent: AbstractThreadEvent?) { + require(N_JUMPS > 0) + require(jumps.size >= N_JUMPS) + jumps[0] = parent + for (i in 1 until N_JUMPS) { + jumps[i] = jumps[i - 1]?.jumps?.get(i - 1) + } + } + } + +} + +class AtomicThreadEventImpl( + label: EventLabel, + override val parent: AtomicThreadEvent?, + /** + * Sender events corresponding to this event. + * Applicable only to response events. + */ + override val senders: List = listOf(), + /** + * The allocation event for the accessed object. + * Applicable only to object accessing events. + */ + override val allocation: AtomicThreadEvent? = null, + /** + * The allocation event for the value produced by this label + * (for example, written value for write access label). + */ + // TODO: refactor! + override val source: AtomicThreadEvent? = null, + /** + * List of event's dependencies + */ + override val dependencies: List = listOf(), +) : AtomicThreadEvent, AbstractThreadEvent(label, (parent as AbstractThreadEvent?), dependencies) { + + final override val synchronized: List = + if (label.isResponse && !label.isSpanLabel) (listOf(request!!) + senders) else listOf() + + override fun validate() { + super.validate() + // constraints for atomic non-span-related events + if (!label.isSpanLabel) { + // the request event should not follow another request event + // because the earlier request should first receive its response + require((label.isRequest && parent != null) implies { + !parent!!.label.isRequest || parent!!.label.isSpanLabel + }) + // only the response event should have a corresponding request part + require(label.isResponse equivalent (request != null)) + // response and receive events (and only them) should have a corresponding list + // of sender events, with which they synchronize-with + require((label.isResponse || label.isReceive) equivalent senders.isNotEmpty()) + } + // read-exclusive label should precede every write-exclusive label + if (label is WriteAccessLabel && label.isExclusive) { + require(parent != null) + require(parent!!.label.satisfies { + isResponse && isExclusive && location == label.location + }) + } + } + +} + +// TODO: rename to SpanEvent +class HyperThreadEvent( + label: EventLabel, + override val parent: HyperThreadEvent?, + override val dependencies: List, + override val events: List, +) : HyperEvent, AbstractThreadEvent(label, parent, dependencies) + + +val programOrder = Relation { x, y -> + if (x.threadId != y.threadId || x.threadPosition >= y.threadPosition) + false + else (x == y.predNth(y.threadPosition - x.threadPosition)) +} + +val causalityOrder = Relation { x, y -> + (x != y) && y.causalityClock.observes(x.threadId, x.threadPosition) +} + +val causalityCovering: Covering = Covering { it.dependencies } + +fun getLocationForSameLocationAccesses(x: Event, y: Event): MemoryLocation? { + val xloc = (x.label as? MemoryAccessLabel)?.location + val yloc = (y.label as? MemoryAccessLabel)?.location + val isSameLocation = when { + xloc != null && yloc != null -> xloc == yloc + xloc != null -> y.label.isMemoryAccessTo(xloc) + yloc != null -> x.label.isMemoryAccessTo(yloc) + else -> false + } + return if (isSameLocation) (xloc ?: yloc) else null +} + +fun getLocationForSameLocationWriteAccesses(x: Event, y: Event): MemoryLocation? { + val xloc = (x.label as? WriteAccessLabel)?.location + val yloc = (y.label as? WriteAccessLabel)?.location + val isSameLocation = when { + xloc != null && yloc != null -> xloc == yloc + xloc != null -> y.label.isWriteAccessTo(xloc) + yloc != null -> x.label.isWriteAccessTo(yloc) + else -> false + } + return if (isSameLocation) (xloc ?: yloc) else null +} + +fun List.getLocationForSameLocationMemoryAccesses(): MemoryLocation? { + val location = this.findMapped { (it.label as? MemoryAccessLabel)?.location } + ?: return null + return if (all { it.label.isWriteAccessTo(location) }) + location + else null +} + +fun List.getLocationForSameLocationWriteAccesses(): MemoryLocation? { + val location = this.findMapped { (it.label as? WriteAccessLabel)?.location } + ?: return null + return if (all { it.label.isWriteAccessTo(location) }) + location + else null +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventAggregation.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventAggregation.kt new file mode 100644 index 000000000..2e641f11e --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventAggregation.kt @@ -0,0 +1,308 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* + +interface EventAggregator { + fun aggregate(events: List): List> + + fun label(events: List): EventLabel? + fun dependencies(events: List, remapping: EventRemapping): List + + fun isCoverable( + events: List, + covering: Covering, + clock: VectorClock + ): Boolean +} + +// TODO: unify with Remapping class +typealias EventRemapping = Map + +fun Execution.aggregate( + aggregator: EventAggregator +): Pair, EventRemapping> { + val clock = MutableVectorClock(1 + maxThreadID) + val result = MutableExecution(1 + maxThreadID) + val remapping = mutableMapOf() + val aggregated = threadMap.mapValues { (_, events) -> aggregator.aggregate(events) } + val aggregatedClock = MutableVectorClock(1 + maxThreadID).apply { + for (i in 0 .. maxThreadID) + this[i] = 0 + } + while (!clock.observes(this)) { + var position = -1 + var found = false + var events: List? = null + for ((tid, list) in aggregated.entries) { + position = aggregatedClock[tid] + events = list.getOrNull(position) ?: continue + if (aggregator.isCoverable(events, causalityCovering, clock)) { + found = true + break + } + } + if (!found) { + // error("Cannot aggregate events due to cyclic dependencies") + break + } + check(position >= 0) + check(events != null) + check(events.isNotEmpty()) + val tid = events.first().threadId + val label = aggregator.label(events) + val parent = result[tid]?.lastOrNull() + if (label != null) { + val dependencies = aggregator.dependencies(events, remapping) + val event = HyperThreadEvent( + label = label, + parent = parent, + dependencies = dependencies, + events = events, + ) + result.add(event) + events.forEach { + remapping.put(it, event).ensureNull() + } + } else if (parent != null) { + // effectively squash skipped events into previous hyper event, + // such representation is convenient for causality clock maintenance + // TODO: make sure dependencies of skipped events are propagated correctly + events.forEach { + remapping.put(it, parent).ensureNull() + } + } + clock.increment(tid, events.size) + aggregatedClock.increment(tid) + } + return result to remapping +} + +fun Relation.existsLifting() = Relation { x, y -> + x.events.any { ex -> ex !in y.events && y.events.any { ey -> this(ex, ey) } } +} + +fun Covering.aggregate(remapping: EventRemapping) = Covering { event -> + event.events + .flatMap { atomicEvent -> + this(atomicEvent).mapNotNull { remapping[it] } + } + .distinct() +} + +fun ActorAggregator(execution: Execution) = object : EventAggregator { + + override fun aggregate(events: List): List> { + var pos = 0 + val result = mutableListOf>() + while (pos < events.size) { + if ((events[pos].label as? ActorLabel)?.spanKind != SpanLabelKind.Start) { + result.add(listOf(events[pos++])) + continue + } + val start = events[pos] + val end = events.subList(fromIndex = start.threadPosition, toIndex = events.size).find { + (it.label as? ActorLabel)?.spanKind == SpanLabelKind.End + } ?: break + check((start.label as ActorLabel).actor == (end.label as ActorLabel).actor) + result.add(events.subList(fromIndex = start.threadPosition, toIndex = end.threadPosition + 1)) + pos = end.threadPosition + 1 + } + return result + } + + override fun label(events: List): EventLabel? { + val start = events.first().takeIf { + (it.label as? ActorLabel)?.spanKind == SpanLabelKind.Start + } ?: return null + val end = events.last().ensure { + (it.label as? ActorLabel)?.spanKind == SpanLabelKind.End + } + check((start.label as ActorLabel).actor == (end.label as ActorLabel).actor) + return ActorLabel(SpanLabelKind.Span, (start.label as ActorLabel).threadId, (start.label as ActorLabel).actor) + } + + override fun dependencies(events: List, remapping: EventRemapping): List { + return events + .flatMap { event -> + val causalEvents = execution.threadMap.entries.mapNotNull { (tid, thread) -> + if (tid != event.threadId) + thread.getOrNull(event.causalityClock[tid]) + else null + } + causalEvents.mapNotNull { remapping[it] } + } + // TODO: should use covering here instead of dependencies? + .filter { + // take the last event before ActorEnd event + val last = it.events[it.events.size - 2] + events.first().causalityClock.observes(last) + } + .distinct() + } + + override fun isCoverable( + events: List, + covering: Covering, + clock: VectorClock + ): Boolean { + return covering.firstCoverable(events, clock) + } + +} + +fun SynchronizationAlgebra.aggregator() = object : EventAggregator { + + override fun aggregate(events: List): List> = + events.squash { x, y -> synchronizable(x.label, y.label) } + + override fun label(events: List): EventLabel = + synchronize(events).ensureNotNull() + + override fun dependencies(events: List, remapping: EventRemapping): List { + return events + // TODO: should use covering here instead of dependencies? + .flatMap { event -> event.dependencies.mapNotNull { remapping[it] } } + .distinct() + } + + override fun isCoverable( + events: List, + covering: Covering, + clock: VectorClock + ): Boolean { + return covering.allCoverable(events, clock) + } + +} + +private val ReceiveAggregationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = + if (!label.isSpanLabel && (label.isRequest || label.isResponse)) SynchronizationType.Binary else null + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + label.isSpanLabel || other.isSpanLabel -> + null + label.isRequest && other.isResponse && other.isValidResponse(label) -> + other.getReceive() + else -> null + } + +} + +fun ThreadStartLabel.getReceive(): EventLabel = + copy(kind = LabelKind.Receive) + +fun ThreadJoinLabel.getReceive(): EventLabel? = + if (isUnblocked) copy(kind = LabelKind.Receive) else null + +fun ReadAccessLabel.getReceive(): ReadAccessLabel = + copy(kind = LabelKind.Receive) + +fun ReadModifyWriteAccessLabel.getReceive(): ReadModifyWriteAccessLabel = + copy(kind = LabelKind.Receive) + +fun ParkLabel.getReceive(): EventLabel = + copy(kind = LabelKind.Receive) + +fun CoroutineSuspendLabel.getReceive(): EventLabel = + copy(kind = LabelKind.Receive) + +fun EventLabel.getReceive(): EventLabel? = when (this) { + is ThreadStartLabel -> getReceive() + is ThreadJoinLabel -> getReceive() + is ReadAccessLabel -> getReceive() + is ReadModifyWriteAccessLabel -> getReceive() + is ParkLabel -> getReceive() + is CoroutineSuspendLabel -> getReceive() + else -> null +} + +private val MemoryAccessAggregationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is MemoryAccessLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + // read request synchronizes with read response + label is ReadAccessLabel && label.isRequest && other is ReadAccessLabel && other.isResponse + && other.isValidResponse(label) -> + other.getReceive() + + // exclusive read response/receive synchronizes with exclusive write + label is ReadAccessLabel && (label.isResponse || label.isReceive) && other is WriteAccessLabel -> + label.getReadModifyWrite(other) + + // exclusive read request synchronizes with read-modify-write response + label is ReadAccessLabel && label.isRequest && other is ReadModifyWriteAccessLabel && other.isResponse + && label.isValidReadPart(other) -> + other.getReceive() + + else -> null + } + +} + +val MutexAggregationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when (label) { + is MutexLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + // unlock label can be merged with the subsequent wait request + label is UnlockLabel && other is WaitLabel && other.isRequest && !other.isUnlocking + && label.mutexID == other.mutexID -> + WaitLabel(LabelKind.Request, label.mutexID, isUnlocking = true) + + // wait response label can be merged with the subsequent lock request + label is WaitLabel && label.isResponse && !label.isLocking && other is LockLabel && other.isRequest + && label.mutexID == other.mutexID -> + WaitLabel(LabelKind.Response, label.mutexID, isLocking = true) + + // TODO: do we need to merge lock request/response (?) + else -> null + } + +} + +val ThreadAggregationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when (label) { + is ActorLabel -> null + is MemoryAccessLabel -> MemoryAccessAggregationAlgebra.syncType(label) + is MutexLabel -> MutexAggregationAlgebra.syncType(label) + else -> ReceiveAggregationAlgebra.syncType(label) + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when (label) { + is ActorLabel -> null + is MemoryAccessLabel -> MemoryAccessAggregationAlgebra.synchronize(label, other) + is MutexLabel -> MutexAggregationAlgebra.synchronize(label, other) + else -> ReceiveAggregationAlgebra.synchronize(label, other) + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventIndexing.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventIndexing.kt new file mode 100644 index 000000000..9bb822dba --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventIndexing.kt @@ -0,0 +1,289 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.util.* + + +typealias EventIndexClassifier = (E) -> Pair? + +interface EventIndex, K : Any> { + + operator fun get(category: C, key: K): SortedList + + fun enumerator(category: C, key: K): Enumerator? +} + +interface MutableEventIndex, K : Any> : EventIndex { + + val classifier: EventIndexClassifier + + fun index(category: C, key: K, event: E) + + fun index(event: E) { + val (category, key) = classifier(event) ?: return + index(category, key, event) + } + + fun index(events: Collection) { + events.enumerationOrderSorted().forEach { index(it) } + } + + fun rebuild(events: Collection) { + reset() + index(events) + } + + fun reset() + +} + +inline fun , K : Any> EventIndex( + noinline classifier: EventIndexClassifier +): EventIndex { + return MutableEventIndex(classifier) +} + +inline fun , K : Any> MutableEventIndex( + noinline classifier: EventIndexClassifier +): MutableEventIndex { + return EventIndexImpl.create(classifier) +} + +// TODO: make this class private +class EventIndexImpl, K : Any> private constructor( + nCategories: Int, + override val classifier: EventIndexClassifier +) : MutableEventIndex { + + private val index = Array>>(nCategories) { mutableMapOf() } + + override operator fun get(category: C, key: K): SortedList { + return index[category.ordinal][key] ?: sortedListOf() + } + + override fun index(category: C, key: K, event: E) { + index[category.ordinal].updateInplace(key, default = sortedMutableListOf()) { add(event) } + } + + override fun enumerator(category: C, key: K): Enumerator? { + return index[category.ordinal][key]?.toEnumerator() + } + + override fun reset() { + index.forEach { it.clear() } + } + + companion object { + inline fun , K : Any> create( + noinline classifier: EventIndexClassifier + ): MutableEventIndex { + // We cannot directly call the private constructor taking the number of categories, + // because calling private methods from public inline methods is forbidden. + // Neither we want to expose this constructor, because passing + // an invalid number of categories will cause runtime exceptions. + // Thus, we introduce an auxiliary factory method taking an object of enum class as a witness. + // This constructor infers the correct total number of categories through the passed witness. + val categories = enumValues() + return if (categories.isNotEmpty()) + create(categories[0], classifier) + else + create() + } + + fun , K : Any> create( + witness: C, + classifier: (E) -> Pair? + ): MutableEventIndex { + return EventIndexImpl(witness.declaringJavaClass.enumConstants.size, classifier) + } + + fun , K : Any> create(): MutableEventIndex { + return EventIndexImpl(0) { throw UnsupportedOperationException() } + } + } +} + +enum class AtomicMemoryAccessCategory { + ReadRequest, + ReadResponse, + Write, +} + +interface AtomicMemoryAccessEventIndex : EventIndex { + + // TODO: move race status maintenance logic into separate class (?) + interface LocationInfo { + val isReadWriteRaceFree: Boolean + val isWriteWriteRaceFree: Boolean + } + + val locationInfo: Map + + fun getReadRequests(location: MemoryLocation) : SortedList = + get(AtomicMemoryAccessCategory.ReadRequest, location) + + fun getReadResponses(location: MemoryLocation): SortedList = + get(AtomicMemoryAccessCategory.ReadResponse, location) + + fun getWrites(location: MemoryLocation): SortedList = + get(AtomicMemoryAccessCategory.Write, location) + +} + +interface MutableAtomicMemoryAccessEventIndex : + AtomicMemoryAccessEventIndex, + MutableEventIndex + +val AtomicMemoryAccessEventIndex.locations: Set + get() = locationInfo.keys + +val AtomicMemoryAccessEventIndex.LocationInfo.isRaceFree: Boolean + get() = isReadWriteRaceFree && isWriteWriteRaceFree + +fun AtomicMemoryAccessEventIndex.isWriteWriteRaceFree(location: MemoryLocation): Boolean = + locationInfo[location]?.isWriteWriteRaceFree ?: true + +fun AtomicMemoryAccessEventIndex.isReadWriteRaceFree(location: MemoryLocation): Boolean = + locationInfo[location]?.isReadWriteRaceFree ?: true + +fun AtomicMemoryAccessEventIndex.isRaceFree(location: MemoryLocation): Boolean = + locationInfo[location]?.isRaceFree ?: true + +fun AtomicMemoryAccessEventIndex.getLastWrite(location: MemoryLocation): AtomicThreadEvent? = + getWrites(location).lastOrNull() + + +fun AtomicMemoryAccessEventIndex(): AtomicMemoryAccessEventIndex = + MutableAtomicMemoryAccessEventIndex() + +fun MutableAtomicMemoryAccessEventIndex(): MutableAtomicMemoryAccessEventIndex = + MutableAtomicMemoryAccessEventIndexImpl() + +typealias AtomicMemoryAccessEventClassifier = + EventIndexClassifier + +private class MutableAtomicMemoryAccessEventIndexImpl : MutableAtomicMemoryAccessEventIndex { + + private data class LocationInfoData( + override var isReadWriteRaceFree: Boolean, + override var isWriteWriteRaceFree: Boolean, + // TODO: also handle case when all accesses are totally ordered, + // i.e. there is no even read-read "races" + ) : AtomicMemoryAccessEventIndex.LocationInfo { + constructor(): this(isReadWriteRaceFree = true, isWriteWriteRaceFree = true) + } + + private val _locationInfo = mutableMapOf() + + override val locationInfo: Map + get() = _locationInfo + + override val classifier: AtomicMemoryAccessEventClassifier = { event -> + val label = (event.label as? MemoryAccessLabel) + when { + label is ReadAccessLabel && label.isRequest -> + (AtomicMemoryAccessCategory.ReadRequest to label.location) + + label is ReadAccessLabel && label.isResponse -> + (AtomicMemoryAccessCategory.ReadResponse to label.location) + + label is WriteAccessLabel -> + (AtomicMemoryAccessCategory.Write to label.location) + + else -> null + } + } + + private val index = MutableEventIndex(classifier) + + override fun get(category: AtomicMemoryAccessCategory, key: MemoryLocation): SortedList = + index[category, key] + + override fun index(category: AtomicMemoryAccessCategory, key: MemoryLocation, event: AtomicThreadEvent) = + index.index(category, key, event) + + override fun index(event: AtomicThreadEvent) { + val label = (event.label as? MemoryAccessLabel) ?: return + if (label.location !in locations) { + /* If the indexed event is the first memory access to a given location, + * then we also need to add the object allocation event + * to the index for this memory location. + */ + index.index(AtomicMemoryAccessCategory.Write, label.location, event.allocation!!) + /* Also initialize the race status data */ + _locationInfo[label.location] = LocationInfoData() + } + updateRaceStatus(label.location, event) + index.index(event) + } + + private fun updateRaceStatus(location: MemoryLocation, event: AtomicThreadEvent) { + val info = _locationInfo[location]!! + if (!info.isWriteWriteRaceFree && !info.isReadWriteRaceFree) + return + when { + event.label is WriteAccessLabel -> { + if (info.isWriteWriteRaceFree) { + // to detect write-write race, + // it is sufficient to check only against the latest write event + val lastWrite = getLastWrite(location)!! + info.isWriteWriteRaceFree = causalityOrder(lastWrite, event) + } + if (info.isReadWriteRaceFree) { + // to detect read-write race, + // we need to check against all the read-request events + info.isReadWriteRaceFree = getReadRequests(location).all { read -> + causalityOrder(read, event) + } + } + } + + // in race-free case, to detect read-write race + // it is sufficient to check only against the latest write event + event.label is ReadAccessLabel && info.isRaceFree -> { + val lastWrite = getLastWrite(location)!! + if (causalityOrder(lastWrite, event)) + return + check(causalityOrder.unordered(lastWrite, event)) + info.isReadWriteRaceFree = false + } + + // in case when there was already a write-write race, to detect read-write race, + // we need to check against all the write events + event.label is ReadAccessLabel && info.isReadWriteRaceFree -> { + info.isReadWriteRaceFree = getWrites(location).all { write -> + causalityOrder(write, event) + } + } + } + } + + override fun enumerator(category: AtomicMemoryAccessCategory, key: MemoryLocation): Enumerator? = + index.enumerator(category, key) + + override fun reset() { + _locationInfo.clear() + index.reset() + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventLabel.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventLabel.kt new file mode 100644 index 000000000..d486451d6 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventLabel.kt @@ -0,0 +1,1230 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.SpanLabelKind.* +import org.jetbrains.kotlinx.lincheck.util.* + +/** + * EventLabel is a base class for the hierarchy of classes + * representing semantic labels of various events performed by the program. + * + * It includes events such as + * - thread fork, join, start, or finish; + * - reads and writes from/to shared memory; + * - mutex locks and unlocks; + * and other (see subclasses of this class). + * + * Some events may form a _span_ - that is a sequence of events. + * For example, all events belonging to the execution of a method form a span. + * The start of the method execution and the end of method execution are denoted + * by two special events - method start event and method exit event respectively. + * All the events between these two events are considered to form a method's span. + * The span itself can be considered to be a composite event (see [HyperEvent]). + * Labels of span-related events, which denote the start of a span, its end, + * or the span as a whole, should set [spanKind] accordingly. + * + * + * @property kind The kind of this label (see [LabelKind]). + * @property spanKind For span-related events, denotes the kind of the span label (see [SpanLabelKind]), + * null for the non-span-related events. + * @property isBlocking Flag indicating that label is blocking. + * Mutex lock and thread join are examples of blocking labels. + * @property isUnblocked Flag indicating that blocking label is unblocked. + * For example, thread join-response is unblocked when all the threads it waits for have finished. + * For non-blocking labels, it should be set to true. + */ +sealed class EventLabel( + open val kind: LabelKind, + open val spanKind: SpanLabelKind? = null, + val isBlocking: Boolean = false, + val isUnblocked: Boolean = true, +) { + + /** + * Checks whether this label is send label. + */ + val isSend: Boolean + get() = (kind == LabelKind.Send) + + /** + * Checks whether this label is a request label. + */ + val isRequest: Boolean + get() = (kind == LabelKind.Request) + + /** + * Checks whether this label is a response label. + */ + val isResponse: Boolean + get() = (kind == LabelKind.Response) + + /** + * Checks whether this label is a receive label. + */ + val isReceive: Boolean + get() = (kind == LabelKind.Receive) + + /** + * Checks whether this label is span-related. + */ + val isSpanLabel: Boolean + get() = (spanKind != null) + + /** + * Object accesses by the operation represented by the label. + * For example, for object field memory access labels, this is the accessed object. + * If a particular subclass of labels does not access any object, + * then this property is equal to [NULL_OBJECT_ID]. + */ + open val objectID: ObjectID = NULL_OBJECT_ID + +} + +/** + * Kind of label. Can be one of the following: + * - [LabelKind.Send] --- send label. + * - [LabelKind.Request] --- request label. + * - [LabelKind.Response] --- response label. + * - [LabelKind.Receive] --- receive label. + * + * For example, [WriteAccessLabel] is an example of the send label, + * while [ReadAccessLabel] is split into request and response part. + * + * @see EventLabel + */ +enum class LabelKind { Send, Request, Response, Receive } + +val LabelKind.repr get() = when (this) { + LabelKind.Send -> "" + LabelKind.Request -> "^req" + LabelKind.Response -> "^rsp" + LabelKind.Receive -> "" +} + +/** + * Enumeration representing different kinds of span event labels. + * + * - [Start] is assigned to the label denoting the start of the span. + * - [End] is assigned to the label denoting the end of the span. + * - [Span] is assigned to the label of composite event + * containing the whole list of events of the span. + */ +enum class SpanLabelKind { Start, End, Span } + +/** + * Determines the corresponding label kind for a given span label kind. + */ +fun SpanLabelKind.toLabelKind(): LabelKind = when (this) { + SpanLabelKind.Start -> LabelKind.Request + SpanLabelKind.End -> LabelKind.Response + SpanLabelKind.Span -> LabelKind.Receive +} + +/** + * Enum class representing different types of event labels. + */ +enum class LabelType { + Initialization, + ObjectAllocation, + ThreadStart, + ThreadFinish, + ThreadFork, + ThreadJoin, + ReadAccess, + WriteAccess, + ReadModifyWriteAccess, + CoroutineSuspend, + CoroutineResume, + Lock, + Unlock, + Wait, + Notify, + Park, + Unpark, + Actor, + Random, +} + +/** + * Type of the label. + * + * @see LabelType + */ +val EventLabel.type: LabelType get() = when (this) { + is InitializationLabel -> LabelType.Initialization + is ObjectAllocationLabel -> LabelType.ObjectAllocation + is ThreadStartLabel -> LabelType.ThreadStart + is ThreadFinishLabel -> LabelType.ThreadFinish + is ThreadForkLabel -> LabelType.ThreadFork + is ThreadJoinLabel -> LabelType.ThreadJoin + is ReadAccessLabel -> LabelType.ReadAccess + is WriteAccessLabel -> LabelType.WriteAccess + is ReadModifyWriteAccessLabel -> LabelType.ReadModifyWriteAccess + is CoroutineSuspendLabel -> LabelType.CoroutineSuspend + is CoroutineResumeLabel -> LabelType.CoroutineResume + is LockLabel -> LabelType.Lock + is UnlockLabel -> LabelType.Unlock + is WaitLabel -> LabelType.Wait + is NotifyLabel -> LabelType.Notify + is ParkLabel -> LabelType.Park + is UnparkLabel -> LabelType.Unpark + is ActorLabel -> LabelType.Actor + is RandomLabel -> LabelType.Random +} + + +/* ************************************************************************* */ +/* Initialization labels */ +/* ************************************************************************* */ + + +/** + * A special label of the virtual root event of every execution. + * + * Initialization label stores the initial values for: + * - static memory locations, and + * - memory locations of external objects - these are the objects + * allocated outside the tracked code sections. + * + * @property initThreadID thread id used for the special initial thread; + * this thread should contain only the initialization event itself. + * @property mainThreadID thread id of the main thread, starting the execution of a program. + * @property memoryInitializer a callback performing a load of the initial values + * of a passed memory location. + */ +class InitializationLabel( + val initThreadID: ThreadID, + val mainThreadID: ThreadID, + val memoryInitializer: MemoryIDInitializer, +) : EventLabel(LabelKind.Send) { + + private val staticMemory = + HashMap() + + private val _objectsAllocations = + HashMap() + + val objectsAllocations: Map + get() = _objectsAllocations + + val externalObjects: Set + get() = objectsAllocations.keys + + fun getInitialValue(location: StaticFieldMemoryLocation): ValueID { + return staticMemory.computeIfAbsent(location) { memoryInitializer(it) } + } + + fun trackExternalObject(className: String, objID: ObjectID) { + _objectsAllocations[objID] = ObjectAllocationLabel(className, objID, memoryInitializer) + } + + override fun toString(): String = "Init" + +} + + +/** + * Represents a label for object allocation events. + * + * @property className The name of the class of the allocated object. + * @property objectID The ID of the allocated object. + * @property memoryInitializer a callback performing a load of the initial values + * of the allocated object's memory locations. + */ +data class ObjectAllocationLabel( + val className: String, + override val objectID: ObjectID, + val memoryInitializer: MemoryIDInitializer, +) : EventLabel(kind = LabelKind.Send) { + + init { + require(objectID != NULL_OBJECT_ID) + } + + private val initialValues = HashMap() + + fun getInitialValue(location: MemoryLocation): ValueID { + require(location.objID == objectID) + return initialValues.computeIfAbsent(location) { memoryInitializer(it) } + } + + override fun toString(): String = + "Alloc(${objRepr(className, objectID)})" + +} + + +/** + * Interprets the initialization label as an object allocation label. + * + * Initialization label is only responsible for storing information about the external objects --- + * these are the objects created outside the tracked code sections. + * Such objects should be registered via [InitializationLabel.trackExternalObject] method. + * + * @param objID The ObjectID of an external object. + * @return The object allocation label associated with the given ObjectID, + * or null if there is no external object with given id exists. + */ +fun InitializationLabel.asObjectAllocationLabel(objID: ObjectID): ObjectAllocationLabel? = + objectsAllocations[objID] + +/** + * Attempts to interpret a given event label as an object allocation label. + * + * @param objID The ObjectID of the object to match. + * @return The [ObjectAllocationLabel] if the label can be interpreted as it, null otherwise. + */ +fun EventLabel.asObjectAllocationLabel(objID: ObjectID): ObjectAllocationLabel? = when (this) { + is ObjectAllocationLabel -> takeIf { it.objectID == objID } + is InitializationLabel -> asObjectAllocationLabel(objID) + else -> null +} + + +/* ************************************************************************* */ +/* Thread events labels */ +/* ************************************************************************* */ + + +/** + * Base class for all thread event labels. + * + * @param kind the kind of this label. + * @param isBlocking flag indicating that label is blocking. + * @param isUnblocked flag indicating that blocking label is unblocked. + */ +sealed class ThreadEventLabel( + kind: LabelKind, + isBlocking: Boolean = false, + isUnblocked: Boolean = true, +): EventLabel( + kind = kind, + isBlocking = isBlocking, + isUnblocked = isUnblocked +) + +/** + * Label representing fork of a set of threads. + * + * @param forkThreadIds a set of thread ids this fork spawns. + */ +data class ThreadForkLabel( + val forkThreadIds: Set, +): ThreadEventLabel(kind = LabelKind.Send) { + + override fun toString(): String = + "ThreadFork(${forkThreadIds})" +} + +/** + * Interprets the initialization label as a thread fork of the main thread. + */ +fun InitializationLabel.asThreadForkLabel() = + ThreadForkLabel(setOf(mainThreadID)) + +/** + * Attempts to interpret a given event label as a thread fork label. + * + * @return The [ThreadForkLabel] if the label can be interpreted as it, null otherwise. + */ +fun EventLabel.asThreadForkLabel(): ThreadForkLabel? = when (this) { + is ThreadForkLabel -> this + is InitializationLabel -> asThreadForkLabel() + else -> null +} + +/** + * Label of a virtual event put into the beginning of each thread. + * + * @param kind the kind of this label: [LabelKind.Request], [LabelKind.Response] or [LabelKind.Receive]. + * @param threadId thread id of starting thread. + */ +data class ThreadStartLabel( + override val kind: LabelKind, + val threadId: Int, +): ThreadEventLabel(kind) { + + init { + require(isRequest || isResponse || isReceive) + } + + override fun toString(): String = + "ThreadStart${kind.repr}" +} + +/** + * Label of a virtual event put into the end of each thread. + * + * Thread finish label is considered to be always blocked. + * + * @param finishedThreadIds set of threads that have been finished. + */ +data class ThreadFinishLabel( + val finishedThreadIds: Set +): ThreadEventLabel( + kind = LabelKind.Send, + isBlocking = true, + isUnblocked = false, +) { + constructor(threadId: Int): this(setOf(threadId)) + + override fun toString(): String = + "ThreadFinish" +} + +/** + * Label representing join of a set of threads. + * + * Thread join label is blocking label. + * It is considered unblocked when the set of threads to-be-joined + * becomes empty (see [joinThreadIds]). + * + * @param kind the kind of this label: [LabelKind.Request], [LabelKind.Response] or [LabelKind.Receive]. + * @param joinThreadIds set of threads this label awaits to join. + */ +data class ThreadJoinLabel( + override val kind: LabelKind, + val joinThreadIds: Set, +): ThreadEventLabel( + kind = kind, + isBlocking = true, + isUnblocked = (kind != LabelKind.Request) implies joinThreadIds.isEmpty(), +) { + + init { + require(isRequest || isResponse || isReceive) + require(isReceive implies isUnblocked) + } + + override fun toString(): String = + "ThreadJoin${kind.repr}(${joinThreadIds})" +} + +/** + * Attempts to interpret a given event label as a thread event label. + * + * @return The [ThreadEventLabel] if the label can be interpreted as it, null otherwise. + */ +fun EventLabel.asThreadEventLabel(): ThreadEventLabel? = when (this) { + is ThreadEventLabel -> this + is InitializationLabel -> asThreadForkLabel() + else -> null +} + + +/* ************************************************************************* */ +/* Memory access labels */ +/* ************************************************************************* */ + + +/** + * Base class of shared memory access labels. + * + * It stores common information about memory accesses, + * such as the accessed memory location and read or written value. + * + * @param kind The kind of this label. + * @param location The accessed memory location. + * @param isExclusive flag indicating whether this access is exclusive. + * Memory accesses obtained as a result of executing atomic read-modify-write + * instructions (such as CAS) have this flag set. + * @param kClass class of written or read value. + * @param codeLocation the code location corresponding to the memory access. + */ +sealed class MemoryAccessLabel( + kind: LabelKind, + open val location: MemoryLocation, + open val readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null, + open val codeLocation: Int = UNKNOWN_CODE_LOCATION, +): EventLabel(kind) { + + /** + * Read value for read access. + */ + abstract val readValue: ValueID + + /** + * Written value for write access. + */ + abstract val writeValue: ValueID + + /** + * Checks whether this memory access is a read access. + */ + val isRead: Boolean + get() = (accessKind == MemoryAccessKind.Read || accessKind == MemoryAccessKind.ReadModifyWrite) + + /** + * Checks whether this memory access is a write access. + */ + val isWrite: Boolean + get() = (accessKind == MemoryAccessKind.Write || accessKind == MemoryAccessKind.ReadModifyWrite) + + /** + * The id of the accessed object. + */ + override val objectID: ObjectID + get() = location.objID + + val isExclusive: Boolean + get() = (readModifyWriteDescriptor != null) + + override fun toString(): String { + val exclString = if (isExclusive) "_ex" else "" + val argsString = listOfNotNull( + "$location", + if (isRead && kind != LabelKind.Request) "$readValue" else null, + if (isWrite) "$writeValue" else null, + ).joinToString() + return "${accessKind}${kind.repr}${exclString}(${argsString})" + } + +} + + +/** + * Kind of memory access. + * + * Memory access can either be a read access, write access, or atomic read-modify-write access + * (for example, compare-and-set or atomic increment). + */ +enum class MemoryAccessKind { Read, Write, ReadModifyWrite } + +/** + * Kind of the memory access. + * + * @see MemoryAccessKind + */ +val MemoryAccessLabel.accessKind: MemoryAccessKind + get() = when(this) { + is WriteAccessLabel -> MemoryAccessKind.Write + is ReadAccessLabel -> MemoryAccessKind.Read + is ReadModifyWriteAccessLabel -> MemoryAccessKind.ReadModifyWrite + } + +sealed class ReadModifyWriteDescriptor { + data class GetAndSetDescriptor(val newValue: ValueID): ReadModifyWriteDescriptor() + data class CompareAndSetDescriptor(val expectedValue: ValueID, val newValue: ValueID): ReadModifyWriteDescriptor() + data class CompareAndExchangeDescriptor(val expectedValue: ValueID, val newValue: ValueID): ReadModifyWriteDescriptor() + data class FetchAndAddDescriptor(val delta: ValueID, val kind: IncrementKind): ReadModifyWriteDescriptor() + enum class IncrementKind { Pre, Post } +} + +/** + * Label denoting a read access to shared memory. + * + * @param kind The kind of this label: [LabelKind.Request], [LabelKind.Response], or [LabelKind.Receive]. + * @param location The memory location of this read. + * @param readValue The read value; for read-request label should be equal to null. + * @param isExclusive Exclusive access flag. + * @param kClass The class of read value. + * @param codeLocation the code location corresponding to the read access. + */ +data class ReadAccessLabel( + override val kind: LabelKind, + override val location: MemoryLocation, + override val readValue: ValueID, + override val readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null, + override val codeLocation: Int = UNKNOWN_CODE_LOCATION, +): MemoryAccessLabel(kind, location, readModifyWriteDescriptor, codeLocation) { + + init { + require(isRequest || isResponse || isReceive) + require(isRequest implies (value == NULL_OBJECT_ID)) + } + + val value: ValueID + get() = readValue + + override val writeValue: ValueID = NULL_OBJECT_ID + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting a write access to shared memory. + * + * @param location The memory location affected by this write access. + * @param writeValue The written value. + * @param isExclusive Exclusive access flag. + * @param kClass The class of written value. + * @param codeLocation the code location corresponding to the write access. + */ +data class WriteAccessLabel( + override val location: MemoryLocation, + override val writeValue: ValueID, + override val readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null, + override val codeLocation: Int = UNKNOWN_CODE_LOCATION, +): MemoryAccessLabel(LabelKind.Send, location, readModifyWriteDescriptor, codeLocation) { + + val value: ValueID + get() = writeValue + + override val readValue: ValueID = NULL_OBJECT_ID + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting read-modify-write (RMW) access to shared memory + * (for example, compare-and-swap or atomic increment). + * + * @param location The memory location affected by this access. + * @param readValue The read value. + * @param writeValue The written value. + * @param kClass the class of written value. + * @param codeLocation the code location corresponding to the memory access. + */ +data class ReadModifyWriteAccessLabel( + override val kind: LabelKind, + override val location: MemoryLocation, + override val readValue: ValueID, + override val writeValue: ValueID, + override val readModifyWriteDescriptor: ReadModifyWriteDescriptor, + override val codeLocation: Int = UNKNOWN_CODE_LOCATION, +): MemoryAccessLabel(kind, location, readModifyWriteDescriptor, codeLocation) { + + init { + require(kind == LabelKind.Response || kind == LabelKind.Receive) + } + + override fun toString(): String = + super.toString() +} + +/** + * Attempts to create a read-modify-write (RMW) access label based on read and write labels. + * + * @param read The read access label. + * @param write The write access label. + * @return The resulting read-modify-write access label if the read and write labels match, null otherwise. + */ +fun ReadModifyWriteAccessLabel(read: ReadAccessLabel, write: WriteAccessLabel): ReadModifyWriteAccessLabel? { + require(read.kind == LabelKind.Response || read.kind == LabelKind.Receive) + return if (read.isExclusive && + read.location == write.location && + read.readModifyWriteDescriptor == write.readModifyWriteDescriptor && + read.codeLocation == write.codeLocation) { + ReadModifyWriteAccessLabel( + kind = read.kind, + location = read.location, + readValue = read.value, + writeValue = write.value, + readModifyWriteDescriptor = read.readModifyWriteDescriptor!!, + codeLocation = read.codeLocation, + ) + } + else null +} + + +/** + * Attempts to create a read-modify-write (RMW) access label based on this read label and given write label. + * + * @param write The write access label. + * @return The resulting read-modify-write access label if the read and write labels match, null otherwise. + */ +fun ReadAccessLabel.getReadModifyWrite(write: WriteAccessLabel): ReadModifyWriteAccessLabel? = + ReadModifyWriteAccessLabel(this, write) + +/** + * Checks whether this read access label is a valid read part of the given read-modify-write access label. + * + * @param label The read-modify-write label to check against. + * @return true if the given label is a valid read part of the RMW, false otherwise. + */ +fun ReadAccessLabel.isValidReadPart(label: ReadModifyWriteAccessLabel): Boolean = + location == label.location && + isExclusive == label.isExclusive && + (isResponse implies (readValue == label.readValue)) + +/** + * Checks whether this write access label is a valid write part of the given read-modify-write access label. + * + * @param label The read-modify-write label to check against. + * @return true if the given label is a valid write part of the RMW, false otherwise. + */ +fun WriteAccessLabel.isValidWritePart(label: ReadModifyWriteAccessLabel): Boolean = + location == label.location && + isExclusive == label.isExclusive && + writeValue == label.writeValue + +/** + * Checks whether this label can be interpreted as a write access label. + */ +fun EventLabel.isWriteAccess(): Boolean = + this is InitializationLabel || this is ObjectAllocationLabel || this is WriteAccessLabel + +/** + * Checks whether this label is exclusive write access label. + */ +fun EventLabel.isExclusiveWriteAccess(): Boolean = + (this is WriteAccessLabel) && isExclusive + +/** + * Checks whether this label can be interpreted as an initializing write access label. + */ +fun EventLabel.isInitializingWriteAccess(): Boolean = + this is InitializationLabel || this is ObjectAllocationLabel + +/** + * Checks if the initialization label can be interpreted as a write access to the given memory location. + * + * Initialization label can represent initializing writes to static memory locations, + * as well as field memory locations of external objects. + * + * @param location The memory location to check for. + */ +fun InitializationLabel.isWriteAccessTo(location: MemoryLocation): Boolean = + location is StaticFieldMemoryLocation || (location.objID in externalObjects) + +/** + * Checks if the object allocation label can be interpreted as a write access to the given memory location. + * + * Object allocation label can represent initializing writes to memory locations of the allocated object. + * + * @param location The memory location to check for. + */ +fun ObjectAllocationLabel.isWriteAccessTo(location: MemoryLocation) = + (location.objID == objectID) + +/** + * Checks if the given event label can be interpreted as a write access to the given memory location. + * + * @param location The memory location to check for. + */ +fun EventLabel.isWriteAccessTo(location: MemoryLocation): Boolean = when (this) { + is WriteAccessLabel -> (this.location == location) + is ObjectAllocationLabel -> isWriteAccessTo(location) + is InitializationLabel -> isWriteAccessTo(location) + else -> false +} + +/** + * Attempts to interpret a given initialization label as a write access label. + * + * @param location The memory location to match. + * @return The [WriteAccessLabel] if the label can be interpreted as it, null otherwise. + */ +fun InitializationLabel.asWriteAccessLabel(location: MemoryLocation): WriteAccessLabel? = when { + location is StaticFieldMemoryLocation -> + WriteAccessLabel( + location = location, + writeValue = getInitialValue(location), + codeLocation = INIT_CODE_LOCATION, + ) + + else -> asObjectAllocationLabel(location.objID)?.asWriteAccessLabel(location) +} + +/** + * Attempts to interpret a given object allocation label as a write access label. + * + * @param location The memory location to match. + * @return The [WriteAccessLabel] if the label can be interpreted as it, null otherwise. + */ +fun ObjectAllocationLabel.asWriteAccessLabel(location: MemoryLocation): WriteAccessLabel? = + if (location.objID == objectID) + WriteAccessLabel( + location = location, + writeValue = getInitialValue(location), + // TODO: use actual allocation-site code location? + codeLocation = INIT_CODE_LOCATION, + ) + else null + +/** + * Attempts to interpret a given label as a write access label. + * + * @param location The memory location to match. + * @return The [WriteAccessLabel] if the label can be interpreted as it, null otherwise. + */ +fun EventLabel.asWriteAccessLabel(location: MemoryLocation): WriteAccessLabel? = when (this) { + is WriteAccessLabel -> this.takeIf { it.location == location } + is ObjectAllocationLabel -> asWriteAccessLabel(location) + is InitializationLabel -> asWriteAccessLabel(location) + else -> null +} + +/** + * Checks if the given event label can be interpreted as a memory access to the given memory location. + * + * @param location The memory location to check for. + */ +fun EventLabel.isMemoryAccessTo(location: MemoryLocation): Boolean = when (this) { + is MemoryAccessLabel -> (this.location == location) + is ObjectAllocationLabel -> isWriteAccessTo(location) + is InitializationLabel -> isWriteAccessTo(location) + else -> false +} + +/** + * Attempts to interpret a given label as a memory access label. + * + * @param location The memory location to match. + * @return The [MemoryAccessLabel] if the label can be interpreted as it, null otherwise. + */ +fun EventLabel.asMemoryAccessLabel(location: MemoryLocation): MemoryAccessLabel? = when (this) { + is MemoryAccessLabel -> this.takeIf { it.location == location } + is ObjectAllocationLabel -> asWriteAccessLabel(location) + is InitializationLabel -> asWriteAccessLabel(location) + else -> null +} + + +/* ************************************************************************* */ +/* Mutex events labels */ +/* ************************************************************************* */ + + +/** + * Base class of all mutex operations event labels. + * + * It stores common information about mutex operations, + * such as the accessed lock object. + * + * @param kind The kind of this label. + * @param mutexID The id of the mutex object to perform operation on. + * @param isBlocking Flag indicating whether this label is blocking. + * @param isUnblocked Flag indicating whether this blocking label is already unblocked. + */ +sealed class MutexLabel( + kind: LabelKind, + open val mutexID: ObjectID, + isBlocking: Boolean = false, + isUnblocked: Boolean = true, +): EventLabel( + kind = kind, + isBlocking = isBlocking, + isUnblocked = isUnblocked +) { + + /** + * The id of the accessed object. + */ + override val objectID: ObjectID + get() = mutexID + + override fun toString(): String { + return "${operationKind}${kind.repr}($mutexID)" + } + +} + + +/** + * Kind of mutex operation. + * + * Mutex operation can either be lock, unlock, wait or notify. + */ +enum class MutexOperationKind { Lock, Unlock, Wait, Notify } + +/** + * Kind of mutex operation. + * + * @see MutexOperationKind + */ +val MutexLabel.operationKind: MutexOperationKind + get() = when(this) { + is LockLabel -> MutexOperationKind.Lock + is UnlockLabel -> MutexOperationKind.Unlock + is WaitLabel -> MutexOperationKind.Wait + is NotifyLabel -> MutexOperationKind.Notify + } + + +/** + * Label denoting lock of a mutex. + * + * @param kind The kind of this label: [LabelKind.Request] or [LabelKind.Response]. + * @param mutexID The id of the locked mutex object. + * @param isReentry Flag indicating whether this lock operation is re-entry lock. + * @param reentrancyDepth The re-entrance depth of this lock operation. + * @param isSynthetic Flag indicating whether this lock operation is synthetic. + * For example, a wait-response operation can be represented as a wait-response event, + * followed by a synthetic lock operation. + */ +data class LockLabel( + override val kind: LabelKind, + override val mutexID: ObjectID, + val isReentry: Boolean = false, + val reentrancyDepth: Int = 1, + val isSynthetic: Boolean = false, +) : MutexLabel( + kind = kind, + mutexID = mutexID, + isBlocking = true, + isUnblocked = (kind != LabelKind.Request), +) { + init { + require(isRequest || isResponse) + } + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting unlock of a mutex. + * + * @param mutexID The id of the locked mutex object. + * @param isReentry Flag indicating whether this unlock operation is re-entry lock. + * @param reentrancyDepth The re-entrance depth of this unlock operation. + * @param isSynthetic Flag indicating whether this lock operation is synthetic. + * For example, a wait-response operation can be represented as a wait-response event, + * followed by a synthetic lock operation. + */ +data class UnlockLabel( + override val mutexID: ObjectID, + val isReentry: Boolean = false, + val reentrancyDepth: Int = 1, + val isSynthetic: Boolean = false, +) : MutexLabel(LabelKind.Send, mutexID) { + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting wait on a mutex. + * + * @param kind The kind of this label: [LabelKind.Request] or [LabelKind.Response]. + * @param mutexID The id of the mutex object to wait on. + * @param isLocking Flag indicating whether this wait operation also performs lock of the mutex. + * @param isUnlocking Flag indicating whether this wait operation also performs unlock of the mutex. + */ +data class WaitLabel( + override val kind: LabelKind, + override val mutexID: ObjectID, + val isLocking: Boolean = false, + val isUnlocking: Boolean = false, +) : MutexLabel( + kind = kind, + mutexID = mutexID, + isBlocking = true, + isUnblocked = (kind != LabelKind.Request), +) { + init { + require(isRequest || isResponse) + require(isRequest implies !isLocking) + require(isResponse implies !isUnlocking) + } + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting notification of a mutex. + * + * @param mutexID The id of the mutex object to notify. + * @param isBroadcast Flag indicating that this notification is broadcast, + * that is created by a `notifyAll()` method call. + */ +data class NotifyLabel( + override val mutexID: ObjectID, + val isBroadcast: Boolean +) : MutexLabel(LabelKind.Send, mutexID) { + + override fun toString(): String = + super.toString() +} + +/** + * Checks whether this label can be interpreted as an initializing synthetic unlock label. + */ +fun EventLabel.isInitializingUnlock(): Boolean = + this is InitializationLabel || this is ObjectAllocationLabel + +/** + * Interprets the initialization label as an unlock label. + * + * Initialization label can represent the first synthetic unlock label of some external objects. + * + * @param mutexID The id of an external object on which unlock is performed. + * @return The unlock label associated with the given mutex, + * or null if there is no external mutex object with given id exists. + */ +fun InitializationLabel.asUnlockLabel(mutexID: ObjectID) = + asObjectAllocationLabel(mutexID)?.asUnlockLabel(mutexID) + +/** + * Interprets the object allocation label as an unlock label. + * + * Object allocation label can represent the first synthetic unlock label of + * the given mutex object. + * + * @param mutexID The id of an external object on which unlock is performed. + * @return The unlock label associated with the given mutex, + * or null if there is no external mutex object with given id exists. + */ +fun ObjectAllocationLabel.asUnlockLabel(mutexID: ObjectID) = + if (mutexID == objectID) UnlockLabel(mutexID = objectID, isSynthetic = true) else null + + +/** + * Attempts to interpret a given label as an unlock label. + * + * @param mutexID The id of the mutex object to match. + * @return The [UnlockLabel] if the label can be interpreted as it, null otherwise. + * + */ +fun EventLabel.asUnlockLabel(mutexID: ObjectID): UnlockLabel? = when (this) { + is UnlockLabel -> this.takeIf { it.mutexID == mutexID } + is ObjectAllocationLabel -> asUnlockLabel(mutexID) + is InitializationLabel -> asUnlockLabel(mutexID) + else -> null +} + +/** + * Attempts to interpret a given label a notify label. + * + * @param mutexID The id of the mutex object to match. + * @return The [NotifyLabel] if the label can be interpreted as it, null otherwise. + * + */ +fun EventLabel.asNotifyLabel(mutexID: ObjectID): NotifyLabel? = when (this) { + is NotifyLabel -> this.takeIf { it.mutexID == mutexID } + else -> null +} + + +/* ************************************************************************* */ +/* Parking labels */ +/* ************************************************************************* */ + + +/** + * Base class for park and unpark event labels. + * + * @param kind The kind of this label. + * @param threadId The thread id of parked or unparked thread. + * @param isBlocking Flag indicating that label is blocking. + * @param isUnblocked Flag indicating that blocking label is unblocked. + */ +sealed class ParkingEventLabel( + kind: LabelKind, + open val threadId: Int, + isBlocking: Boolean = false, + isUnblocked: Boolean = true, +): EventLabel( + kind = kind, + isBlocking = isBlocking, + isUnblocked = isUnblocked +) { + + override fun toString(): String { + val argsString = if (operationKind == ParkingOperationKind.Unpark) "($threadId)" else "" + return "${operationKind}${kind.repr}${argsString}" + } + +} + +/** + * Kind of parking operation. + */ +enum class ParkingOperationKind { Park, Unpark } + +/** + * Kind of parking operation. + * + * @see ParkingOperationKind + */ +val ParkingEventLabel.operationKind: ParkingOperationKind + get() = when (this) { + is ParkLabel -> ParkingOperationKind.Park + is UnparkLabel -> ParkingOperationKind.Unpark + } + +/** + * Label denoting park operation of a thread. + * + * @param kind the kind of this label: [LabelKind.Request] or [LabelKind.Response]. + * @param threadId the thread id of the parked thread. + */ +data class ParkLabel( + override val kind: LabelKind, + override val threadId: Int, +) : ParkingEventLabel( + kind = kind, + threadId = threadId, + isBlocking = true, + isUnblocked = (kind != LabelKind.Request), +) { + init { + require(isRequest || isResponse || isReceive) + } + + override fun toString(): String = + super.toString() +} + +/** + * Label denoting unpark operation of a thread. + * + * @param threadId the thread id of the thread to unpark. + */ +data class UnparkLabel( + override val threadId: Int, +) : ParkingEventLabel(LabelKind.Send, threadId) { + + override fun toString(): String = + super.toString() +} + + +/* ************************************************************************* */ +/* Coroutine labels */ +/* ************************************************************************* */ + + +/** + * Base class of all coroutine operations event labels. + * + * It stores common information about coroutine operation, + * such as the thread id and actor id of the coroutine. + * + * @param kind The kind of this label. + * @param threadId The thread id of the coroutine. + * @param actorId The actor id of the coroutine. + * @param isBlocking Flag indicating whether this label is blocking. + * @param isUnblocked Flag indicating whether this blocking label is already unblocked. + */ +sealed class CoroutineLabel( + override val kind: LabelKind, + open val threadId: Int, + open val actorId: Int, + isBlocking: Boolean = false, + isUnblocked: Boolean = true, +) : EventLabel( + kind = kind, + isBlocking = isBlocking, + isUnblocked = isUnblocked +) { + + override fun toString(): String { + val operationKind = when (this) { + is CoroutineSuspendLabel -> "Suspend" + is CoroutineResumeLabel -> "Resume" + } + val status = when { + this is CoroutineSuspendLabel && !cancelled -> ": resumed" + this is CoroutineSuspendLabel && cancelled -> ": cancelled" + else -> "" + } + return "${operationKind}${kind.repr}($threadId, $actorId)$status" + } + +} + +/** + * Label denoting coroutine suspend operation. + * + * @param kind The kind of this label: [LabelKind.Request] or [LabelKind.Response]. + * @param threadId The thread id of the coroutine. + * @param actorId The actor id of the coroutine. + * @param cancelled The flag indicating whether the coroutine was canceled. + * @param promptCancellation The flag indicating whether the coroutine is subject + * to the prompt cancellation guarantee. + */ +data class CoroutineSuspendLabel( + override val kind: LabelKind, + override val threadId: Int, + override val actorId: Int, + val cancelled: Boolean = false, + val promptCancellation: Boolean = false, + // TODO: should we also keep resume value? +) : CoroutineLabel( + kind = kind, + threadId = threadId, + actorId = actorId, + isBlocking = true, + isUnblocked = (kind != LabelKind.Request), +) { + init { + require(isRequest || isResponse || isReceive) + require(promptCancellation implies isRequest) + require(cancelled implies (isResponse || isReceive)) + } + + override fun toString(): String = + super.toString() + +} + +/** + * Label denoting coroutine resumption operation. + * + * @param threadId The thread id of the coroutine to be resumed. + * @param actorId The actor id of the coroutine to be resumed. + */ +data class CoroutineResumeLabel( + override val threadId: Int, + override val actorId: Int, + // TODO: should we also keep resume value? +) : CoroutineLabel(LabelKind.Send, threadId, actorId) { + + override fun toString(): String = + super.toString() + +} + + +/* ************************************************************************* */ +/* Miscellaneous */ +/* ************************************************************************* */ + + +/** + * Label denoting an actor operation: either start or an end of actor method execution. + * + * @property spanKind The kind of the actor label (see [SpanLabelKind]). + * @property threadId The id of the thread on which the actor is executing. + * @property actor The actor descriptor. + */ +// TODO: generalize actor labels to method call/return labels? +data class ActorLabel( + override val spanKind: SpanLabelKind, + val threadId: ThreadID, + val actor: Actor +) : EventLabel( + kind = spanKind.toLabelKind(), + spanKind = spanKind +) + +/** + * Label denoting a call of the random number generator. + * + * @property value The generated random value. + */ +data class RandomLabel(val value: Int): EventLabel(kind = LabelKind.Send) + + +// special code location used for initializing write events +private const val INIT_CODE_LOCATION = -1 + +// special code location denoting unknown code location +private const val UNKNOWN_CODE_LOCATION = -2 \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructure.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructure.kt new file mode 100644 index 000000000..dbe84aaae --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructure.kt @@ -0,0 +1,1233 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency.* +import org.jetbrains.kotlinx.lincheck.util.* + + +class EventStructure( + nParallelThreads: Int, + val memoryInitializer: MemoryInitializer, + // TODO: refactor --- avoid using callbacks! + private val reportInconsistencyCallback: ReportInconsistencyCallback, + private val internalThreadSwitchCallback: InternalThreadSwitchCallback, +) { + val mainThreadId = 0 + val initThreadId = nParallelThreads + val maxThreadId = initThreadId + val nThreads = maxThreadId + 1 + + /** + * Mutable list of the event structure events. + */ + private val _events = sortedMutableListOf() + + /** + * List of the event structure events. + */ + val events: SortedList = _events + + /** + * Root event of the whole event structure. + * Its label is [InitializationLabel]. + */ + @SuppressWarnings("WeakerAccess") + val root: AtomicThreadEvent + + /** + * The root event of the currently being explored execution. + * In other words, it is a choice point that has led to the current exploration. + * + * The label of this event should be of [LabelKind.Response] kind. + */ + @SuppressWarnings("WeakerAccess") + lateinit var currentExplorationRoot: Event + private set + + /** + * Mutable list of backtracking points. + */ + private val backtrackingPoints = sortedMutableListOf() + + + /** + * The mutable execution currently being explored. + */ + private var _execution = MutableExtendedExecution(this.nThreads) + + /** + * The execution currently being explored. + */ + val execution: ExtendedExecution + get() = _execution + + /** + * The frontier representing an already replayed part of the execution currently being explored. + */ + private var playedFrontier = MutableExecutionFrontier(this.nThreads) + + /** + * An object managing the replay process of the execution currently being explored. + */ + private var replayer = Replayer() + + /** + * Synchronization algebra used for synchronization of events. + */ + @SuppressWarnings("WeakerAccess") + val syncAlgebra: SynchronizationAlgebra = AtomicSynchronizationAlgebra + + /** + * The frontier encoding the subset of pinned events of the execution currently being explored. + * Pinned events cannot be revisited and thus do not participate in the synchronization + * with newly added events. + */ + private var pinnedEvents = ExecutionFrontier(this.nThreads) + + /** + * The object registry, storing information about all objects + * created during the execution currently being explored. + */ + val objectRegistry = ObjectRegistry() + + /** + * For each blocked thread, stores a descriptor of the blocked event. + * + * The thread may become blocked when it issues a request-event + * denoting some blocking operation (e.g., mutex lock), + * such that the corresponding response-event cannot be created immediately + * (for example, because the mutex is already acquired by another thread). + * + * Therefore, a blocked event is a composite event that may consist of either: + * - a blocked request event alone; + * - or a blocked request event followed by an unblocking response event. + * In the latter case, we say that the event becomes effectively unblocked. + * The unblocking response is typically created as a result + * of synchronization with another event when it is added + * to the currently explored [execution]. + * + * However, at the point when the unblocking response is created + * the unblocking response may not be added to + * the currently explored [execution] immediately. + * It is added later when the blocked thread is scheduled again. + * At this point the current [BlockedEventDescriptor] is removed + * from the [blockedEvents] mapping, and the thread becomes fully unblocked. + */ + private val blockedEvents: MutableThreadMap = + ArrayIntMap(this.nThreads) + + private val delayedConsistencyCheckBuffer = mutableListOf() + + private val readCodeLocationsCounter = mutableMapOf, Int>() + + init { + root = addRootEvent() + objectRegistry.initialize(root) + } + + /* ************************************************************************* */ + /* Exploration */ + /* ************************************************************************* */ + + fun startNextExploration(): Boolean { + loop@while (true) { + val backtrackingPoint = rollbackTo { !it.visited } + ?: return false + backtrackingPoint.visit() + resetExploration(backtrackingPoint) + return true + } + } + + fun initializeExploration() { + // reset re-played frontier + playedFrontier = MutableExecutionFrontier(nThreads) + playedFrontier[initThreadId] = execution[initThreadId]!!.last() + // reset replayer state + replayer.reset() + if (replayer.inProgress()) { + replayer.currentEvent.ensure { + it != null && it.label is InitializationLabel + } + replayer.setNextEvent() + } + // reset object indices --- retain only external events + objectRegistry.retain { it.isExternal } + // reset state of other auxiliary structures + delayedConsistencyCheckBuffer.clear() + readCodeLocationsCounter.clear() + } + + fun abortExploration() { + // we abort the current exploration by resetting the current execution to its replayed part; + // however, we need to handle blocking request in a special way --- we include their response part + // to detect potential blocking response uniqueness violations + // (e.g., two lock events unblocked by the same unlock event) + for ((tid, event) in playedFrontier.threadMap.entries) { + if (event == null) + continue + if (!(event.label.isRequest && event.label.isBlocking)) + continue + val response = execution[tid, event.threadPosition + 1] + ?: continue + check(response.label.isResponse) + // skip the response if it does not depend on any re-played event + if (response.dependencies.any { it !in playedFrontier }) + continue + playedFrontier.update(response) + } + _execution.reset(playedFrontier) + } + + private fun rollbackTo(predicate: (BacktrackingPoint) -> Boolean): BacktrackingPoint? { + val idx = backtrackingPoints.indexOfLast(predicate) + val backtrackingPoint = backtrackingPoints.getOrNull(idx) + val eventIdx = events.indexOfLast { it == backtrackingPoint?.event } + backtrackingPoints.subList(idx + 1, backtrackingPoints.size).clear() + _events.subList(eventIdx + 1, events.size).clear() + return backtrackingPoint + } + + private fun resetExploration(backtrackingPoint: BacktrackingPoint) { + // get the event to backtrack to + val event = backtrackingPoint.event.ensure { + it.label is InitializationLabel || it.label.isResponse + } + // reset blocked events + blockedEvents.clear() + // set current exploration root + currentExplorationRoot = event + // reset current execution + _execution.reset(backtrackingPoint.frontier) + // copy pinned events and pin current re-exploration root event + val pinnedEvents = backtrackingPoint.pinnedEvents.copy() + .apply { set(event.threadId, event) } + // add new event to current execution + _execution.add(event) + // do the same for blocked requests + for (blockedRequest in backtrackingPoint.blockedRequests) { + _execution.add(blockedRequest) + // additionally, pin blocked requests if all their predecessors are also blocked ... + if (blockedRequest.parent == pinnedEvents[blockedRequest.threadId]) { + pinnedEvents[blockedRequest.threadId] = blockedRequest + } + // ... and also block them + blockRequest(blockedRequest) + } + // set pinned events + this.pinnedEvents = pinnedEvents.ensure { + execution.containsAll(it.events) + } + // check consistency of the whole execution + _execution.checkConsistency() + // set the replayer state + val replayOrdering = _execution.executionOrderComputable.value.ordering + replayer = Replayer(replayOrdering) + } + + fun checkConsistency(): Inconsistency? { + // TODO: set suddenInvocationResult? + return _execution.checkConsistency() + } + + /* ************************************************************************* */ + /* Event creation */ + /* ************************************************************************* */ + + /** + * Class representing a backtracking point in the exploration of the program's executions. + * + * @property event The event at which to start a new exploration. + * @property frontier The execution frontier at the point when the event was created. + * @property pinnedEvents The frontier of pinned events that should not be + * considered for exploration branching. + * @property blockedRequests The list of blocked request events. + * @property visited Flag to indicate if this backtracking point has been visited. + */ + private class BacktrackingPoint( + val event: AtomicThreadEvent, + val frontier: ExecutionFrontier, + val pinnedEvents: ExecutionFrontier, + val blockedRequests: List, + ) : Comparable { + + var visited: Boolean = false + private set + + fun visit() { + visited = true + } + + override fun compareTo(other: BacktrackingPoint): Int { + return event.id.compareTo(other.event.id) + } + } + + private fun createBacktrackingPoint(event: AtomicThreadEvent, conflicts: List) { + val frontier = execution.toMutableFrontier().apply { + cut(conflicts) + // for already unblocked dangling requests, + // also put their responses into the frontier + addUnblockingResponses(conflicts) + } + val danglingRequests = frontier.getDanglingRequests() + val blockedRequests = danglingRequests + // TODO: perhaps, we should change this to the list of requests to conflicting response events? + .filter { it.label.isBlocking && it != event.parent && (it.label !is CoroutineSuspendLabel) } + frontier.apply { + cut(danglingRequests) + set(event.threadId, event.parent) + } + val pinnedEvents = pinnedEvents.copy().apply { + val causalityFrontier = execution.calculateFrontier(event.causalityClock) + merge(causalityFrontier) + cut(conflicts) + cut(getDanglingRequests()) + cut(event) + } + val backtrackingPoint = BacktrackingPoint( + event = event, + frontier = frontier, + pinnedEvents = pinnedEvents, + blockedRequests = blockedRequests, + ) + backtrackingPoints.add(backtrackingPoint) + } + + private fun createEvent( + iThread: Int, + label: EventLabel, + parent: AtomicThreadEvent?, + dependencies: List, + visit: Boolean = true, + ): AtomicThreadEvent? { + val conflicts = getConflictingEvents(iThread, label, parent, dependencies) + if (isCausalityViolated(parent, dependencies, conflicts)) + return null + val allocation = objectRegistry[label.objectID]?.allocation + val source = (label as? WriteAccessLabel)?.writeValue?.let { + objectRegistry[it]?.allocation + } + val event = AtomicThreadEventImpl( + label = label, + parent = parent, + senders = dependencies, + allocation = allocation, + source = source, + dependencies = listOfNotNull(allocation, source) + dependencies, + ) + _events.add(event) + // if the event is not visited immediately, + // then we create a backtracking point to visit it later + if (!visit) { + createBacktrackingPoint(event, conflicts) + } + return event + } + + private fun getConflictingEvents( + iThread: Int, + label: EventLabel, + parent: AtomicThreadEvent?, + dependencies: List + ): List { + val position = parent?.let { it.threadPosition + 1 } ?: 0 + val conflicts = mutableListOf() + // if the current execution already has an event in given position --- then it is conflict + execution[iThread, position]?.also { conflicts.add(it) } + // handle label specific cases + // TODO: unify this logic for various kinds of labels? + when { + // lock-response synchronizing with our unlock is conflict + label is LockLabel && label.isResponse && !label.isReentry -> run { + val unlock = dependencies.first { it.label.asUnlockLabel(label.mutexID) != null } + execution.forEach { event -> + if (event.label.satisfies { isResponse && mutexID == label.mutexID } + && event.locksFrom == unlock) { + conflicts.add(event) + } + } + } + // wait-response synchronizing with our notify is conflict + label is WaitLabel && label.isResponse -> run { + val notify = dependencies.first { it.label is NotifyLabel } + if ((notify.label as NotifyLabel).isBroadcast) + return@run + execution.forEach { event -> + if (event.label.satisfies { isResponse && mutexID == label.mutexID } + && event.notifiedBy == notify) { + conflicts.add(event) + } + } + } + // TODO: add similar rule for read-exclusive-response? + } + return conflicts + } + + private fun isCausalityViolated( + parent: AtomicThreadEvent?, + dependencies: List, + conflicts: List, + ): Boolean { + var causalityViolation = false + // Check that parent does not depend on conflicting events. + if (parent != null) { + causalityViolation = causalityViolation || conflicts.any { conflict -> + causalityOrder.orEqual(conflict, parent) + } + } + // Also check that dependencies do not causally depend on conflicting events. + causalityViolation = causalityViolation || conflicts.any { conflict -> dependencies.any { dependency -> + causalityOrder.orEqual(conflict, dependency) + }} + return causalityViolation + } + + private fun MutableExecutionFrontier.addUnblockingResponses(conflicts: List) { + for (descriptor in blockedEvents.values) { + val request = descriptor.request + val response = descriptor.response + if (request != this[request.threadId]) + continue + if (request in conflicts || response in conflicts) + continue + if (response != null && response.dependencies.all { it in this }) { + this.update(response) + } + } + } + + private fun addEventToCurrentExecution(event: AtomicThreadEvent) { + // Check if the added event is replayed event. + val isReplayedEvent = inReplayPhase(event.threadId) + // Update current execution and replayed frontier. + if (!isReplayedEvent) { + _execution.add(event) + } + playedFrontier.update(event) + // Unblock the thread if the unblocking response was added. + if (event.label.isResponse && event.label.isBlocking && isBlockedRequest(event.request!!)) { + unblockRequest(event.request!!) + } + // If we are still in replay phase, but the added event is not a replayed event, + // then save it to delayed events buffer to postpone its further processing. + if (inReplayPhase()) { + if (!isReplayedEvent) { + delayedConsistencyCheckBuffer.add(event) + } + return + } + // If we are not in replay phase anymore, but the current event is replayed event, + // it means that we just finished replay phase (i.e. the given event is the last replayed event). + // In this case, we need to proceed with all postponed non-replayed events. + if (isReplayedEvent) { + for (delayedEvent in delayedConsistencyCheckBuffer) { + if (delayedEvent.label.isSend) { + addSynchronizedEvents(delayedEvent) + } + } + delayedConsistencyCheckBuffer.clear() + return + } + // If we are not in the replay phase and the newly added event is not replayed, then proceed as usual. + // Add synchronized events. + if (event.label.isSend) { + addSynchronizedEvents(event) + } + // Check consistency of the new event. + val inconsistency = execution.inconsistency + if (inconsistency != null) { + reportInconsistencyCallback(inconsistency) + } + } + + /* ************************************************************************* */ + /* Replaying */ + /* ************************************************************************* */ + + private class Replayer(private val executionOrder: List) { + private var index: Int = 0 + private var size: Int = 0 + + constructor(): this(listOf()) + + fun inProgress(): Boolean = + (index < size) + + val currentEvent: AtomicThreadEvent? + get() = if (inProgress()) (executionOrder[index] as? AtomicThreadEvent) else null + + fun setNextEvent() { + index++ + } + + fun reset() { + index = 0 + size = executionOrder.size + } + } + + fun inReplayPhase(): Boolean = + replayer.inProgress() + + fun inReplayPhase(iThread: Int): Boolean { + val frontEvent = playedFrontier[iThread] + ?.ensure { it in _execution } + return (frontEvent != execution.lastEvent(iThread)) + } + + // should only be called in replay phase! + fun canReplayNextEvent(iThread: Int): Boolean { + return iThread == replayer.currentEvent?.threadId + } + + private fun tryReplayEvent(iThread: Int): AtomicThreadEvent? { + if (inReplayPhase() && !canReplayNextEvent(iThread)) { + // TODO: can we get rid of this? + // we can try to enforce more ordering invariants by grouping "atomic" events + // and also grouping events for which there is no reason to make switch in-between + // (e.g. `Alloc` followed by a `Write`). + do { + internalThreadSwitchCallback(iThread, SwitchReason.STRATEGY_SWITCH) + } while (inReplayPhase() && !canReplayNextEvent(iThread)) + } + return replayer.currentEvent + ?.ensure { event -> event.dependencies.all { it in playedFrontier } } + ?.also { replayer.setNextEvent() } + } + + /* ************************************************************************* */ + /* Object tracking */ + /* ************************************************************************* */ + + fun allocationEvent(id: ObjectID): AtomicThreadEvent? { + return objectRegistry[id]?.allocation + } + + /* ************************************************************************* */ + /* Synchronization */ + /* ************************************************************************* */ + + private val EventLabel.syncType + get() = syncAlgebra.syncType(this) + + private fun EventLabel.synchronize(other: EventLabel) = + syncAlgebra.synchronize(this, other) + + private fun synchronizationCandidates(label: EventLabel): Sequence { + // TODO: generalize the checks for arbitrary synchronization algebra? + return when { + // write can synchronize with read-request events + label is WriteAccessLabel -> + execution.memoryAccessEventIndex.getReadRequests(label.location).asSequence() + + // read-request can synchronize only with write events + label is ReadAccessLabel && label.isRequest -> + execution.memoryAccessEventIndex.getWrites(label.location).asSequence() + + // read-response cannot synchronize with anything + label is ReadAccessLabel && label.isResponse -> + sequenceOf() + + // re-entry lock-request synchronizes only with initializing unlock + (label is LockLabel && label.isReentry) -> + sequenceOf(allocationEvent(label.mutexID)!!) + + // re-entry unlock does not participate in synchronization + (label is UnlockLabel && label.isReentry) -> + sequenceOf() + + // random labels do not synchronize + label is RandomLabel -> sequenceOf() + + // otherwise we pessimistically assume that any event can potentially synchronize + else -> execution.asSequence() + } + } + + private fun synchronizationCandidates(event: AtomicThreadEvent): Sequence { + val label = event.label + // consider all the candidates and apply additional filters + val candidates = synchronizationCandidates(label) + // take only the events from the current execution + .filter { it in execution } + // for a send event we additionally filter out ... + .runIf(event.label.isSend) { + filter { + // (1) all of its causal predecessors, because an attempt to synchronize with + // these predecessors will result in a causality cycle + !causalityOrder(it, event) && + // (2) pinned events, because their response part is pinned, + // unless the pinned event is blocked event + (!pinnedEvents.contains(it) || isBlockedRequest(it)) + } + } + return when { + /* For read-request events, we search for the last write to + * the same memory location in the same thread. + * We then filter out all causal predecessors of this last write, + * because these events are "obsolete" --- + * reading from them will result in coherence cycle and will violate consistency + */ + label is ReadAccessLabel && label.isRequest -> { + if (execution.memoryAccessEventIndex.isRaceFree(label.location)) { + val lastWrite = execution.memoryAccessEventIndex.getLastWrite(label.location)!! + return sequenceOf(lastWrite) + } + val threadReads = execution[event.threadId]!!.filter { + it.label.isResponse && (it.label as? ReadAccessLabel)?.location == label.location + } + val lastSeenWrite = threadReads.lastOrNull()?.readsFrom + val staleWrites = threadReads + .map { it.readsFrom } + .filter { it != lastSeenWrite } + .distinct() + val eventFrontier = execution.calculateFrontier(event.causalityClock) + val racyWrites = calculateRacyWrites(label.location, eventFrontier) + candidates.filter { + // !causalityOrder.lessThan(it, threadLastWrite) && + !racyWrites.any { write -> causalityOrder(it, write) } && + !staleWrites.any { write -> causalityOrder.orEqual(it, write) } + } + } + + label is WriteAccessLabel -> { + if (execution.memoryAccessEventIndex.isReadWriteRaceFree(label.location)) { + return sequenceOf() + } + candidates + } + + // an allocation event, at the point when it is added to the execution, + // cannot synchronize with anything, because there are no events yet + // that access the allocated object + label is ObjectAllocationLabel -> { + return sequenceOf() + } + + label is CoroutineSuspendLabel && label.isRequest -> { + // filter-out InitializationLabel to prevent creating cancellation response + // TODO: refactor!!! + candidates.filter { it.label !is InitializationLabel } + } + + else -> candidates + } + } + + /** + * Adds to the event structure a list of events obtained as a result of synchronizing given [event] + * with the events contained in the current exploration. For example, if + * `e1 @ A` is the given event labeled by `A` and `e2 @ B` is some event in the event structure labeled by `B`, + * then the resulting list will contain event labeled by `C = A \+ B` if `C` is defined (i.e. not null), + * and the list of dependencies of this new event will be equal to `listOf(e1, e2)`. + * + * @return list of added events + */ + private fun addSynchronizedEvents(event: AtomicThreadEvent): List { + val candidates = synchronizationCandidates(event) + val syncEvents = when (event.label.syncType) { + SynchronizationType.Binary -> addBinarySynchronizedEvents(event, candidates) + SynchronizationType.Barrier -> addBarrierSynchronizedEvents(event, candidates) + else -> return listOf() + } + // if there are responses to blocked dangling requests, then set the response of one of these requests + for (syncEvent in syncEvents) { + val blockedRequest = syncEvent.parent?.takeIf { isBlockedRequest(it) } + ?: continue + if (!hasUnblockingResponse(blockedRequest)) { + setUnblockingResponse(syncEvent) + // mark corresponding backtracking point as visited; + // search from the end, because the search event was added recently, + // and thus should be located near the end of the list + backtrackingPoints.last { it.event == syncEvent }.apply { visit() } + break + } + } + return syncEvents + } + + private fun addBinarySynchronizedEvents( + event: AtomicThreadEvent, + candidates: Sequence + ): List { + require(event.label.syncType == SynchronizationType.Binary) + // TODO: sort resulting events according to some strategy? + return candidates + .asIterable() + .mapNotNull { other -> + val syncLab = event.label.synchronize(other.label) + ?: return@mapNotNull null + val (parent, dependency) = when { + event.label.isRequest -> event to other + other.label.isRequest -> other to event + else -> unreachable() + } + check(parent.label.isRequest && dependency.label.isSend && syncLab.isResponse) + Triple(syncLab, parent, dependency) + }.sortedBy { (_, _, dependency) -> + dependency + }.mapNotNull { (syncLab, parent, dependency) -> + createEvent( + iThread = parent.threadId, + label = syncLab, + parent = parent, + dependencies = listOf(dependency), + visit = false, + ) + } + } + + private fun addBarrierSynchronizedEvents( + event: AtomicThreadEvent, + candidates: Sequence + ): List { + require(event.label.syncType == SynchronizationType.Barrier) + val (syncLab, dependencies) = + candidates.fold(event.label to listOf(event)) { (lab, deps), candidateEvent -> + candidateEvent.label.synchronize(lab)?.let { + (it to deps + candidateEvent) + } ?: (lab to deps) + } + if (syncLab.isBlocking && !syncLab.isUnblocked) + return listOf() + // We assume that at most, one of the events participating into synchronization + // is a request event, and the result of synchronization is a response event. + check(syncLab.isResponse) + val parent = dependencies.first { it.label.isRequest } + val responseEvent = createEvent( + iThread = parent.threadId, + label = syncLab, + parent = parent, + dependencies = dependencies.filter { it != parent }, + visit = false, + ) + return listOfNotNull(responseEvent) + } + + /* ************************************************************************* */ + /* Generic event addition utilities (per event kind) */ + /* ************************************************************************* */ + + private fun addRootEvent(): AtomicThreadEvent { + // we do not mark root event as visited purposefully; + // this is just a trick to make the first call to `startNextExploration` + // to pick the root event as the next event to explore from. + val label = InitializationLabel(initThreadId, mainThreadId) { location -> + val value = memoryInitializer(location) + objectRegistry.getOrRegisterValueID(location.type, value) + } + return createEvent(initThreadId, label, parent = null, dependencies = emptyList(), visit = false)!! + .also { event -> + val id = STATIC_OBJECT_ID + val entry = ObjectEntry(id, StaticObject.opaque(), event) + objectRegistry.register(entry) + addEventToCurrentExecution(event) + } + } + + private fun addEvent(iThread: Int, label: EventLabel, dependencies: List): AtomicThreadEvent { + tryReplayEvent(iThread)?.let { event -> + check(event.label == label) + addEventToCurrentExecution(event) + return event + } + val parent = playedFrontier[iThread] + return createEvent(iThread, label, parent, dependencies)!!.also { event -> + addEventToCurrentExecution(event) + } + } + + private fun addSendEvent(iThread: Int, label: EventLabel): AtomicThreadEvent { + require(label.isSend) + return addEvent(iThread, label, listOf()) + } + + private fun addRequestEvent(iThread: Int, label: EventLabel): AtomicThreadEvent { + require(label.isRequest) + return addEvent(iThread, label, listOf()) + } + + private fun addResponseEvents(requestEvent: AtomicThreadEvent): Pair> { + require(requestEvent.label.isRequest) + tryReplayEvent(requestEvent.threadId)?.let { event -> + check(event.label.isResponse) + check(event.parent == requestEvent) + check(event.label == event.resynchronize(syncAlgebra)) + addEventToCurrentExecution(event) + return event to listOf(event) + } + if (isBlockedRequest(requestEvent)) { + val event = getUnblockingResponse(requestEvent) + ?: return (null to listOf()) + check(event.label.isResponse) + check(event.request == requestEvent) + addEventToCurrentExecution(event) + return event to listOf(event) + } + val responseEvents = addSynchronizedEvents(requestEvent) + if (responseEvents.isEmpty()) { + blockRequest(requestEvent) + return (null to listOf()) + } + // TODO: use some other strategy to select the next event in the current exploration? + // TODO: check consistency of chosen event! + val chosenEvent = responseEvents.last().also { event -> + check(event == backtrackingPoints.last().event) + backtrackingPoints.last().visit() + addEventToCurrentExecution(event) + } + return (chosenEvent to responseEvents) + } + + /* ************************************************************************* */ + /* Blocking events handling */ + /* ************************************************************************* */ + + /** + * Descriptor of a blocked event. + * + * @property request The request part of blocked event. + * @property response The response part of the blocked event. + * + * @see [blockedEvents] + */ + class BlockedEventDescriptor(val request: AtomicThreadEvent) { + + init { + require(request.label.isRequest) + require(request.label.isBlocking) + } + + var response: AtomicThreadEvent? = null + private set + + fun setResponse(response: AtomicThreadEvent) { + require(response.label.isResponse) + require(response.label.isBlocking) + require(this.request == response.request) + check(this.response == null) + this.response = response + } + } + + private fun isBlockedRequest(request: AtomicThreadEvent): Boolean { + return (request == blockedEvents[request.threadId]?.request) + } + + private fun blockRequest(request: AtomicThreadEvent) { + require(execution.isBlockedDanglingRequest(request)) + blockedEvents.put(request.threadId, BlockedEventDescriptor(request)).ensureNull() + } + + private fun unblockRequest(request: AtomicThreadEvent) { + require(request.label.isRequest && request.label.isBlocking) + require(!execution.isBlockedDanglingRequest(request)) + check(request == blockedEvents[request.threadId]!!.request) + blockedEvents.remove(request.threadId) + } + + private fun hasUnblockingResponse(request: AtomicThreadEvent): Boolean { + return (getUnblockingResponse(request) != null) + } + + private fun getUnblockingResponse(request: AtomicThreadEvent): AtomicThreadEvent? { + require(execution.isBlockedDanglingRequest(request)) + val descriptor = blockedEvents[request.threadId].ensure { + it != null && it.request == request + } + return descriptor!!.response + } + + private fun setUnblockingResponse(response: AtomicThreadEvent) { + require(response.label.isResponse && response.label.isBlocking) + require(execution.isBlockedDanglingRequest(response.request!!)) + val descriptor = blockedEvents[response.threadId]!! + descriptor.setResponse(response) + } + + fun getPendingBlockingRequest(iThread: Int): AtomicThreadEvent? = + playedFrontier[iThread]?.takeIf { it.label.isRequest && it.label.isBlocking } + + fun isPendingUnblockedRequest(request: AtomicThreadEvent): Boolean { + require(playedFrontier.isBlockedDanglingRequest(request)) + // if we are in replay phase, then the request is unblocked + // if we can replay its response part + if (inReplayPhase(request.threadId)) { + return canReplayNextEvent(request.threadId) + } + // otherwise, the request is unblocked if its response part was already created + val descriptor = blockedEvents[request.threadId] + if (descriptor != null) { + check(request == descriptor.request) + return (descriptor.response != null) + } + return true + } + + /* ************************************************************************* */ + /* Specific event addition utilities (per event class) */ + /* ************************************************************************* */ + + fun addThreadStartEvent(iThread: Int): AtomicThreadEvent { + val label = ThreadStartLabel( + threadId = iThread, + kind = LabelKind.Request, + ) + val requestEvent = addRequestEvent(iThread, label) + val (responseEvent, responseEvents) = addResponseEvents(requestEvent) + checkNotNull(responseEvent) + check(responseEvents.size == 1) + return responseEvent + } + + fun addThreadFinishEvent(iThread: Int): AtomicThreadEvent { + val label = ThreadFinishLabel( + threadId = iThread, + ) + return addSendEvent(iThread, label) + } + + fun addThreadForkEvent(iThread: Int, forkThreadIds: Set): AtomicThreadEvent { + val label = ThreadForkLabel( + forkThreadIds = forkThreadIds + ) + return addSendEvent(iThread, label) + } + + fun addThreadJoinEvent(iThread: Int, joinThreadIds: Set): AtomicThreadEvent { + val label = ThreadJoinLabel( + kind = LabelKind.Request, + joinThreadIds = joinThreadIds, + ) + val requestEvent = addRequestEvent(iThread, label) + val (responseEvent, responseEvents) = addResponseEvents(requestEvent) + // TODO: handle case when ThreadJoin is not ready yet + checkNotNull(responseEvent) + check(responseEvents.size == 1) + return responseEvent + } + + fun addObjectAllocationEvent(iThread: Int, value: OpaqueValue): AtomicThreadEvent { + tryReplayEvent(iThread)?.let { event -> + check(event.label is ObjectAllocationLabel) + val id = event.label.objectID + val entry = ObjectEntry(id, value, event) + objectRegistry.register(entry) + addEventToCurrentExecution(event) + return event + } + val id = objectRegistry.nextObjectID + val label = ObjectAllocationLabel( + objectID = id, + className = value.unwrap().javaClass.simpleName, + memoryInitializer = { location -> + val initValue = memoryInitializer(location) + objectRegistry.getOrRegisterValueID(location.type, initValue) + }, + ) + val parent = playedFrontier[iThread] + val dependencies = listOf() + return createEvent(iThread, label, parent, dependencies)!!.also { event -> + val entry = ObjectEntry(id, value, event) + objectRegistry.register(entry) + addEventToCurrentExecution(event) + } + } + + fun addWriteEvent(iThread: Int, codeLocation: Int, location: MemoryLocation, value: ValueID, + readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null): AtomicThreadEvent { + val label = WriteAccessLabel( + location = location, + writeValue = value, // TODO: change API of other methods to also take ValueID + readModifyWriteDescriptor = readModifyWriteDescriptor, + codeLocation = codeLocation, + ) + return addSendEvent(iThread, label) + } + + fun addReadRequest(iThread: Int, codeLocation: Int, location: MemoryLocation, + readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null): AtomicThreadEvent { + // we create a read-request event with an unknown (null) value, + // value will be filled later in the read-response event + val label = ReadAccessLabel( + kind = LabelKind.Request, + location = location, + readValue = NULL_OBJECT_ID, + readModifyWriteDescriptor = readModifyWriteDescriptor, + codeLocation = codeLocation, + ) + return addRequestEvent(iThread, label) + } + + fun addReadResponse(iThread: Int): AtomicThreadEvent { + val readRequest = playedFrontier[iThread].ensure { + it != null && it.label.isRequest && it.label is ReadAccessLabel + } + val (responseEvent, _) = addResponseEvents(readRequest!!) + // TODO: think again --- is it possible that there is no write to read-from? + // Probably not, because in Kotlin variables are always initialized by default? + // What about initialization-related issues? + checkNotNull(responseEvent) + if (isSpinLoopBoundReached(responseEvent)) { + internalThreadSwitchCallback(responseEvent.threadId, SwitchReason.SPIN_BOUND) + } + return responseEvent + } + + fun addLockRequestEvent(iThread: Int, mutex: OpaqueValue, + isReentry: Boolean = false, reentrancyDepth: Int = 1, + isSynthetic: Boolean = false): AtomicThreadEvent { + val label = LockLabel( + kind = LabelKind.Request, + mutexID = objectRegistry.getOrRegisterObjectID(mutex), + isReentry = isReentry, + reentrancyDepth = reentrancyDepth, + isSynthetic = isSynthetic, + ) + return addRequestEvent(iThread, label) + } + + fun addLockResponseEvent(lockRequest: AtomicThreadEvent): AtomicThreadEvent? { + require(lockRequest.label.isRequest && lockRequest.label is LockLabel) + return addResponseEvents(lockRequest).first + } + + fun addUnlockEvent(iThread: Int, mutex: OpaqueValue, + isReentry: Boolean = false, reentrancyDepth: Int = 1, + isSynthetic: Boolean = false): AtomicThreadEvent { + val label = UnlockLabel( + mutexID = objectRegistry.getOrRegisterObjectID(mutex), + isReentry = isReentry, + reentrancyDepth = reentrancyDepth, + isSynthetic = isSynthetic, + ) + return addSendEvent(iThread, label) + } + + fun addWaitRequestEvent(iThread: Int, mutex: OpaqueValue): AtomicThreadEvent { + val label = WaitLabel( + kind = LabelKind.Request, + mutexID = objectRegistry.getOrRegisterObjectID(mutex), + ) + return addRequestEvent(iThread, label) + + } + + fun addWaitResponseEvent(waitRequest: AtomicThreadEvent): AtomicThreadEvent? { + require(waitRequest.label.isRequest && waitRequest.label is WaitLabel) + return addResponseEvents(waitRequest).first + } + + fun addNotifyEvent(iThread: Int, mutex: OpaqueValue, isBroadcast: Boolean): AtomicThreadEvent { + // TODO: we currently ignore isBroadcast flag and handle `notify` similarly as `notifyAll`. + // It is correct wrt. Java's semantics, since `wait` can wake-up spuriously according to the spec. + // Thus multiple wake-ups due to single notify can be interpreted as spurious. + // However, if one day we will want to support wait semantics without spurious wake-ups + // we will need to revisit this. + val label = NotifyLabel( + mutexID = objectRegistry.getOrRegisterObjectID(mutex), + isBroadcast = isBroadcast, + ) + return addSendEvent(iThread, label) + } + + fun addParkRequestEvent(iThread: Int): AtomicThreadEvent { + val label = ParkLabel(LabelKind.Request, iThread) + return addRequestEvent(iThread, label) + } + + fun addParkResponseEvent(parkRequest: AtomicThreadEvent): AtomicThreadEvent? { + require(parkRequest.label.isRequest && parkRequest.label is ParkLabel) + return addResponseEvents(parkRequest).first + } + + fun addUnparkEvent(iThread: Int, unparkingThreadId: Int): AtomicThreadEvent { + val label = UnparkLabel(unparkingThreadId) + return addSendEvent(iThread, label) + } + + fun addCoroutineSuspendRequestEvent(iThread: Int, iActor: Int, promptCancellation: Boolean = false): AtomicThreadEvent { + val label = CoroutineSuspendLabel(LabelKind.Request, iThread, iActor, promptCancellation = promptCancellation) + return addRequestEvent(iThread, label) + } + + fun addCoroutineSuspendResponseEvent(iThread: Int, iActor: Int): AtomicThreadEvent { + val request = getPendingBlockingRequest(iThread)!!.ensure { event -> + event.label.satisfies { actorId == iActor } + } + val (response, events) = addResponseEvents(request) + check(events.size == 1) + return response!! + } + + fun addCoroutineCancelResponseEvent(iThread: Int, iActor: Int): AtomicThreadEvent { + val request = getPendingBlockingRequest(iThread)!!.ensure { event -> + event.label.satisfies { actorId == iActor } + } + val label = (request.label as CoroutineSuspendLabel).getResponse(root.label)!! + tryReplayEvent(iThread)?.let { event -> + check(event.label == label) + addEventToCurrentExecution(event) + return event + } + return createEvent(iThread, label, parent = request, dependencies = listOf(root))!!.also { event -> + addEventToCurrentExecution(event) + } + } + + fun addCoroutineResumeEvent(iThread: Int, iResumedThread: Int, iResumedActor: Int): AtomicThreadEvent? { + val label = CoroutineResumeLabel(iResumedThread, iResumedActor) + for (event in execution) { + if (event in playedFrontier && event.label == label) return null + } + tryReplayEvent(iThread)?.let { event -> + check(event.label == label) + addEventToCurrentExecution(event) + return event + } + val parent = playedFrontier[iThread] + return createEvent(iThread, label, parent, dependencies = listOf())!!.also { event -> + addEventToCurrentExecution(event) + } + } + + fun addActorStartEvent(iThread: Int, actor: Actor): AtomicThreadEvent { + val label = ActorLabel(SpanLabelKind.Start, iThread, actor) + return addEvent(iThread, label, dependencies = listOf()).also { + resetReadCodeLocationsCounter(iThread) + } + } + + fun addActorEndEvent(iThread: Int, actor: Actor): AtomicThreadEvent { + val label = ActorLabel(SpanLabelKind.End, iThread, actor) + return addEvent(iThread, label, dependencies = listOf()) + } + + fun tryReplayRandomEvent(iThread: Int): AtomicThreadEvent? { + tryReplayEvent(iThread)?.let { event -> + check(event.label is RandomLabel) + addEventToCurrentExecution(event) + return event + } + return null + } + + fun addRandomEvent(iThread: Int, generated: Int): AtomicThreadEvent { + val label = RandomLabel(generated) + val parent = playedFrontier[iThread] + return createEvent(iThread, label, parent, dependencies = emptyList())!!.also { event -> + addEventToCurrentExecution(event) + } + } + + /* ************************************************************************* */ + /* Miscellaneous */ + /* ************************************************************************* */ + + /** + * Calculates the view for specific memory location observed at the given point of execution + * given by [observation] vector clock. Memory location view is a vector clock itself + * that maps each thread id to the last write access event to the given memory location at the given thread. + * + * @param location the memory location. + * @param observation the vector clock specifying the point of execution for the view calculation. + * @return the view (i.e. vector clock) for the given memory location. + * + * TODO: move to Execution? + */ + fun calculateMemoryLocationView( + location: MemoryLocation, + observation: ExecutionFrontier + ): ExecutionFrontier = + observation.threadMap.map { (tid, event) -> + val lastWrite = event + ?.ensure { it in execution } + ?.pred(inclusive = true) { + it.label.asMemoryAccessLabel(location)?.takeIf { label -> label.isWrite } != null + } + (tid to lastWrite as? AtomicThreadEvent?) + }.let { + executionFrontierOf(*it.toTypedArray()) + } + + /** + * Calculates a list of all racy writes to specific memory location observed at the given point of execution + * given by [observation] vector clock. In other words, the resulting list contains all program-order maximal + * racy writes observed at the given point. + * + * @param location the memory location. + * @param observation the vector clock specifying the point of execution for the view calculation. + * @return list of program-order maximal racy write events. + * + * TODO: move to Execution? + */ + fun calculateRacyWrites( + location: MemoryLocation, + observation: ExecutionFrontier + ): List { + val writes = calculateMemoryLocationView(location, observation).events + return writes.filter { write -> + !writes.any { other -> + causalityOrder(write, other) + } + } + } + + private fun isSpinLoopBoundReached(event: ThreadEvent): Boolean { + check(event.label is ReadAccessLabel && event.label.isResponse) + val readLabel = (event.label as ReadAccessLabel) + val location = readLabel.location + val readValue = readLabel.readValue + val codeLocation = readLabel.codeLocation + // check code locations counter to detect spin-loop + val counter = readCodeLocationsCounter.compute(event.threadId to codeLocation) { _, count -> + 1 + (count ?: 0) + }!! + // a potential spin-loop occurs when we have visited the same code location more than N times + if (counter < SPIN_BOUND) + return false + // if the last 3 reads with the same code location read the same value, + // then we consider this a spin-loop + var spinEvent: ThreadEvent = event + var spinCounter = SPIN_BOUND + while (spinCounter-- > 0) { + spinEvent = spinEvent.pred { + it.label.isResponse && it.label.satisfies { + this.location == location && this.codeLocation == codeLocation + } + } ?: return false + if ((spinEvent.label as ReadAccessLabel).readValue != readValue) + return false + } + return true + } + + private fun resetReadCodeLocationsCounter(iThread: Int) { + // reset all code-locations counters of the given thread + readCodeLocationsCounter.keys.retainAll { (tid, _) -> tid != iThread } + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructureStrategy.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructureStrategy.kt new file mode 100644 index 000000000..b7674f541 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventStructureStrategy.kt @@ -0,0 +1,849 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.execution.* +import org.jetbrains.kotlinx.lincheck.runner.* +import org.jetbrains.kotlinx.lincheck.strategy.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.verifier.* +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.util.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingCTestConfiguration +import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent +import sun.nio.ch.lincheck.TestThread +import java.lang.reflect.* +import org.objectweb.asm.Type + +class EventStructureStrategy( + testCfg: ModelCheckingCTestConfiguration, + testClass: Class<*>, + scenario: ExecutionScenario, + validationFunction: Actor?, + stateRepresentation: Method? +) : ManagedStrategy(testClass, scenario, validationFunction, stateRepresentation, testCfg) { + + private val memoryInitializer: MemoryInitializer = { location -> + runInIgnoredSection { + location.read(eventStructure.objectRegistry::getValue)?.opaque() + } + } + + private val eventStructure: EventStructure = + EventStructure(nThreads, memoryInitializer, ::onInconsistency) { iThread, reason -> + switchCurrentThread(iThread, reason, mustSwitch = true) + } + + private var isTestInstanceRegistered = false + + // Tracker of objects. + override val objectTracker: ObjectTracker = + EventStructureObjectTracker(eventStructure) + // Tracker of shared memory accesses. + override val memoryTracker: MemoryTracker = + EventStructureMemoryTracker(eventStructure, objectTracker) + // Tracker of monitors operations. + override val monitorTracker: MonitorTracker = + EventStructureMonitorTracker(eventStructure, eventStructure.objectRegistry) + // Tracker of thread parking + override val parkingTracker: ParkingTracker = + EventStructureParkingTracker(eventStructure) + + override val trackFinalFields: Boolean = true + + val stats = Stats() + + override fun shouldInvokeBeforeEvent(): Boolean { + // TODO: fixme + return false + } + + override fun nextInvocation(): Boolean { + // check that we have the next invocation to explore + return eventStructure.startNextExploration() + } + + override fun initializeInvocation() { + super.initializeInvocation() + isTestInstanceRegistered = false + eventStructure.initializeExploration() + } + + override fun runInvocationImpl(): InvocationResult { + val (result, inconsistency) = runNextExploration() + if (inconsistency != null) { + return InconsistentInvocationResult(inconsistency) + } + check(result != null) + // TODO: re-verify that it is safe to omit the memory dump at the end; + // it should be safe, because currently in the event-structure based algorithm, + // the intercepted writes are still performed, so the actual state of the memory + // reflects the state modelled by the current execution graph. + // runInIgnoredSection { + // memoryTracker.dumpMemory() + // } + return result + } + + // TODO: rename & refactor! + fun runNextExploration(): Pair { + var result: InvocationResult? = null + var inconsistency: Inconsistency? = eventStructure.checkConsistency() + if (inconsistency == null) { + eventStructure.addThreadStartEvent(eventStructure.mainThreadId) + result = super.runInvocationImpl() + // if invocation was aborted, we also abort the current execution inside event structure + if (result.isAbortedInvocation()) { + eventStructure.abortExploration() + } + // patch clocks + if (result is CompletedInvocationResult) { + val patchedResult = patchResultsClock(eventStructure.execution, result.results) + result = CompletedInvocationResult(patchedResult) + } + inconsistency = when (result) { + is InconsistentInvocationResult -> result.inconsistency + is SpinLoopBoundInvocationResult -> null + else -> eventStructure.checkConsistency() + } + } + + // println(eventStructure.execution) + // println("inconsistency: $inconsistency") + // println() + + stats.update(result, inconsistency) + // println(stats.totalInvocations) + return (result to inconsistency) + } + + // TODO: temporarily disable trace collection for event structure strategy + override fun tryCollectTrace(result: InvocationResult): Trace? { + // return super.tryCollectTrace(result) + return null + } + + class Stats { + + var consistentInvocations: Int = 0 + private set + + var blockedInvocations: Int = 0 + private set + + private var lockConsistencyViolationCount: Int = 0 + + private var atomicityInconsistenciesCount: Int = 0 + + private var relAcqInconsistenciesCount: Int = 0 + + private var seqCstApproximationInconsistencyCount: Int = 0 + + private var seqCstCoherenceViolationCount: Int = 0 + + private var seqCstReplayViolationCount: Int = 0 + + private val sequentialConsistencyViolationsCount: Int + get() = + seqCstApproximationInconsistencyCount + + seqCstCoherenceViolationCount + + seqCstReplayViolationCount + + val inconsistentInvocations: Int + get() = + lockConsistencyViolationCount + + atomicityInconsistenciesCount + + relAcqInconsistenciesCount + + sequentialConsistencyViolationsCount + + val totalInvocations: Int + get() = consistentInvocations + inconsistentInvocations + blockedInvocations + + fun update(result: InvocationResult?, inconsistency: Inconsistency?) { + if (result is SpinLoopBoundInvocationResult) { + check(inconsistency == null) + blockedInvocations++ + return + } + if (inconsistency == null) { + consistentInvocations++ + return + } + when(inconsistency) { + is LockConsistencyViolation -> + lockConsistencyViolationCount++ + is ReadModifyWriteAtomicityViolation -> + atomicityInconsistenciesCount++ + is ReleaseAcquireInconsistency -> + relAcqInconsistenciesCount++ + is SequentialConsistencyApproximationInconsistency -> + seqCstApproximationInconsistencyCount++ + is CoherenceViolation -> + seqCstCoherenceViolationCount++ + is SequentialConsistencyReplayViolation -> + seqCstReplayViolationCount++ + } + } + + override fun toString(): String = """ + #Total invocations = $totalInvocations + #consistent = $consistentInvocations + #inconsistent = $inconsistentInvocations + #blocked = $blockedInvocations + #Lock violations = $lockConsistencyViolationCount + #Atom. violations = $atomicityInconsistenciesCount + #RelAcq violations = $relAcqInconsistenciesCount + #SeqCst violations = $sequentialConsistencyViolationsCount + #approx. phase = $seqCstApproximationInconsistencyCount + #coher. phase = $seqCstCoherenceViolationCount + #replay phase = $seqCstReplayViolationCount + """.trimIndent() + + } + + // a hack to reset happens-before clocks computed by scheduler, + // because these clocks can be not in sync with with + // happens-before relation constructed by the event structure + // TODO: refactor this --- we need a more robust solution; + // for example, we can compute happens before relation induced by + // the event structure and pass it on + private fun patchResultsClock(execution: Execution, executionResult: ExecutionResult): ExecutionResult { + val initPartSize = executionResult.initResults.size + val postPartSize = executionResult.postResults.size + val hbClockSize = executionResult.parallelResultsWithClock.size + val patchedParallelResults = executionResult.parallelResultsWithClock + .map { it.map { resultWithClock -> ResultWithClock(resultWithClock.result, resultWithClock.clockOnStart) }} + val (actorsExecution, _) = execution.aggregate(ActorAggregator(execution)) + check(actorsExecution.threadIDs.size == hbClockSize + 1) + for (tid in patchedParallelResults.indices) { + var actorEvents: List = actorsExecution[tid]!! + // cut init/post part + if (tid == 0) { + actorEvents = actorEvents.subList( + fromIndex = initPartSize, + toIndex = actorEvents.size - postPartSize, + ) + } + val actorResults = patchedParallelResults[tid] + actorResults.forEachIndexed { i, result -> + val actorEvent = actorEvents.getOrNull(i) + val prevHBClock = actorResults.getOrNull(i - 1)?.clockOnStart?.copy() + ?: emptyClock(hbClockSize) + val clockSize = result.clockOnStart.clock.size + val hbClock = actorEvent?.causalityClock?.toHBClock(clockSize, tid, i) + ?: prevHBClock.apply { clock[tid] = i } + // cut init part actors + hbClock.clock[0] -= initPartSize + check(hbClock[tid] == i) + result.clockOnStart.set(hbClock) + } + } + return ExecutionResult( + initResults = executionResult.initResults, + parallelResultsWithClock = patchedParallelResults, + postResults = executionResult.postResults, + afterInitStateRepresentation = executionResult.afterInitStateRepresentation, + afterParallelStateRepresentation = executionResult.afterParallelStateRepresentation, + afterPostStateRepresentation = executionResult.afterPostStateRepresentation, + ) + } + + override fun shouldSwitch(iThread: Int): ThreadSwitchDecision { + // If strategy is in replay phase we first need to execute replaying threads + if (eventStructure.inReplayPhase() && !eventStructure.inReplayPhase(iThread)) { + return ThreadSwitchDecision.MAY + } + // If strategy is in replay mode for given thread + // we should wait until replaying the next event become possible + // (i.e. when all the dependencies will be replayed too) + if (eventStructure.inReplayPhase(iThread)) { + return if (eventStructure.canReplayNextEvent(iThread)) + ThreadSwitchDecision.NOT + else + ThreadSwitchDecision.MUST + } + + /* For event structure strategy enforcing context switches is not necessary, + * because it is guaranteed that the strategy will explore all + * executions anyway, no matter of the order of context switches. + * Thus, in principle it is possible to explore threads in fixed static order, + * (e.g. always return false here). + * In practice, however, the order of context switches may influence performance + * of the model checking, and time-to-first-bug-discovered metric. + * Thus we might want to customize scheduling strategy. + * TODO: make scheduling strategy configurable + */ + return ThreadSwitchDecision.NOT + } + + override fun chooseThread(iThread: Int): Int { + // see comment in `shouldSwitch` method + // TODO: make scheduling strategy configurable + return if (runner.currentExecutionPart == ExecutionPart.PARALLEL) + switchableThreads(iThread).first() + else + eventStructure.mainThreadId + } + + override fun isActive(iThread: Int): Boolean { + return super.isActive(iThread) && (eventStructure.inReplayPhase() implies { + eventStructure.inReplayPhase(iThread) && eventStructure.canReplayNextEvent(iThread) + }) + } + + override fun beforePart(part: ExecutionPart) { + super.beforePart(part) + val forkedThreads = (0 until eventStructure.nThreads) + .filter { it != eventStructure.mainThreadId && it != eventStructure.initThreadId } + .toSet() + when (part) { + ExecutionPart.INIT -> { + registerTestInstance() + } + ExecutionPart.PARALLEL -> { + if (!isTestInstanceRegistered) { + registerTestInstance() + } + if (forkedThreads.isNotEmpty()) { + eventStructure.addThreadForkEvent(eventStructure.mainThreadId, forkedThreads) + } + } + ExecutionPart.POST -> { + if (forkedThreads.isNotEmpty()) { + eventStructure.addThreadJoinEvent(eventStructure.mainThreadId, forkedThreads) + } + } + else -> {} + } + } + + private fun registerTestInstance() { + check(!isTestInstanceRegistered) + eventStructure.addObjectAllocationEvent(eventStructure.mainThreadId, runner.testInstance.opaque()) + isTestInstanceRegistered = true + } + + override fun onStart(iThread: Int) { + super.onStart(iThread) + if (iThread != eventStructure.mainThreadId && iThread != eventStructure.initThreadId) { + eventStructure.addThreadStartEvent(iThread) + } + } + + override fun onFinish(iThread: Int) { + // TODO: refactor, make `switchCurrentThread` private again in ManagedStrategy, + // call overridden `onStart` and `onFinish` methods only when thread is active + // and the `currentThread` lock is held + awaitTurn(iThread) + // TODO: extract this check into a method ? + while (eventStructure.inReplayPhase() && !eventStructure.canReplayNextEvent(iThread)) { + switchCurrentThread(iThread, mustSwitch = true) + } + eventStructure.addThreadFinishEvent(iThread) + super.onFinish(iThread) + } + + override fun onActorStart(iThread: Int) { + super.onActorStart(iThread) + // TODO: move ignored section to ManagedStrategyRunner + runInIgnoredSection { + if (runner.currentExecutionPart == ExecutionPart.VALIDATION) + return@runInIgnoredSection + val actor = scenario.threads[iThread][currentActorId[iThread]] + eventStructure.addActorStartEvent(iThread, actor) + } + } + + override fun onActorFinish(iThread: Int) { + // TODO: move ignored section to ManagedStrategyRunner + runInIgnoredSection { + if (runner.currentExecutionPart == ExecutionPart.VALIDATION) + return@runInIgnoredSection + val actor = scenario.threads[iThread][currentActorId[iThread]] + eventStructure.addActorEndEvent(iThread, actor) + } + super.onActorFinish(iThread) + } + + private fun onInconsistency(inconsistency: Inconsistency) { + suddenInvocationResult = InconsistentInvocationResult(inconsistency) + throw ForcibleExecutionFinishError + } + + override fun afterCoroutineSuspended(iThread: Int) { + eventStructure.addCoroutineSuspendRequestEvent(iThread, currentActorId[iThread]) + super.afterCoroutineSuspended(iThread) + } + + override fun afterCoroutineResumed(iThread: Int) { + super.afterCoroutineResumed(iThread) + eventStructure.addCoroutineSuspendResponseEvent(iThread, currentActorId[iThread]) + } + + override fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, cancellationResult: CancellationResult) { + super.afterCoroutineCancelled(iThread, promptCancellation, cancellationResult) + if (cancellationResult == CancellationResult.CANCELLATION_FAILED) + return + eventStructure.addCoroutineSuspendRequestEvent(iThread, currentActorId[iThread], promptCancellation) + eventStructure.addCoroutineCancelResponseEvent(iThread, currentActorId[iThread]) + } + + override fun onResumeCoroutine(iThread: Int, iResumedThread: Int, iResumedActor: Int) { + super.onResumeCoroutine(iThread, iResumedThread, iResumedActor) + eventStructure.addCoroutineResumeEvent(iThread, iResumedThread, iResumedActor) + } + + override fun isCoroutineResumed(iThread: Int, iActor: Int): Boolean { + if (!super.isCoroutineResumed(iThread, iActor)) + return false + val resumeEvent = eventStructure.execution.find { + it.label.satisfies { threadId == iThread && actorId == iActor } + } + return (resumeEvent != null) + } +} + +typealias ReportInconsistencyCallback = (Inconsistency) -> Unit +typealias InternalThreadSwitchCallback = (ThreadID, SwitchReason) -> Unit + +private class EventStructureObjectTracker( + private val eventStructure: EventStructure, +) : ObjectTracker { + + override fun registerNewObject(obj: Any) { + val iThread = (Thread.currentThread() as TestThread).threadId + eventStructure.addObjectAllocationEvent(iThread, obj.opaque()) + } + + override fun registerObjectLink(fromObject: Any, toObject: Any?) {} + + override fun initializeObject(obj: Any) { + val isRegistered = (eventStructure.objectRegistry[obj.opaque()] != null) + if (!isRegistered && !obj.isPrimitive()) { + registerNewObject(obj) + } + } + + override fun shouldTrackObjectAccess(obj: Any): Boolean = true + + override fun getObjectId(obj: Any): ObjectID { + return eventStructure.objectRegistry.getOrRegisterObjectID(obj.opaque()) + } + + override fun reset() {} + +} + +private class EventStructureMemoryTracker( + private val eventStructure: EventStructure, + private val objectTracker: ObjectTracker, +) : MemoryTracker { + + private val objectRegistry: ObjectRegistry + get() = eventStructure.objectRegistry + + private fun getValueID(location: MemoryLocation, value: OpaqueValue?): ValueID = + objectRegistry.getOrRegisterValueID(location.type, value) + + private fun getValue(location: MemoryLocation, valueID: ValueID) = + objectRegistry.getValue(location.type, valueID) + + private fun getValue(type: Type, valueID: ValueID) = + objectRegistry.getValue(type, valueID) + + private fun addWriteEvent(iThread: Int, codeLocation: Int, location: MemoryLocation, value: OpaqueValue?, + rmwWriteDescriptor: ReadModifyWriteDescriptor? = null) { + // force evaluation of initial value (before possibly overwriting it) + // TODO: refactor this! + eventStructure.allocationEvent(location.objID)?.label?.asWriteAccessLabel(location) + eventStructure.addWriteEvent(iThread, codeLocation, location, getValueID(location, value), rmwWriteDescriptor) + } + + private fun addReadRequest(iThread: Int, codeLocation: Int, location: MemoryLocation, + readModifyWriteDescriptor: ReadModifyWriteDescriptor? = null) { + eventStructure.addReadRequest(iThread, codeLocation, location, readModifyWriteDescriptor) + } + + private fun addReadResponse(iThread: Int): OpaqueValue? { + val event = eventStructure.addReadResponse(iThread) + val label = (event.label as ReadAccessLabel) + val rmwDescriptor = label.readModifyWriteDescriptor + // regular non-RMW read - return the read value + if (rmwDescriptor == null) { + return getValue(label.location, label.value) + } + // handle different kinds of RMWs + // TODO: perform actual write to memory for successful CAS + when (rmwDescriptor) { + is ReadModifyWriteDescriptor.GetAndSetDescriptor -> { + val newValueID = rmwDescriptor.newValue + val newValue = getValue(label.location, newValueID) + eventStructure.addWriteEvent(iThread, label.codeLocation, label.location, newValueID, rmwDescriptor) + label.location.write(newValue?.unwrap(), objectRegistry::getValue) + return getValue(label.location, label.value) + } + + is ReadModifyWriteDescriptor.CompareAndSetDescriptor -> { + if (label.value == rmwDescriptor.expectedValue) { + val newValueID = rmwDescriptor.newValue + val newValue = getValue(label.location, newValueID) + eventStructure.addWriteEvent(iThread, label.codeLocation, label.location, newValueID, rmwDescriptor) + label.location.write(newValue?.unwrap(), objectRegistry::getValue) + return getValue(Type.BOOLEAN_TYPE, true.toInt().toLong()) + } + return getValue(Type.BOOLEAN_TYPE, false.toInt().toLong()) + } + + is ReadModifyWriteDescriptor.CompareAndExchangeDescriptor -> { + if (label.value == rmwDescriptor.expectedValue) { + val newValueID = rmwDescriptor.newValue + val newValue = getValue(label.location, newValueID) + eventStructure.addWriteEvent(iThread, label.codeLocation, label.location, newValueID, rmwDescriptor) + label.location.write(newValue?.unwrap(), objectRegistry::getValue) + } + return getValue(label.location, label.value) + } + + is ReadModifyWriteDescriptor.FetchAndAddDescriptor -> { + val newValueID = label.value + rmwDescriptor.delta + val newValue = getValue(label.location, newValueID) + eventStructure.addWriteEvent(iThread, label.codeLocation, label.location, newValueID, rmwDescriptor) + label.location.write(newValue?.unwrap(), objectRegistry::getValue) + return when (rmwDescriptor.kind) { + ReadModifyWriteDescriptor.IncrementKind.Pre -> getValue(label.location, label.value) + ReadModifyWriteDescriptor.IncrementKind.Post -> getValue(label.location, newValueID) + } + } + } + } + + override fun beforeWrite(iThread: Int, codeLocation: Int, location: MemoryLocation, value: Any?) { + addWriteEvent(iThread, codeLocation, location, value?.opaque()) + } + + override fun beforeRead(iThread: Int, codeLocation: Int, location: MemoryLocation) { + addReadRequest(iThread, codeLocation, location) + } + + override fun beforeGetAndSet(iThread: Int, codeLocation: Int, location: MemoryLocation, newValue: Any?) { + eventStructure.addReadRequest(iThread, codeLocation, location, + readModifyWriteDescriptor = ReadModifyWriteDescriptor.GetAndSetDescriptor( + newValue = getValueID(location, newValue?.opaque()) + ) + ) + } + + override fun beforeCompareAndSet(iThread: Int, codeLocation: Int, location: MemoryLocation, expectedValue: Any?, newValue: Any?) { + eventStructure.addReadRequest(iThread, codeLocation, location, + readModifyWriteDescriptor = ReadModifyWriteDescriptor.CompareAndSetDescriptor( + expectedValue = getValueID(location, expectedValue?.opaque()), + newValue = getValueID(location, newValue?.opaque()), + ) + ) + } + + override fun beforeCompareAndExchange(iThread: Int, codeLocation: Int, location: MemoryLocation, expectedValue: Any?, newValue: Any?) { + eventStructure.addReadRequest(iThread, codeLocation, location, + readModifyWriteDescriptor = ReadModifyWriteDescriptor.CompareAndExchangeDescriptor( + expectedValue = getValueID(location, expectedValue?.opaque()), + newValue = getValueID(location, newValue?.opaque()), + ) + ) + } + + override fun beforeGetAndAdd(iThread: Int, codeLocation: Int, location: MemoryLocation, delta: Number) { + eventStructure.addReadRequest(iThread, codeLocation, location, + readModifyWriteDescriptor = ReadModifyWriteDescriptor.FetchAndAddDescriptor( + delta = getValueID(location, delta.opaque()), + kind = ReadModifyWriteDescriptor.IncrementKind.Pre, + ) + ) + } + + override fun beforeAddAndGet(iThread: Int, codeLocation: Int, location: MemoryLocation, delta: Number) { + eventStructure.addReadRequest(iThread, codeLocation, location, + readModifyWriteDescriptor = ReadModifyWriteDescriptor.FetchAndAddDescriptor( + delta = getValueID(location, delta.opaque()), + kind = ReadModifyWriteDescriptor.IncrementKind.Post, + ) + ) + } + + override fun interceptReadResult(iThread: Int): Any? { + return addReadResponse(iThread)?.unwrap()?.also { + LincheckJavaAgent.ensureObjectIsTransformed(it) + } + } + + override fun interceptArrayCopy(iThread: Int, codeLocation: Int, srcArray: Any?, srcPos: Int, dstArray: Any?, dstPos: Int, length: Int) { + val srcType = srcArray!!::class.getType() + val dstType = dstArray!!::class.getType() + for (i in 0 until length) { + val readLocation = objectTracker.getArrayAccessMemoryLocation(srcArray, srcPos + i, srcType) + val writeLocation = objectTracker.getArrayAccessMemoryLocation(dstArray, dstPos + i, dstType) + val value = run { + beforeRead(iThread, codeLocation, readLocation) + interceptReadResult(iThread) + } + beforeWrite(iThread, codeLocation, writeLocation, value) + writeLocation.write(value, objectRegistry::getValue) + } + } + + override fun reset() {} + +} + +private class EventStructureMonitorTracker( + private val eventStructure: EventStructure, + private val objectRegistry: ObjectRegistry, +) : MonitorTracker { + + // for each mutex object acquired by some thread, + // this map stores a mapping from the mutex object to the lock-response event; + // to handle lock re-entrance, we actually store a stack of lock-response events + private val lockStacks = mutableMapOf>() + + // for threads waiting on the mutex, + // stores the lock stack of the current thread for the awaited mutex + private val waitLockStack = ArrayIntMap>(eventStructure.nThreads) + + private fun canAcquireMonitor(iThread: Int, mutexID: ValueID): Boolean { + val lockStack = lockStacks[mutexID] + return (lockStack == null) || (lockStack.last().threadId == iThread) + } + + private fun canAcquireMonitor(iThread: Int, monitor: OpaqueValue): Boolean { + val mutexID = objectRegistry[monitor]!!.id + return canAcquireMonitor(iThread, mutexID) + } + + override fun acquireMonitor(iThread: Int, monitor: Any): Boolean { + // issue lock-request event + val lockRequest = issueLockRequest(iThread, monitor.opaque()) + // if lock is acquired by another thread then postpone addition of lock-response event + if (!canAcquireMonitor(iThread, monitor.opaque())) + return false + // try to add lock-response event + val lockResponse = tryCompleteLockResponse(lockRequest) + // return true if the lock-response event was created successfully + return (lockResponse != null) + } + + override fun releaseMonitor(iThread: Int, monitor: Any) { + issueUnlock(iThread, monitor.opaque()) + } + + private fun issueLockRequest(iThread: Int, monitor: OpaqueValue): AtomicThreadEvent { + val mutexID = objectRegistry[monitor]!!.id + // check if the thread is already blocked on the lock-request + val blockingRequest = eventStructure.getPendingBlockingRequest(iThread) + ?.ensure { it.label.satisfies { this.mutexID == mutexID } } + if (blockingRequest != null) + return blockingRequest + // check if it is a re-entrance lock and obtain lock re-entrance depth + val lockStack = lockStacks[mutexID] + ?.ensure { it.isNotEmpty() } + ?.takeIf { it.last().threadId == iThread } + val depth = lockStack?.size ?: 0 + // finally, add the new lock-request + return eventStructure.addLockRequestEvent(iThread, monitor, + isReentry = depth > 0, + reentrancyDepth = 1 + depth, + ) + } + + private fun tryCompleteLockResponse(lockRequest: AtomicThreadEvent): AtomicThreadEvent? { + val mutexID = (lockRequest.label as LockLabel).mutexID + // try to add lock-response event + return eventStructure.addLockResponseEvent(lockRequest)?.also { lockResponse -> + // if lock-response was added successfully, then push it to the lock stack + lockStacks.updateInplace(mutexID, default = mutableListOf()) { + check(isNotEmpty() implies { last().threadId == lockResponse.threadId }) + add(lockResponse) + } + } + } + + private fun issueUnlock(iThread: Int, monitor: OpaqueValue): AtomicThreadEvent { + val mutexID = objectRegistry[monitor]!!.id + // obtain current lock-responses stack, and ensure that + // the lock is indeed acquired by the releasing thread + val lockStack = lockStacks[mutexID]!! + .ensure { it.isNotEmpty() && (it.last().threadId == iThread) } + val depth = lockStack.size + // add unlock event to the event structure + return eventStructure.addUnlockEvent(iThread, monitor, + isReentry = (depth > 1), + reentrancyDepth = depth, + ).also { + // remove last lock-response event from the stack, + // since we just released the lock one time + lockStack.removeLast() + if (lockStack.isEmpty()) { + lockStacks.remove(mutexID) + } + } + } + + override fun isWaiting(iThread: Int): Boolean { + val blockingRequest = eventStructure.getPendingBlockingRequest(iThread) + ?.takeIf { (it.label is LockLabel || it.label is WaitLabel) } + ?: return false + val mutexID = (blockingRequest.label as MutexLabel).mutexID + return !(eventStructure.isPendingUnblockedRequest(blockingRequest) && + canAcquireMonitor(iThread, mutexID)) + } + + override fun waitOnMonitor(iThread: Int, monitor: Any): Boolean { + val mutexID = objectRegistry[monitor.opaque()]!!.id + // check if the thread is already blocked on wait-request or (synthetic) lock-request + val blockingRequest = eventStructure.getPendingBlockingRequest(iThread) + ?.ensure { it.label.satisfies { this.mutexID == mutexID } } + ?.ensure { it.label is LockLabel || it.label is WaitLabel } + var waitRequest = blockingRequest?.takeIf { it.label is WaitLabel } + var lockRequest = blockingRequest?.takeIf { it.label is LockLabel } + // if the thread is not blocked yet, issue wait-request event; + // this procedure will also add synthetic unlock event + if (blockingRequest == null) { + check(waitLockStack[iThread] == null) + waitRequest = issueWaitRequest(iThread, monitor.opaque()) + } + // if the wait-request was already issued, try to complete it by wait-response; + // this procedure will also add synthetic lock-request event + if (waitRequest != null) { + val (_, _lockRequest) = tryCompleteWaitResponse(monitor.opaque(), waitRequest) + ?: return true + lockRequest = _lockRequest + } + // finally, check that the thread can acquire the lock back, + // and try to complete the lock-request by lock-response + check(lockRequest != null) + if (!canAcquireMonitor(iThread, mutexID)) + return true + val lockResponse = tryCompleteWaitLockResponse(lockRequest) + // exit waiting if the lock response was added successfully + return (lockResponse == null) + } + + override fun notify(iThread: Int, monitor: Any, notifyAll: Boolean) { + issueNotify(iThread, monitor.opaque(), notifyAll) + } + + private fun issueWaitRequest(iThread: Int, monitor: OpaqueValue): AtomicThreadEvent { + val mutexID = objectRegistry[monitor]!!.id + // obtain the current lock-responses stack, and ensure that + // the lock is indeed acquired by the waiting thread + val lockStack = lockStacks[mutexID]!! + .ensure { it.isNotEmpty() && (it.last().threadId == iThread) } + val depth = lockStack.size + // add synthetic unlock event to release the mutex + eventStructure.addUnlockEvent(iThread, monitor, + isSynthetic = true, + isReentry = false, + reentrancyDepth = depth, + ) + // save the lock-responses stack to restore it later + waitLockStack[iThread] = lockStack + lockStacks.remove(mutexID) + // add the new wait-request + return eventStructure.addWaitRequestEvent(iThread, monitor) + } + + private fun tryCompleteWaitResponse(monitor: OpaqueValue, waitRequest: AtomicThreadEvent): Pair? { + require(waitRequest.label.isRequest) + require(waitRequest.label is WaitLabel) + val mutexID = (waitRequest.label as WaitLabel).mutexID + // try to complete wait-response + val waitResponse = eventStructure.addWaitResponseEvent(waitRequest) + ?: return null + val unlockEvent = waitRequest.parent!! + check(unlockEvent.label.satisfies { + this.mutexID == mutexID && isSynthetic + }) + // issue synthetic lock-request to acquire the mutex back + val iThread = waitRequest.threadId + val depth = (unlockEvent.label as UnlockLabel).reentrancyDepth + val lockRequest = eventStructure.addLockRequestEvent(iThread, monitor, + isSynthetic = true, + isReentry = false, + reentrancyDepth = depth, + ) + return (waitResponse to lockRequest) + } + + private fun tryCompleteWaitLockResponse(lockRequest: AtomicThreadEvent): AtomicThreadEvent? { + val iThread = lockRequest.threadId + val mutexID = (lockRequest.label as LockLabel).mutexID + // try to add lock-response event + return eventStructure.addLockResponseEvent(lockRequest)?.also { + // if lock-response was added successfully, then restore + // the lock stack of the acquiring thread + val lockStack = waitLockStack[iThread]!! + lockStacks.put(mutexID, lockStack).ensureNull() + waitLockStack.remove(iThread) + } + } + + private fun issueNotify(iThread: Int, monitor: OpaqueValue, notifyAll: Boolean) { + eventStructure.addNotifyEvent(iThread, monitor, notifyAll) + } + + override fun reset() { + lockStacks.clear() + waitLockStack.clear() + } + +} + +private class EventStructureParkingTracker( + private val eventStructure: EventStructure, +) : ParkingTracker { + + override fun park(iThread: Int) { + eventStructure.addParkRequestEvent(iThread) + } + + override fun waitUnpark(iThread: Int): Boolean { + val parkRequest = eventStructure.getPendingBlockingRequest(iThread) + ?.takeIf { it.label is ParkLabel } + ?: return false + val parkResponse = eventStructure.addParkResponseEvent(parkRequest) + return (parkResponse == null) + } + + override fun unpark(iThread: Int, unparkedThreadId: Int) { + eventStructure.addUnparkEvent(iThread, unparkedThreadId) + } + + override fun isParked(iThread: Int): Boolean { + val blockingRequest = eventStructure.getPendingBlockingRequest(iThread) + ?.takeIf { it.label is ParkLabel } + ?: return false + return !eventStructure.isPendingUnblockedRequest(blockingRequest) + } + + override fun reset() {} + +} + +internal const val SPIN_BOUND = 5 \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventSynchronization.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventSynchronization.kt new file mode 100644 index 000000000..cb911930d --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/EventSynchronization.kt @@ -0,0 +1,544 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* + +/** + * Synchronization algebra describes how event labels can synchronize to form new labels. + * For example, write access label can synchronize with read-request label to form a read-response label. + * The response label takes the value written by the write access as its read value. + * + * When appropriate, we use notation `\+` to denote synchronization binary operation: + * + * ``` + * Write(x, v) \+ Read^{req}(x) = Read^{rsp}(x, v) + * ``` + * + * Synchronize operation is expected to be associative + * (and commutative in case of [CommutativeSynchronizationAlgebra]). + * It is a partial operation --- some labels cannot participate in synchronization + * (e.g., write access label cannot synchronize with another write access label). + * In such cases, the synchronization operation returns null: + * + * ``` + * Write(x, v) \+ Write(y, u) = null + * ``` + * + * In case when a pair of labels can synchronize, we also say that they are synchronizable. + * Given a pair of synchronizable labels, we say that these labels synchronize into the synchronization result label. + * We use notation `\>>` to denote the synchronize-into relation and `<> C` and `B \>> C`. + * + * The [synchronize] method should implement synchronization operation. + * It is not obligatory to override [synchronizable] method, which checks if a pair of labels is synchronization. + * This is because the default implementation is guaranteed to be consistent with [synchronize] + * (it just checks that result of [synchronize] is not null). + * However, the overridden implementation can optimize this check. + * + * **Note**: formally, synchronization algebra is the special algebraic structure + * deriving from partial commutative monoid [1]. + * The synchronizes-into relation corresponds to the irreflexive kernel of + * the divisibility pre-order associated with the synchronization monoid. + * + * [[1]] "Event structure semantics for CCS and related languages." + * _Glynn Winskel._ + * _International Colloquium on Automata, Languages, and Programming._ + * _Springer, Berlin, Heidelberg, 1982._ + * + */ +interface SynchronizationAlgebra { + + /** + * The synchronization type of this label. + * + * @see SynchronizationType + */ + fun syncType(label: EventLabel): SynchronizationType? + + /** + * Synchronizes two event labels. + * + * @return label representing the result of synchronization + * or null if this label cannot synchronize with [label]. + */ + fun synchronize(label: EventLabel, other: EventLabel): EventLabel? + + /** + * Checks whether two labels can synchronize. + * Default implementation just checks that result of [synchronize] is not null, + * overridden implementation can optimize this check. + */ + fun synchronizable(label: EventLabel, other: EventLabel): Boolean = + (synchronize(label, other) != null) + + /* TODO: make synchronization algebras cancellative and splittable PCM? + * With these properties we can define `split` function that returns + * a unique decomposition of any given label. + * Then we can derive implementation of `synchronizesInto` function. + * To do this we need to guarantee unique decomposition, + * currently it does not always hold + * (e.g. because of InitializationLabel synchronizing with ReadLabel). + * We need to apply some tricks to overcome this. + */ +} + +/** + * Synchronizes two nullable event labels. + * + * @return the label representing the result of synchronization, + * or null if the labels cannot synchronize, or one of the labels is null. + */ +fun SynchronizationAlgebra.synchronize(label: EventLabel?, other: EventLabel?): EventLabel? = when { + label == null -> other + other == null -> label + else -> synchronize(label, other) +} + +/** + * Synchronizes a list of events using the provided synchronization algebra. + * + * @param events the list of events which labels need to be synchronized. + * @return the label representing the result of synchronization, + * or null if event labels are not synchronizable, or the list of events is empty. + */ +fun SynchronizationAlgebra.synchronize(events: List): EventLabel? { + if (events.isEmpty()) + return null + return events.fold (null) { label: EventLabel?, event -> + synchronize(label, event.label) + } +} + +/** + * A commutative synchronization algebra is a synchronization algebra + * whose [synchronize] operation is expected to be commutative. + * + * @see SynchronizationAlgebra + */ +interface CommutativeSynchronizationAlgebra : SynchronizationAlgebra + +/** + * Constructs commutative synchronization algebra derived from the non-commutative one + * by trying to apply [synchronize] operation in both directions. + */ +fun CommutativeSynchronizationAlgebra(algebra: SynchronizationAlgebra) = object : CommutativeSynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = + algebra.syncType(label) + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = + algebra.synchronize(label, other) ?: algebra.synchronize(other, label) + +} + +/** + * Checks whether this label has binary synchronization. + */ +fun SynchronizationAlgebra.isBinarySynchronizing(label: EventLabel): Boolean = + (syncType(label) == SynchronizationType.Binary) + +/** + * Checks whether this label has barrier synchronization. + */ +fun SynchronizationAlgebra.isBarrierSynchronizing(label: EventLabel): Boolean = + (syncType(label) == SynchronizationType.Barrier) + +/** + * Type of synchronization used by label. + * Currently, two types of synchronization are supported. + * + * - [SynchronizationType.Binary] binary synchronization --- only a pair of events can synchronize. + * For example, write access label can synchronize with read-request label, + * but the resulting read-response label can no longer synchronize with any other label. + * + * - [SynchronizationType.Barrier] barrier synchronization --- a set of events can synchronize. + * For example, several thread finish labels can synchronize with a single thread + * join-request label waiting for all of these threads to complete. + * + * @see [EventLabel] + */ +enum class SynchronizationType { Binary, Barrier } + +/** + * Thread synchronization algebra defines synchronization rules + * for different types of thread event labels. + * + * The rules are as follows. + * + * - Thread fork synchronizes with thread start + * if the thread id of the starting thread is in the set of forked threads: + * + * ``` + * TFork(ts) \+ TStart^req(t) = TStart^rsp(t) | if t in ts + * ``` + * + * - Thread finish synchronizes with thread join, + * wherein the set of finished thread ids is subtracted from the set of joined thread ids. + * + * ``` + * TFinish(ts) \+ TJoin^{req|rsp}(ts') = TJoin^rsp(ts' \ ts) + * ``` + * + * - Two thread finish labels synchronize, and their sets of finished thread ids are joined. + * + * ``` + * TFinish(ts) \+ TFinish(ts') = TFinish(ts + ts') + * ``` + */ +private val ThreadSynchronizationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when (label) { + is ThreadForkLabel -> SynchronizationType.Binary + is ThreadStartLabel -> SynchronizationType.Binary + is ThreadFinishLabel -> SynchronizationType.Barrier + is ThreadJoinLabel -> SynchronizationType.Barrier + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + other is ThreadStartLabel && other.isRequest -> + other.getResponse(label) + other is ThreadJoinLabel -> + other.getResponse(label) + other is ThreadFinishLabel -> + other.join(label) + else -> null + } + +} + +/** + * Checks whether this [ThreadStartLabel] is a valid response to the given [label]. + * + * @see ThreadStartLabel.getResponse + */ +fun ThreadStartLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is ThreadStartLabel && threadId == label.threadId +} + +fun ThreadStartLabel.getResponse(label: EventLabel): ThreadStartLabel? = when { + isRequest -> label.asThreadForkLabel() + ?.takeIf { isRequest && threadId in it.forkThreadIds } + ?.let { this.copy(kind = LabelKind.Response) } + + else -> null +} + +fun ThreadJoinLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest || label.isResponse) + return label is ThreadJoinLabel && label.joinThreadIds.containsAll(joinThreadIds) +} + +fun ThreadJoinLabel.getResponse(label: EventLabel): EventLabel? = when { + !isReceive && label is ThreadFinishLabel && joinThreadIds.containsAll(label.finishedThreadIds) -> + this.copy( + kind = LabelKind.Response, + joinThreadIds = joinThreadIds - label.finishedThreadIds + ) + + else -> null +} + +fun ThreadFinishLabel.join(label: EventLabel): ThreadFinishLabel? = when { + (label is ThreadFinishLabel) -> + this.copy(finishedThreadIds = finishedThreadIds + label.finishedThreadIds) + + else -> null +} + +/** + * Memory access synchronization algebra defines synchronization rules + * for different types of memory access event labels. + * + * The rules are as follows. + * + * - Write access synchronizes with read access from the same memory location, + * passing its value into it: + * + * ``` + * Write(x, v) \+ Read^req(x) = Read^rsp(x, v) + * ``` + */ +private val MemoryAccessSynchronizationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is MemoryAccessLabel -> SynchronizationType.Binary + is ObjectAllocationLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + other is ReadAccessLabel && other.isRequest -> + other.getResponse(label) + else -> null + } + +} + +fun ReadAccessLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is ReadAccessLabel && + location == label.location && + isExclusive == label.isExclusive +} + +fun ReadAccessLabel.getResponse(label: EventLabel): EventLabel? = when { + isRequest -> label.asWriteAccessLabel(location)?.let { write -> + // TODO: perform dynamic type-check + this.copy( + kind = LabelKind.Response, + readValue = write.value, + ) + } + + else -> null +} + + +/** + * Mutex synchronization algebra defines synchronization rules + * for different types of mutex event labels. + * + * The rules are as follows. + * + * - Unlock synchronizes with lock-request to the same mutex: + * + * ``` + * Unlock(m) \+ Lock^req(m) = Lock^rsp(m) + * ``` + * + * - Notify synchronizes with wait-request on the same mutex: + * + * ``` + * Notify(m) \+ Wait^req(m) = Wait^rsp(m) + * ``` + */ +private val MutexSynchronizationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is MutexLabel -> SynchronizationType.Binary + is ObjectAllocationLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + other is LockLabel && other.isRequest -> + other.getResponse(label) + other is WaitLabel && other.isRequest -> + other.getResponse(label) + else -> null + } + +} + +fun LockLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is LockLabel && + mutexID == label.mutexID && + reentrancyDepth == label.reentrancyDepth +} + +fun LockLabel.getResponse(label: EventLabel): LockLabel? = when { + isRequest -> label.asUnlockLabel(mutexID) + ?.takeIf { isReentry implies label.isInitializingUnlock() } + ?.let { this.copy(kind = LabelKind.Response) } + + else -> null +} + +fun WaitLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is WaitLabel && mutexID == label.mutexID +} + +fun WaitLabel.getResponse(label: EventLabel): WaitLabel? = when { + isRequest -> label.asNotifyLabel(mutexID)?.let { + WaitLabel(LabelKind.Response, mutexID) + } + else -> null +} + +/** + * Parking synchronization algebra defines synchronization rules + * for different types of parking event labels. + * + * The rules are as follows. + * + * - Unpark synchronizes with park-request to the same thread: + * + * ``` + * Unpark(t) \+ Park^req(t) = Park^rsp(t) + * ``` + */ +private val ParkingSynchronizationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is ParkingEventLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + other is ParkLabel && other.isRequest -> + other.getResponse(label) + else -> null + } + +} + +fun ParkLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is ParkingEventLabel && threadId == label.threadId +} + +fun ParkLabel.getResponse(label: EventLabel): ParkLabel? = when { + isRequest && label is UnparkLabel && threadId == label.threadId -> + this.copy(kind = LabelKind.Response) + + // TODO: provide an option to enable spurious wake-ups + else -> null +} + +/** + * Coroutine synchronization algebra defines synchronization rules + * for different types of parking event labels. + * + * The rules are as follows. + * + * - Coroutine resume synchronizes with suspend-request on the same suspension point: + * + * ``` + * Resume(s) \+ Suspend^req(s) = Suspend^rsp(s) + * ``` + * + * - Coroutine cancel synchronizes with suspend-request on the same suspension point: + * + * ``` + * Cancel(s) \+ Suspend^req(s) = Suspend^rsp[cancelled](s) + * ``` + * + */ +private val CoroutineSynchronizationAlgebra = object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is CoroutineLabel -> SynchronizationType.Binary + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when { + other is CoroutineSuspendLabel && other.isRequest -> + other.getResponse(label) + else -> null + } + +} + +fun CoroutineSuspendLabel.isValidResponse(label: EventLabel): Boolean { + require(isResponse) + require(label.isRequest) + return label is CoroutineSuspendLabel && + threadId == label.threadId && + actorId == label.actorId && + (label.promptCancellation implies cancelled) +} + +fun CoroutineSuspendLabel.getResponse(label: EventLabel): CoroutineSuspendLabel? = when { + isRequest && !promptCancellation + && label is CoroutineResumeLabel + && threadId == label.threadId + && actorId == label.actorId -> + this.copy(kind = LabelKind.Response) + + // TODO: use separate CoroutineCancel label instead of InitializationLabel + isRequest && label is InitializationLabel -> + this.copy(kind = LabelKind.Response, cancelled = true, promptCancellation = false) + + else -> null +} + + +/** + * Atomic events synchronization algebra is a commutative synchronization algebra + * that combines multiple sub-algebras of synchronization. + * + * In particular, it encompasses the following synchronization algebras: + * - thread synchronization algebra, + * - memory access synchronization algebra, + * - mutex synchronization algebra, + * - parking synchronization algebra, + * - coroutine synchronization algebra. + * + */ +val AtomicSynchronizationAlgebra = CommutativeSynchronizationAlgebra(object : SynchronizationAlgebra { + + override fun syncType(label: EventLabel): SynchronizationType? = when(label) { + is ThreadEventLabel -> ThreadSynchronizationAlgebra.syncType(label) + is MemoryAccessLabel -> MemoryAccessSynchronizationAlgebra.syncType(label) + is MutexLabel -> MutexSynchronizationAlgebra.syncType(label) + is ParkingEventLabel -> ParkingSynchronizationAlgebra.syncType(label) + is CoroutineLabel -> CoroutineSynchronizationAlgebra.syncType(label) + + // special treatment of ObjectAllocationLabel, + // because it can contribute to several sub-algebras + is ObjectAllocationLabel -> SynchronizationType.Binary + + else -> null + } + + override fun synchronize(label: EventLabel, other: EventLabel): EventLabel? = when (other) { + is ThreadEventLabel -> ThreadSynchronizationAlgebra.synchronize(label, other) + is MemoryAccessLabel -> MemoryAccessSynchronizationAlgebra.synchronize(label, other) + is MutexLabel -> MutexSynchronizationAlgebra.synchronize(label, other) + is ParkLabel -> ParkingSynchronizationAlgebra.synchronize(label, other) + is CoroutineLabel -> CoroutineSynchronizationAlgebra.synchronize(label, other) + else -> null + } + +}) + +fun EventLabel.isValidResponse(label: EventLabel): Boolean = when (this) { + is ThreadStartLabel -> isValidResponse(label) + is ThreadJoinLabel -> isValidResponse(label) + is ReadAccessLabel -> isValidResponse(label) + is LockLabel -> isValidResponse(label) + is WaitLabel -> isValidResponse(label) + is ParkLabel -> isValidResponse(label) + is CoroutineSuspendLabel -> isValidResponse(label) + else -> throw IllegalArgumentException() +} + +fun EventLabel.getResponse(label: EventLabel): EventLabel? = when (this) { + is ThreadStartLabel -> getResponse(label) + is ThreadJoinLabel -> getResponse(label) + is ReadAccessLabel -> getResponse(label) + is LockLabel -> getResponse(label) + is WaitLabel -> getResponse(label) + is ParkLabel -> getResponse(label) + is CoroutineSuspendLabel -> getResponse(label) + else -> null +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Execution.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Execution.kt new file mode 100644 index 000000000..c2c3e7f73 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Execution.kt @@ -0,0 +1,442 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.util.* + + +/** + * Execution represents a set of events belonging to a single program's execution. + * + * The set of events in the execution is causally closed, meaning that if + * an event belongs to the execution, all its causal predecessors must also belong to it. + * + * We assume that all the events within the same thread are totally ordered by the program order. + * Thus, the execution is represented as a list of threads, + * with each thread being a list of thread's events given in program order. + * + * Since we also assume that the program order is consistent the enumeration order of events, + * (that is, `(a, b) \in po` implies `a.id < b.id`), + * the threads' events lists are sorted with respect to the enumeration order. + * + * @param E the type of events stored in the execution, must extend `ThreadEvent`. + */ +interface Execution : Collection { + + /** + * A map from thread IDs to the list of thread events. + */ + val threadMap: ThreadMap> + + /** + * Retrieves the list of events for the specified thread ID. + * + * @param tid The thread ID for which to retrieve the events. + * @return The list of events for the specified thread ID, + * or null if the requested thread does not belong to the execution. + */ + operator fun get(tid: ThreadID): SortedList? = + threadMap[tid] + + /** + * Checks if the given event is present at the execution. + * + * @param element The event to check for presence in the execution. + * @return true if the execution contains the event, false otherwise. + */ + override fun contains(element: @UnsafeVariance E): Boolean = + get(element.threadId, element.threadPosition) == element + + /** + * Checks if all events in the given collection are present at the execution. + * + * @param elements The collection of events to check for presence in the execution. + * @return true if all events in the collection are present in the execution, false otherwise. + */ + override fun containsAll(elements: Collection<@UnsafeVariance E>): Boolean = + elements.all { contains(it) } + + + /** + * Returns an iterator over the events in the execution. + * + * @return an iterator that iterates over the events of the execution. + */ + override fun iterator(): Iterator = + threadIDs.map { get(it)!! }.asSequence().flatten().iterator() + +} + +/** + * Mutable execution represents a modifiable set of events belonging to a single program's execution. + * + * Mutable execution supports events' addition. + * Events can only be added according to the causal order. + * That is, whenever a new event is added to the execution, + * all the events on which it depends, including its program order predecessors, + * should already be added into the execution. + * + * @param E the type of events stored in the execution, must extend `ThreadEvent`. + * + * @see Execution + * @see ExecutionFrontier + */ +interface MutableExecution : Execution { + + /** + * Adds the specified event to the mutable execution. + * All the causal predecessors of the event must already be added to the execution. + * + * @param event the event to be added to the execution. + */ + fun add(event: E) + + // TODO: support single (causally-last) event removal (?) +} + +val Execution<*>.threadIDs: Set + get() = threadMap.keys + +val Execution<*>.maxThreadID: ThreadID + get() = threadIDs.maxOrNull() ?: -1 + +fun Execution<*>.getThreadSize(tid: ThreadID): Int = + get(tid)?.size ?: 0 + +fun Execution<*>.lastPosition(tid: ThreadID): Int = + getThreadSize(tid) - 1 + +fun Execution.firstEvent(tid: ThreadID): E? = + get(tid)?.firstOrNull() + +fun Execution.lastEvent(tid: ThreadID): E? = + get(tid)?.lastOrNull() + +operator fun Execution.get(tid: ThreadID, pos: Int): E? = + get(tid)?.getOrNull(pos) + +fun Execution.nextEvent(event: E): E? = + get(event.threadId)?.let { events -> + require(events[event.threadPosition] == event) + events.getOrNull(event.threadPosition + 1) + } + +// TODO: make default constructor +fun Execution(nThreads: Int): Execution = + MutableExecution(nThreads) + +fun MutableExecution(nThreads: Int): MutableExecution = + ExecutionImpl(ArrayIntMap(*(0 until nThreads) + .map { (it to sortedArrayListOf()) } + .toTypedArray() + )) + +fun executionOf(vararg pairs: Pair>): Execution = + mutableExecutionOf(*pairs) + +fun mutableExecutionOf(vararg pairs: Pair>): MutableExecution = + ExecutionImpl(ArrayIntMap(*pairs + .map { (tid, events) -> (tid to SortedArrayList(events)) } + .toTypedArray() + )) + +private class ExecutionImpl( + override val threadMap: ArrayIntMap> +) : MutableExecution { + + override var size: Int = threadMap.values.sumOf { it.size } + private set + + override fun isEmpty(): Boolean = + (size == 0) + + override fun get(tid: ThreadID): SortedMutableList? = + threadMap[tid] + + override fun add(event: E) { + ++size + threadMap[event.threadId]!! + .ensure { event.parent == it.lastOrNull() } + .also { it.add(event) } + } + + override fun equals(other: Any?): Boolean = + (other is ExecutionImpl<*>) && (size == other.size) && (threadMap == other.threadMap) + + override fun hashCode(): Int = + threadMap.hashCode() + + override fun toString(): String = buildString { + appendLine("<======== Execution Graph @${hashCode()} ========>") + threadIDs.toList().sorted().forEach { tid -> + val events = threadMap[tid] ?: return@forEach + appendLine("[-------- Thread #${tid} --------]") + for (event in events) { + appendLine("$event") + if (event.dependencies.isNotEmpty()) { + appendLine(" dependencies: ${event.dependencies.joinToString()}}") + } + } + } + } + +} + +fun Execution.toFrontier(): ExecutionFrontier = + toMutableFrontier() + +fun Execution.toMutableFrontier(): MutableExecutionFrontier = + threadIDs.map { tid -> + tid to get(tid)?.lastOrNull() + }.let { + mutableExecutionFrontierOf(*it.toTypedArray()) + } + +fun Execution.calculateFrontier(clock: VectorClock): MutableExecutionFrontier = + (0 until maxThreadID).mapNotNull { tid -> + val timestamp = clock[tid] + if (timestamp >= 0) + tid to this[tid, timestamp] + else null + }.let { + mutableExecutionFrontierOf(*it.toTypedArray()) + } + +fun Execution.buildEnumerator() = object : Enumerator { + + private val events = enumerationOrderSorted() + + private val eventIndices = threadMap.mapValues { (_, threadEvents) -> + List(threadEvents.size) { pos -> + events.indexOf(threadEvents[pos]).ensure { it >= 0 } + } + } + + override fun get(i: Int): E { + return events[i] + } + + override fun get(x: E): Int { + return eventIndices[x.threadId]!![x.threadPosition] + } + +} + +fun Execution<*>.locations(): Set { + val locations = mutableSetOf() + for (event in this) { + val location = (event.label as? MemoryAccessLabel)?.location + ?: continue + locations.add(location) + } + return locations +} + +fun Execution.getResponse(request: AtomicThreadEvent): AtomicThreadEvent? { + // TODO: handle the case of span-start label + require(request.label.isRequest && !request.label.isSpanLabel) + return this[request.threadId, request.threadPosition + 1]?.ensure { + it.isValidResponse(request) + } +} + +fun Execution.isBlockedDanglingRequest(event: E): Boolean = + event.label.isRequest && + event.label.isBlocking && + event == this[event.threadId]?.last() + +/** + * Computes the backward vector clock for the given event and relation. + * + * Backward vector clock encodes the position of the last event in each thread, + * on which the given event depends on according to the given relation. + * + * Graphically, this can be illustrated by the following picture: + * e1 e2 e3 + * \ | / + * e + * + * Formally, for an event `e` and relation `r` the backward vector clock stores + * a mapping `tid -> pos` such that `e' = execution[tid, pos]` is + * the last event in thread `tid` such that `(e', e) \in r`. + * + * This function has time complexity O(E^2) where E is the number of events in execution. + * If the given relation respects program order (see definition below), + * then the time complexity can be optimized (using binary search) + * to O(E * T * log E) where T is the number of threads. + * + * The relation is said to respect the program order (in the backward direction) if the following is true: + * (x, y) \in r and (z, x) \in po implies (z, y) \in r + * + * + * @param event The event for which to compute the backward vector clock. + * @param relation The relation used to determine the causality between events. + * @return The computed backward vector clock. + */ +fun Execution.computeBackwardVectorClock(event: E, relation: Relation, + respectsProgramOrder: Boolean = false +): VectorClock { + val capacity = 1 + this.maxThreadID + val clock = MutableVectorClock(capacity) + for (i in 0 until capacity) { + val threadEvents = get(i) ?: continue + val position = if (respectsProgramOrder) { + threadEvents.binarySearch { !relation(it, event) } + } else { + threadEvents.indexOfFirst { !relation(it, event) } + } + clock[i] = (position - 1).coerceAtLeast(-1) + } + return clock +} + +/** + * Computes the forward vector clock for the given event and relation. + * + * Forward vector clock encodes the position of the first event in each thread, + * that depends on given event according to the given relation. + * + * Graphically, this can be illustrated by the following picture: + * e + * / | \ + * e1 e2 e3 + * + * Formally, for an event `e` and relation `r` the forward vector clock stores + * a mapping `tid -> pos` such that `e' = execution[tid, pos]` is + * the first event in thread `tid` such that `(e, e') \in r`. + * + * This function has time complexity O(E^2) where E is the number of events in execution. + * If the given relation respects program order (see definition below), + * then the time complexity can be optimized (using binary search) + * to O(E * T * log E) where T is the number of threads. + * + * The relation is said to respect the program order (in the forward direction) if the following is true: + * (x, y) \in r and (y, z) \in po implies (x, z) \in r + * + * + * @param event The event for which to compute the forward vector clock. + * @param relation The relation used to determine the causality between events. + * @return The computed forward vector clock. + */ +fun Execution.computeForwardVectorClock(event: E, relation: Relation, + respectsProgramOrder: Boolean = false, +): VectorClock { + val capacity = 1 + this.maxThreadID + val clock = MutableVectorClock(capacity) + for (i in 0 until capacity) { + val threadEvents = get(i) ?: continue + val position = if (respectsProgramOrder) { + threadEvents.binarySearch { relation(event, it) } + } else { + threadEvents.indexOfFirst { relation(event, it) } + } + clock[i] = position + } + return clock +} + +fun VectorClock.observes(event: ThreadEvent): Boolean = + observes(event.threadId, event.threadPosition) + +fun VectorClock.observes(execution: Execution<*>): Boolean = + execution.threadMap.values.all { events -> + events.lastOrNull()?.let { observes(it) } ?: true + } + +fun Covering.coverable(event: E, clock: VectorClock): Boolean = + this(event).all { clock.observes(it) } + +fun Covering.allCoverable(events: List, clock: VectorClock): Boolean = + events.all { event -> this(event).all { clock.observes(it) || it in events } } + +fun Covering.firstCoverable(events: List, clock: VectorClock): Boolean = + coverable(events.first(), clock) + +fun Collection.enumerationOrderSorted(): List = + this.sorted() + +fun Execution.enumerationOrderSorted(): List = + this.sorted() + + +fun Execution.buildGraph( + relation: Relation, + respectsProgramOrder: Boolean = false, +) = object : Graph { + private val execution = this@buildGraph + + override val nodes: Collection + get() = execution + + private val enumerator = execution.buildEnumerator() + + private val nThreads = 1 + execution.maxThreadID + + private val adjacencyList = Array(nodes.size) { i -> + val event = enumerator[i] + val clock = execution.computeForwardVectorClock(event, relation, + respectsProgramOrder = respectsProgramOrder + ) + (0 until nThreads).mapNotNull { tid -> + if (clock[tid] != -1) execution[tid, clock[tid]] else null + } + } + + override fun adjacent(node: E): List { + val idx = enumerator[node] + return adjacencyList[idx] + } + +} + +// TODO: include parent event in covering (?) and remove `External` +fun Execution.buildExternalCovering( + relation: Relation, + respectsProgramOrder: Boolean = false, +) = object : Covering { + + // TODO: document this precondition! + init { + // require(respectsProgramOrder) + } + + private val execution = this@buildExternalCovering + private val enumerator = execution.buildEnumerator() + + private val nThreads = 1 + maxThreadID + + val covering: List> = execution.indices.map { index -> + val event = enumerator[index] + val clock = execution.computeBackwardVectorClock(event, relation, + respectsProgramOrder = respectsProgramOrder + ) + (0 until nThreads).mapNotNull { tid -> + if (tid != event.threadId && clock[tid] != -1) + execution[tid, clock[tid]] + else null + } + } + + override fun invoke(x: E): List = + covering[enumerator[x]] + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionFrontier.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionFrontier.kt new file mode 100644 index 000000000..52bf8495c --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionFrontier.kt @@ -0,0 +1,136 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* + +// TODO: implement VectorClock interface?? +interface ExecutionFrontier { + val threadMap: ThreadMap +} + +interface MutableExecutionFrontier : ExecutionFrontier { + override val threadMap: MutableThreadMap +} + +val ExecutionFrontier<*>.threadIDs: Set + get() = threadMap.keys + +val ExecutionFrontier.events: List + get() = threadMap.values.filterNotNull() + +operator fun ExecutionFrontier.get(tid: ThreadID): E? = + threadMap[tid] + +fun ExecutionFrontier<*>.getPosition(tid: ThreadID): Int = + get(tid)?.threadPosition ?: -1 + +fun ExecutionFrontier<*>.getNextPosition(tid: ThreadID): Int = + 1 + getPosition(tid) + +operator fun ExecutionFrontier.contains(event: E): Boolean { + val lastEvent = get(event.threadId) ?: return false + return programOrder.orEqual(event, lastEvent) +} + +operator fun MutableExecutionFrontier.set(tid: ThreadID, event: E?) { + require(tid == (event?.threadId ?: tid)) + threadMap[tid] = event +} + +fun MutableExecutionFrontier.update(event: E) { + check(event.parent == get(event.threadId)) + set(event.threadId, event) +} + +fun MutableExecutionFrontier.merge(other: ExecutionFrontier) { + threadMap.mergeReduce(other.threadMap) { x, y -> when { + x == null -> y + y == null -> x + else -> programOrder.max(x, y) + }} +} + +fun ExecutionFrontier.isBlockedDanglingRequest(event: E): Boolean = + event.label.isRequest && + event.label.isBlocking && + event == this[event.threadId] + +fun ExecutionFrontier.getDanglingRequests(): List { + return threadMap.mapNotNull { (_, lastEvent) -> + lastEvent?.takeIf { it.label.isRequest && !it.label.isSpanLabel } + } +} + +inline fun MutableExecutionFrontier.cut(event: E) { + // TODO: optimize for a case of a single event + cut(listOf(event)) +} + +inline fun MutableExecutionFrontier.cut(events: List) { + if (events.isEmpty()) + return + // TODO: optimize --- extract sublist of maximal events having no causal successors, + // to remove them faster without the need to compute vector clocks + threadMap.forEach { (tid, lastEvent) -> + // find the program-order latest event, not observing any of the cut events + // TODO: optimize --- transform events into vector clock + // TODO: optimize using binary search + val pred = lastEvent?.pred(inclusive = true) { + (it is E) && !events.any { cutEvent -> + it.causalityClock.observes(cutEvent.threadId, cutEvent.threadPosition) + } + } + set(tid, pred as? E) + } +} + +fun ExecutionFrontier(nThreads: Int): ExecutionFrontier = + MutableExecutionFrontier(nThreads) + +fun MutableExecutionFrontier(nThreads: Int): MutableExecutionFrontier = + ExecutionFrontierImpl(ArrayIntMap(nThreads)) + +fun executionFrontierOf(vararg pairs: Pair): ExecutionFrontier = + mutableExecutionFrontierOf(*pairs) + +fun mutableExecutionFrontierOf(vararg pairs: Pair): MutableExecutionFrontier = + ExecutionFrontierImpl(ArrayIntMap(*pairs)) + +fun ExecutionFrontier.copy(): MutableExecutionFrontier { + check(this is ExecutionFrontierImpl) + return ExecutionFrontierImpl(threadMap.copy()) +} + +private class ExecutionFrontierImpl( + override val threadMap: ArrayIntMap +): MutableExecutionFrontier + +inline fun ExecutionFrontier.toExecution(): Execution = + toMutableExecution() + +inline fun ExecutionFrontier.toMutableExecution(): MutableExecution = + threadIDs.map { tid -> + val events = get(tid)?.threadPrefix(inclusive = true) + tid to (events?.refine() ?: listOf()) + }.let { + mutableExecutionOf(*it.toTypedArray()) + } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionTracker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionTracker.kt new file mode 100644 index 000000000..75ad2b06d --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExecutionTracker.kt @@ -0,0 +1,59 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +/** + * The execution tracker interface is used to track the execution's modifications. + * It can be used to maintain additional data structures associated with the execution, + * such as various events' indices or auxiliary relations on events. + * + * There are two types of modifications that can be tracked. + + * The first type of modifications is incremental changes, such as the addition of a new event. + * The tracker can take advantage of the incremental nature of such changes to + * optimize the update of its internal data structures. + * + * The second type of modifications is complete reset of the execution state + * to some novel state, containing a different set of events. + * In response to that, the tracker might need to re-compute its internal data structures. + * Nevertheless, the tracker can utilize the fact that some subset of events might be preserved after reset, + * and thus save some part of its internal data structures. + * + * @param E the type of events of the tracked execution. + */ +interface ExecutionTracker> { + + /** + * This method is called when a new event is added to the execution. + * + * @param event the newly added event. + * + * @see MutableExecution.add + */ + fun onAdd(event: E) + + /** + * This method is called when the execution is reset to a novel state. + * + * @param execution the execution representing the new set of events after reset. + */ + fun onReset(execution: X) +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExtendedExecution.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExtendedExecution.kt new file mode 100644 index 000000000..3603a9b4d --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ExtendedExecution.kt @@ -0,0 +1,348 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency.* +import org.jetbrains.kotlinx.lincheck.util.* + + +/** + * Extended execution extends the regular execution with additional information, + * such as various event indices and auxiliary relations. + * This additional information is mainly used to maintain the consistency of the execution. + * + * @see Execution + */ +interface ExtendedExecution : Execution { + + /** + * Index for memory access events in the extended execution. + * + * @see AtomicMemoryAccessEventIndex + */ + val memoryAccessEventIndex : AtomicMemoryAccessEventIndex + + /** + * The read-modify-write order of the execution. + * + * @see ReadModifyWriteOrder + */ + val readModifyWriteOrder: Relation + + /** + * The writes-before (wb) order of the execution. + * + * @see WritesBeforeOrder + */ + val writesBeforeOrder: Relation + + /** + * The coherence (co) order of the execution + * + * @see CoherenceOrder + */ + val coherenceOrder: Relation + + /** + * The extended coherence (eco) relation of the execution + * + * @see ExtendedCoherenceOrder + */ + val extendedCoherence: Relation + + /** + * The sequential consistency order (sc) of the execution. + * + * @see SequentialConsistencyOrder + */ + val sequentialConsistencyOrder: Relation + + /** + * The execution order (xo) of the execution. + * + * @see ExecutionOrder + */ + val executionOrder: Relation + + val inconsistency: Inconsistency? +} + +/** + * Represents a mutable extended execution, which extends the regular execution + * with additional information and supports modification. + * + * The mutable extended execution allows adding events to the execution, + * or resetting the execution to a new set of events, + * rebuilding the auxiliary data structures accordingly. + */ +interface MutableExtendedExecution : ExtendedExecution, MutableExecution { + + override val memoryAccessEventIndex: MutableAtomicMemoryAccessEventIndex + + val readModifyWriteOrderComputable: ComputableNode + + val writesBeforeOrderComputable: ComputableNode + + val coherenceOrderComputable: ComputableNode + + val extendedCoherenceComputable: ComputableNode + + val sequentialConsistencyOrderComputable: ComputableNode + + val executionOrderComputable: ComputableNode + + /** + * Resets the mutable execution to contain the new set of events + * and rebuilds all the auxiliary data structures accordingly. + * + * To ensure causal closure, the reset method takes the new set of events as an [ExecutionFrontier] object. + * That is, after the reset, the execution will contain the events of the frontier as well as + * all of its causal predecessors. + * + * @see ExecutionFrontier + */ + fun reset(frontier: ExecutionFrontier) + + fun checkConsistency(): Inconsistency? +} + +fun ExtendedExecution(nThreads: Int): ExtendedExecution = + MutableExtendedExecution(nThreads) + +fun MutableExtendedExecution(nThreads: Int): MutableExtendedExecution = + ExtendedExecutionImpl(ResettableExecution(nThreads)) + + +/* private */ class ExtendedExecutionImpl( + val execution: ResettableExecution +) : MutableExtendedExecution, MutableExecution by execution { + + override val memoryAccessEventIndex = + MutableAtomicMemoryAccessEventIndex().apply { index(execution) } + + override val readModifyWriteOrderComputable = computable { ReadModifyWriteOrder(execution) } + + override val readModifyWriteOrder: Relation by readModifyWriteOrderComputable + + override val writesBeforeOrderComputable = computable { + WritesBeforeOrder( + execution, + memoryAccessEventIndex, + readModifyWriteOrderComputable.value, + causalityOrder + ) + } + .dependsOn(readModifyWriteOrderComputable, soft = true, invalidating = true) + + override val writesBeforeOrder: Relation by writesBeforeOrderComputable + + override val coherenceOrderComputable = computable { + CoherenceOrder( + execution, + memoryAccessEventIndex, + readModifyWriteOrderComputable.value, + causalityOrder union writesBeforeOrderComputable.value, // TODO: add eco or sc? + ) + } + .dependsOn(readModifyWriteOrderComputable, soft = true, invalidating = true) + .dependsOn(writesBeforeOrderComputable, soft = true, invalidating = true) + + override val coherenceOrder: Relation by coherenceOrderComputable + + override val extendedCoherenceComputable = computable { + ExtendedCoherenceOrder( + execution, + memoryAccessEventIndex, + causalityOrder union writesBeforeOrderComputable.value // TODO: add coherence + ) + } + .dependsOn(writesBeforeOrderComputable, soft = true, invalidating = true) + .apply { + // add reference to coherence order, so once it is computed + // it can force-set the extended coherence order + coherenceOrderComputable.value.extendedCoherenceOrder = this + } + + override val extendedCoherence: Relation by extendedCoherenceComputable + + override val sequentialConsistencyOrderComputable = computable { + SequentialConsistencyOrder( + execution, + memoryAccessEventIndex, + causalityOrder union extendedCoherenceComputable.value, + // TODO: refine eco order after sc order computation (?) + ) + } + .dependsOn(extendedCoherenceComputable, soft = true, invalidating = true) + + override val sequentialConsistencyOrder: Relation by sequentialConsistencyOrderComputable + + override val executionOrderComputable = computable { + ExecutionOrder( + execution, + memoryAccessEventIndex, + causalityOrder union extendedCoherence, // TODO: add sc order + ) + } + .dependsOn(extendedCoherenceComputable, soft = true, invalidating = true) + .apply { + // add reference to coherence order, so once it is computed + // it can force-set the execution order + coherenceOrderComputable.value.executionOrder = this + } + + override val executionOrder: Relation by executionOrderComputable + + private val consistencyChecker = aggregateConsistencyCheckers( + execution = this, + listOf( + ReadModifyWriteAtomicityChecker(execution = this), + + IncrementalSequentialConsistencyChecker( + execution = this, + checkReleaseAcquireConsistency = true, + approximateSequentialConsistency = false + ) + ), + listOf(), + ) + + private val trackers = listOf( + memoryAccessEventIndex.incrementalTracker(), + consistencyChecker.incrementalTracker(), + ) + + override val inconsistency: Inconsistency? + get() = consistencyChecker.state.inconsistency + + override fun checkConsistency(): Inconsistency? { + return consistencyChecker.check() + } + + override fun add(event: AtomicThreadEvent) { + execution.add(event) + for (tracker in trackers) + tracker.onAdd(event) + } + + override fun reset(frontier: ExecutionFrontier) { + execution.reset(frontier) + for (tracker in trackers) + tracker.onReset(this) + } + + override fun toString(): String = + execution.toString() + +} + +/* private */ class ResettableExecution(nThreads: Int) : MutableExecution { + + private var execution = MutableExecution(nThreads) + + constructor(execution: MutableExecution) : this(0) { + this.execution = execution + } + + override val size: Int + get() = execution.size + + override val threadMap: ThreadMap> + get() = execution.threadMap + + override fun isEmpty(): Boolean = + execution.isEmpty() + + override fun add(event: AtomicThreadEvent) { + execution.add(event) + } + + fun reset(frontier: ExecutionFrontier) { + execution = frontier.toMutableExecution() + } + + override fun equals(other: Any?): Boolean = + (other is ResettableExecution) && (execution == other.execution) + + override fun hashCode(): Int = + execution.hashCode() + + override fun toString(): String = + execution.toString() + +} + +private typealias ExtendedExecutionTracker = ExecutionTracker + +private fun MutableEventIndex.incrementalTracker(): ExtendedExecutionTracker { + return object : ExtendedExecutionTracker { + override fun onAdd(event: AtomicThreadEvent) { + index(event) + } + + override fun onReset(execution: MutableExtendedExecution) { + reset() + index(execution) + } + } +} + +private typealias AtomicEventConsistencyChecker = + IncrementalConsistencyChecker + +private fun AtomicEventConsistencyChecker.incrementalTracker(): ExtendedExecutionTracker { + return object : ExtendedExecutionTracker { + override fun onAdd(event: AtomicThreadEvent) { + check(event) + } + + override fun onReset(execution: MutableExtendedExecution) { + reset(execution) + } + } +} + +// private fun ComputableNode.incrementalTracker(): ExecutionTracker +// where I : Computable, +// I : Incremental +// { +// return object : ExecutionTracker { +// override fun onAdd(event: AtomicThreadEvent) { +// if (computed) value.add(event) +// } +// +// override fun onReset(execution: Execution) { +// reset() +// } +// } +// } +// +// private fun ComputableNode<*>.resettingTracker(): ExecutionTracker { +// return object : ExecutionTracker { +// override fun onAdd(event: AtomicThreadEvent) { +// reset() +// } +// +// override fun onReset(execution: Execution) { +// reset() +// } +// } +// } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Graph.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Graph.kt new file mode 100644 index 000000000..7353bc229 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Graph.kt @@ -0,0 +1,123 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* +import java.util.* + +// adjacency list is a function that +// for each node of a graph returns destination nodes of its outgoing edges +interface Graph { + val nodes: Collection + fun adjacent(node: T): List +} + +fun topologicalSorting(graph: Graph): List? { + val result = mutableListOf() + val state = graph.initializeTopoSortState() + val queue: Queue = LinkedList() + for (node in graph.nodes) { + if (state[node]!!.indegree == 0) + queue.add(node) + } + while (queue.isNotEmpty()) { + val node = queue.poll() + result.add(node) + graph.adjacent(node).forEach { + if (--state[it]!!.indegree == 0) + queue.add(it) + } + } + if (result.size != graph.nodes.size) { + return null + } + return result +} + +fun topologicalSortings(graph: Graph): Sequence> { + val result = MutableList(graph.nodes.size) { null } + val state = graph.initializeTopoSortState() + return sequence { + yieldTopologicalSortings(graph, 0, result, state) + } +} + +private suspend fun SequenceScope>.yieldTopologicalSortings( + graph: Graph, + depth: Int, + result: MutableList, + state: TopoSortState, +) { + // this flag is used to detect terminal recursive calls + var isTerminal = true + // iterate through all nodes + for (node in graph.nodes) { + val nodeState = state[node]!! + // skip visited and not-yet ready nodes + if (nodeState.visited || nodeState.indegree != 0) + continue + // push the current node on top of the result list + result[depth] = node + // mark node as visited + nodeState.visited = true + // decrease indegree of all adjacent nodes + for (adjacentNode in graph.adjacent(node)) + state[adjacentNode]!!.indegree-- + // explore topological sortings recursively + yieldTopologicalSortings(graph, depth + 1, result, state) + // since we made recursive call, reset the isTerminal flag + isTerminal = false + // rollback the state + for (adjacentNode in graph.adjacent(node)) + state[adjacentNode]!!.indegree++ + nodeState.visited = false + result[depth] = null + } + // if we are at terminal call, yield the resulting sorted list + if (isTerminal) { + val sorting = result.toMutableList().requireNoNulls() + yield(sorting) + } +} + +private typealias TopoSortState = MutableMap + +private data class TopoSortNodeState( + var visited: Boolean, + var indegree: Int, +) { + companion object { + fun initial() = TopoSortNodeState(false, 0) + } +} + +private fun Graph.initializeTopoSortState(): TopoSortState { + val state = mutableMapOf() + for (node in nodes) { + state.putIfAbsent(node, TopoSortNodeState.initial()) + for (adjacentNode in adjacent(node)) { + state.updateInplace(adjacentNode, default = TopoSortNodeState.initial()) { + indegree++ + } + } + } + return state +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ObjectRegistry.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ObjectRegistry.kt new file mode 100644 index 000000000..16f88a90a --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/ObjectRegistry.kt @@ -0,0 +1,189 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.util.* +import kotlin.collections.HashMap +import java.util.IdentityHashMap +import org.objectweb.asm.Type + + +data class ObjectEntry( + val id: ObjectID, + val obj: OpaqueValue, + val allocation: AtomicThreadEvent, +) { + init { + require(id != NULL_OBJECT_ID) + require(allocation.label is InitializationLabel || allocation.label is ObjectAllocationLabel) + require((id == STATIC_OBJECT_ID || obj.isPrimitive()) implies (allocation.label is InitializationLabel)) + } + + val isExternal: Boolean + get() = (allocation.label is InitializationLabel) + +} + +class ObjectRegistry { + + private var objectCounter = 0L + + private val objectIdIndex = HashMap() + + private val objectIndex = IdentityHashMap() + private val primitiveIndex = HashMap() + + val nextObjectID: ObjectID + get() = 1 + objectCounter + + private var initEvent: AtomicThreadEvent? = null + + fun initialize(initEvent: AtomicThreadEvent) { + require(initEvent.label is InitializationLabel) + this.initEvent = initEvent + } + + fun register(entry: ObjectEntry) { + check(entry.id != NULL_OBJECT_ID) + check(entry.id <= objectCounter + 1) + check(!entry.obj.isPrimitive) + objectIdIndex.put(entry.id, entry).ensureNull() + objectIndex.put(entry.obj.unwrap(), entry).ensureNull() + if (entry.id != STATIC_OBJECT_ID) { + objectCounter++ + } + } + + private fun registerPrimitiveObject(obj: OpaqueValue): ObjectID { + check(obj.isPrimitive) + val entry = primitiveIndex.computeIfAbsent(obj.unwrap()) { + val id = ++objectCounter + val entry = ObjectEntry(id, obj, initEvent!!) + objectIdIndex.put(entry.id, entry).ensureNull() + return@computeIfAbsent entry + } + return entry.id + } + + private fun registerExternalObject(obj: OpaqueValue): ObjectID { + if (obj.isPrimitive) { + return registerPrimitiveObject(obj) + } + val id = nextObjectID + val entry = ObjectEntry(id, obj, initEvent!!) + register(entry) + return id + } + + fun getOrRegisterObjectID(obj: OpaqueValue): ObjectID { + get(obj)?.let { return it.id } + val className = obj.unwrap().javaClass.simpleName + val id = registerExternalObject(obj) + (initEvent!!.label as InitializationLabel).trackExternalObject(className, id) + return id + } + + operator fun get(id: ObjectID): ObjectEntry? = + objectIdIndex[id] + + operator fun get(obj: OpaqueValue): ObjectEntry? = + if (obj.isPrimitive) primitiveIndex[obj.unwrap()] else objectIndex[obj.unwrap()] + + fun retain(predicate: (ObjectEntry) -> Boolean) { + objectIdIndex.values.retainAll(predicate) + objectIndex.values.retainAll(predicate) + primitiveIndex.values.retainAll(predicate) + } + +} + +fun ObjectRegistry.getOrRegisterObjectID(obj: OpaqueValue?): ObjectID = + if (obj == null) NULL_OBJECT_ID else getOrRegisterObjectID(obj) + +fun ObjectRegistry.getValue(type: Type, id: ValueID): OpaqueValue? = when (type.sort) { + Type.LONG -> id.opaque() + Type.INT -> id.toInt().opaque() + Type.BYTE -> id.toByte().opaque() + Type.SHORT -> id.toShort().opaque() + Type.CHAR -> id.toChar().opaque() + Type.BOOLEAN -> id.toInt().toBoolean().opaque() + else -> when (type) { + LONG_TYPE_BOXED -> id.opaque() + INT_TYPE_BOXED -> id.toInt().opaque() + BYTE_TYPE_BOXED -> id.toByte().opaque() + SHORT_TYPE_BOXED -> id.toShort().opaque() + CHAR_TYPE_BOXED -> id.toChar().opaque() + BOOLEAN_TYPE_BOXED -> id.toInt().toBoolean().opaque() + else -> get(id)?.obj + } +} + +fun ObjectRegistry.getValueID(type: Type, value: OpaqueValue?): ValueID { + if (value == null) return NULL_OBJECT_ID + return when (type.sort) { + Type.LONG -> (value.unwrap() as Long) + Type.INT -> (value.unwrap() as Int).toLong() + Type.BYTE -> (value.unwrap() as Byte).toLong() + Type.SHORT -> (value.unwrap() as Short).toLong() + Type.CHAR -> (value.unwrap() as Char).toLong() + Type.BOOLEAN -> (value.unwrap() as Boolean).toInt().toLong() + else -> when (type) { + LONG_TYPE_BOXED -> (value.unwrap() as Long) + INT_TYPE_BOXED -> (value.unwrap() as Int).toLong() + BYTE_TYPE_BOXED -> (value.unwrap() as Byte).toLong() + SHORT_TYPE_BOXED -> (value.unwrap() as Short).toLong() + CHAR_TYPE_BOXED -> (value.unwrap() as Char).toLong() + BOOLEAN_TYPE_BOXED -> (value.unwrap() as Boolean).toInt().toLong() + else -> get(value)?.id ?: NULL_OBJECT_ID + } + } +} + +fun ObjectRegistry.getOrRegisterValueID(type: Type, value: OpaqueValue?): ValueID { + if (value == null) return NULL_OBJECT_ID + return when (type.sort) { + Type.LONG -> (value.unwrap() as Long) + Type.INT -> (value.unwrap() as Int).toLong() + Type.SHORT -> (value.unwrap() as Short).toLong() + Type.CHAR -> (value.unwrap() as Char).toLong() + + // sometimes, due to JVM internals, boolean values can be reinterpreted as byte values + // (e.g., because of BALOAD and BASTORE instructions are used for both boolean and byte arrays); + // thus if the type-cast failed, we try to reinterpret the value and cast it to manually + Type.BYTE -> + (value.unwrap() as? Byte)?.toLong() ?: + (value.unwrap() as Boolean).toInt().toLong() + Type.BOOLEAN -> + (value.unwrap() as? Boolean)?.toInt()?.toLong() ?: + (value.unwrap() as Byte).toBoolean().toInt().toLong() + + else -> when (type) { + LONG_TYPE_BOXED -> (value.unwrap() as Long) + INT_TYPE_BOXED -> (value.unwrap() as Int).toLong() + BYTE_TYPE_BOXED -> (value.unwrap() as Byte).toLong() + SHORT_TYPE_BOXED -> (value.unwrap() as Short).toLong() + CHAR_TYPE_BOXED -> (value.unwrap() as Char).toLong() + BOOLEAN_TYPE_BOXED -> (value.unwrap() as Boolean).toInt().toLong() + else -> getOrRegisterObjectID(value) + } + } +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Relation.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Relation.kt new file mode 100644 index 000000000..840f6e753 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Relation.kt @@ -0,0 +1,247 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* + +fun interface Relation { + operator fun invoke(x: T, y: T): Boolean + + companion object { + fun empty() = Relation { _, _ -> false } + } +} + +infix fun Relation.union(relation: Relation) = Relation { x, y -> + this(x, y) || relation(x, y) +} + +infix fun Relation.intersection(relation: Relation) = Relation { x, y -> + this(x, y) && relation(x, y) +} + +fun Relation.orEqual(x: T, y: T): Boolean { + return (x == y) || this(x, y) +} + +fun Relation.unordered(x: T, y: T): Boolean { + return (x != y) && !this(x, y) && !this(y, x) +} + +fun Relation.maxOrNull(x: T, y: T): T? = when { + x == y -> x + this(x, y) -> y + this(y, x) -> x + else -> null +} + +fun Relation.max(x: T, y: T): T = + maxOrNull(x, y) ?: throw IncomparableArgumentsException("$x and $y are incomparable") + +class IncomparableArgumentsException(message: String): Exception(message) + +// covering for each element returns a list of elements on which it depends; +// in terms of a graph: for each node it returns source nodes of its incoming edges +fun interface Covering { + operator fun invoke(x: T): List +} + +interface Enumerator { + + operator fun get(x: T): Int + + // TODO: change to `List` and use it to optimize iteration in `RelationMatrix`? + operator fun get(i: Int): T + +} + +fun> SortedList.toEnumerator(): Enumerator = object : Enumerator { + private val list = this@toEnumerator + override fun get(i: Int): T = list[i] + override fun get(x: T): Int = list.indexOf(x) +} + +class RelationMatrix( + // TODO: take nodes from the enumerator (?) + val nodes: Collection, + val enumerator: Enumerator, +) : Relation { + + private val size = nodes.size + + private val matrix = Array(size) { BooleanArray(size) } + + private var version = 0 + + constructor(nodes: Collection, enumerator: Enumerator, relation: Relation) : this (nodes, enumerator) { + add(relation) + } + + override operator fun invoke(x: T, y: T): Boolean = + get(x, y) + + private inline operator fun get(i: Int, j: Int): Boolean { + return matrix[i][j] + } + + operator fun get(x: T, y: T): Boolean = + get(enumerator[x], enumerator[y]) + + private operator fun set(i: Int, j: Int, value: Boolean) { + version += (matrix[i][j] != value).toInt() + matrix[i][j] = value + } + + operator fun set(x: T, y: T, value: Boolean) = + set(enumerator[x], enumerator[y], value) + + fun add(relation: Relation) { + for (i in 0 until size) { + val x = enumerator[i] + for (j in 0 until size) { + this[i, j] = this[i, j] || relation(x, enumerator[j]) + } + } + } + + fun order(ordering: List, strict: Boolean = true) { + for (i in ordering.indices) { + for (j in i until ordering.size) { + if (strict && i == j) + continue + this[ordering[i], ordering[j]] = true + } + } + } + + fun remove(relation: Relation) { + for (i in 0 until size) { + for (j in 0 until size) { + this[i, j] = this[i, j] && !relation(enumerator[i], enumerator[j]) + } + } + } + + fun filter(relation: Relation) { + for (i in 0 until size) { + for (j in 0 until size) { + this[i, j] = this[i, j] && relation(enumerator[i], enumerator[j]) + } + } + } + + private fun swap(i: Int, j: Int) { + val value = this[i, j] + this[i, j] = this[j, i] + this[j, i] = value + } + + fun transpose() { + for (i in 1 until size) { + for (j in 0 until i) { + swap(i, j) + } + } + } + + fun transitiveClosure() { + // TODO: optimize -- skip the computation for already transitive relation; + // track this by saving relation version number at the last call to `transitiveClosure` + kLoop@for (k in 0 until size) { + iLoop@for (i in 0 until size) { + if (!this[i, k]) + continue@iLoop + jLoop@for (j in 0 until size) { + this[i, j] = this[i, j] || this[k, j] + } + } + } + } + + fun transitiveReduction() { + jLoop@for (j in 0 until size) { + iLoop@for (i in 0 until size) { + if (!this[i, j]) + continue@iLoop + kLoop@for (k in 0 until size) { + if (this[i, k] && this[j, k]) { + this[i, k] = false + } + } + } + } + } + + fun equivalenceClosure(equivClassMapping : (T) -> List?) { + for (i in 0 until size) { + val x = enumerator[i] + val xClass = equivClassMapping(x) + for (j in 0 until size) { + val y = enumerator[j] + val yClass = equivClassMapping(y) + if (this[x, y] && xClass !== yClass) { + xClass?.forEach { this[it, y] = true } + yClass?.forEach { this[x, it] = true } + } + } + } + } + + fun fixpoint(block: RelationMatrix.() -> Unit) { + do { + val changed = trackChanges { block() } + } while (changed) + } + + fun trackChanges(block: RelationMatrix.() -> Unit): Boolean { + val version = this.version + block(this) + return (version != this.version) + } + + fun isIrreflexive(): Boolean { + for (i in 0 until size) { + if (this[i, i]) + return false + } + return true + } + + fun toGraph(): Graph = toGraph(nodes, enumerator) + +} + +fun Relation.toGraph(nodes: Collection, enumerator: Enumerator) = object : Graph { + private val relation = this@toGraph + + override val nodes: Collection + get() = nodes + + private val adjacencyList = Array(nodes.size) { i -> + val x = enumerator[i] + nodes.filter { y -> relation(x, y) } + } + + override fun adjacent(node: T): List { + val idx = enumerator[node] + return adjacencyList[idx] + } +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Utils.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Utils.kt new file mode 100644 index 000000000..5db3210a4 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/Utils.kt @@ -0,0 +1,40 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure + +import org.jetbrains.kotlinx.lincheck.util.* + +fun buildEnumerator(events: List) = object : Enumerator { + + // TODO: perhaps we can maintain event numbers directly in events themself + // and update them during replay? + val index: Map = + mutableMapOf().apply { + events.forEachIndexed { i, event -> + put(event, i).ensureNull() + } + } + + override fun get(i: Int): AtomicThreadEvent = events[i] + + override fun get(x: AtomicThreadEvent): Int = index[x]!! + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/CoherenceChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/CoherenceChecker.kt new file mode 100644 index 000000000..7d306017f --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/CoherenceChecker.kt @@ -0,0 +1,273 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.util.* + +typealias CoherenceList = List + +class CoherenceChecker : ConsistencyChecker { + + override fun check(execution: MutableExtendedExecution): Inconsistency? { + execution.coherenceOrderComputable.apply { + initialize() + compute() + } + val coherenceOrder = execution.coherenceOrderComputable.value + return if (!coherenceOrder.isConsistent()) + CoherenceViolation() + else null + } + +} + +class CoherenceOrder( + val execution: Execution, + val memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + val rmwChainsStorage: ReadModifyWriteOrder, + val writesOrder: Relation, + var extendedCoherenceOrder: ComputableNode? = null, + var executionOrder: ComputableNode? = null, +) : Relation, Computable { + + private var consistent: Boolean = true + + private data class Entry( + val coherence: CoherenceList, + val positions: List, + val enumerator: Enumerator, + ) + + private val map = mutableMapOf() + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean { + val location = getLocationForSameLocationWriteAccesses(x, y) + ?: return false + val (_, positions, enumerator) = map[location] + ?: return writesOrder(x, y) + return positions[enumerator[x]] < positions[enumerator[y]] + } + + operator fun get(location: MemoryLocation): CoherenceList = + map[location]?.coherence ?: emptyList() + + fun isConsistent(): Boolean = + consistent + + override fun invalidate() { + reset() + } + + override fun reset() { + map.clear() + consistent = true + } + + override fun compute() { + check(map.isEmpty()) + generate(execution, memoryAccessEventIndex, rmwChainsStorage, writesOrder).forEach { coherence -> + val extendedCoherence = ExtendedCoherenceOrder(execution, memoryAccessEventIndex, + writesOrder = causalityOrder union coherence + ) + .apply { initialize(); compute() } + val executionOrder = ExecutionOrder(execution, memoryAccessEventIndex, + approximation = causalityOrder union extendedCoherence + ) + .apply { initialize(); compute() } + if (!executionOrder.isConsistent()) + return@forEach + this.map += coherence.map + this.extendedCoherenceOrder?.setComputed(extendedCoherence) + this.executionOrder?.setComputed(executionOrder) + return + } + // if we reached this point, then none of the generated coherence orderings is consistent + consistent = false + } + + companion object { + + private fun generate( + execution: Execution, + memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + rmwChainsStorage: ReadModifyWriteOrder, + writesOrder: Relation + ): Sequence { + val coherenceOrderings = memoryAccessEventIndex.locations.mapNotNull { location -> + if (memoryAccessEventIndex.isWriteWriteRaceFree(location)) + return@mapNotNull null + val writes = memoryAccessEventIndex.getWrites(location) + .takeIf { it.size > 1 } ?: return@mapNotNull null + val enumerator = memoryAccessEventIndex.enumerator(AtomicMemoryAccessCategory.Write, location)!! + topologicalSortings(writesOrder.toGraph(writes, enumerator)).filter { + rmwChainsStorage.respectful(it) + } + } + if (coherenceOrderings.isEmpty()) { + return sequenceOf( + CoherenceOrder(execution, memoryAccessEventIndex, rmwChainsStorage, writesOrder) + ) + } + return coherenceOrderings.cartesianProduct().map { coherenceList -> + val coherenceOrder = CoherenceOrder(execution, memoryAccessEventIndex, rmwChainsStorage, writesOrder) + for (coherence in coherenceList) { + val location = coherence.getLocationForSameLocationWriteAccesses()!! + val enumerator = memoryAccessEventIndex.enumerator(AtomicMemoryAccessCategory.Write, location)!! + val positions = MutableList(coherence.size) { 0 } + coherence.forEachIndexed { i, write -> + positions[enumerator[write]] = i + } + coherenceOrder.map[location] = Entry(coherence, positions, enumerator) + } + return@map coherenceOrder + } + } + + } +} + +class ExtendedCoherenceOrder( + val execution: Execution, + val memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + val writesOrder: Relation, +): Relation, Computable { + + private val relations: MutableMap> = mutableMapOf() + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean { + val location = getLocationForSameLocationAccesses(x, y) + ?: return false + if (!(isWriteOrReadResponse(x) && isWriteOrReadResponse(y))) + return false + return relations[location]?.get(x, y) ?: false + } + + private fun isWriteOrReadResponse(x: AtomicThreadEvent): Boolean { + return (x.label.isWriteAccess() || x.label is ReadAccessLabel && x.label.isResponse) + } + + fun isIrreflexive(): Boolean = + relations.all { (_, relation) -> relation.isIrreflexive() } + + override fun initialize() { + for (location in memoryAccessEventIndex.locations) { + val events = mutableListOf().apply { + addAll(memoryAccessEventIndex.getWrites(location)) + addAll(memoryAccessEventIndex.getReadResponses(location)) + } + relations[location] = RelationMatrix(events, buildEnumerator(events)) + } + } + + override fun compute() { + addCoherenceEdges() + addReadsFromEdges() + addReadsBeforeEdges() + addCoherenceReadFromEdges() + addReadsBeforeReadsFromEdges() + } + + override fun reset() { + relations.clear() + } + + private fun addCoherenceEdges() { + for (location in memoryAccessEventIndex.locations) { + addCoherenceEdges(location) + } + } + + private fun addCoherenceEdges(location: MemoryLocation) { + val relation = relations[location]!! + for (write1 in memoryAccessEventIndex.getWrites(location)) { + for (write2 in memoryAccessEventIndex.getWrites(location)) { + if (write1 != write2 && writesOrder(write1, write2)) + relation[write1, write2] = true + } + } + } + + private fun addReadsFromEdges() { + for (location in memoryAccessEventIndex.locations) { + addReadsFromEdges(location) + } + } + + private fun addReadsFromEdges(location: MemoryLocation) { + val relation = relations[location]!! + for (read in memoryAccessEventIndex.getReadResponses(location)) { + relation[read.readsFrom, read] = true + } + } + + private fun addReadsBeforeEdges() { + for (location in memoryAccessEventIndex.locations) { + addReadsBeforeEdges(location) + } + } + + private fun addReadsBeforeEdges(location: MemoryLocation) { + val relation = relations[location]!! + for (read in memoryAccessEventIndex.getReadResponses(location)) { + for (write in memoryAccessEventIndex.getWrites(location)) { + if (relation(read.readsFrom, write)) { + relation[read, write] = true + } + } + } + } + + private fun addCoherenceReadFromEdges() { + for (location in memoryAccessEventIndex.locations) { + addCoherenceReadFromEdges(location) + } + } + + private fun addCoherenceReadFromEdges(location: MemoryLocation) { + val relation = relations[location]!! + for (read in memoryAccessEventIndex.getReadResponses(location)) { + for (write in memoryAccessEventIndex.getWrites(location)) { + if (relation(write, read.readsFrom)) { + relation[write, read] = true + } + } + } + } + + private fun addReadsBeforeReadsFromEdges() { + for (location in memoryAccessEventIndex.locations) { + addReadsBeforeReadsFromEdges(location) + } + } + + private fun addReadsBeforeReadsFromEdges(location: MemoryLocation) { + val relation = relations[location]!! + for (read1 in memoryAccessEventIndex.getReadResponses(location)) { + for (read2 in memoryAccessEventIndex.getReadResponses(location)) { + if (relation(read1, read2.readsFrom)) { + relation[read1, read2] = true + } + } + } + } +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ConsistencyChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ConsistencyChecker.kt new file mode 100644 index 000000000..c6acf735c --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ConsistencyChecker.kt @@ -0,0 +1,277 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* + +/** + * Represents an inconsistency in the consistency check of an execution. + */ +abstract class Inconsistency + +/** + * Represents an exception that is thrown when an inconsistency is detected during execution. + * + * @param inconsistency The inconsistency that caused the exception. + */ +class InconsistentExecutionException(val inconsistency: Inconsistency) : Exception(inconsistency.toString()) + +fun interface ConsistencyChecker> { + fun check(execution: X): Inconsistency? +} + +interface IncrementalConsistencyChecker> { + + /** + * Represents the current state of the consistency checker. + * + * @see ConsistencyVerdict + */ + val state: ConsistencyVerdict + + /** + * Performs incremental consistency check, + * verifying whether adding the given [event] to the current execution retains execution's consistency. + * The implementation is allowed to approximate the check in the following sense. + * - The check can be incomplete --- it can miss some inconsistencies. + * - The check should be sound --- if an inconsistency is reported, the execution should indeed be inconsistent. + * If an inconsistency was missed by an incremental check, + * a later full consistency check via [check] function should detect this inconsistency. + * + * @return consistency verdict. + * @see ConsistencyVerdict + */ + fun check(event: E): ConsistencyVerdict + + /** + * Performs full consistency check. + * The check should be sound and complete: + * an inconsistency should be reported if and only if the execution is indeed inconsistent. + * + * @return `null` if execution remains consistent, + * otherwise returns non-null [Inconsistency] object + * representing the reason of inconsistency. + */ + fun check(): Inconsistency? + + /** + * Resets the internal state of the consistency checker to [execution]. + * + * @return consistency verdict on the execution after the reset. + */ + fun reset(execution: X): ConsistencyVerdict +} + +/** + * Represents the verdict of an incremental consistency check. + * The verdict can be either: consistent, inconsistent, or unknown. + */ +sealed class ConsistencyVerdict { + object Unknown : ConsistencyVerdict() + object Consistent : ConsistencyVerdict() + class Inconsistent(val inconsistency: Inconsistency): ConsistencyVerdict() +} + +val ConsistencyVerdict.inconsistency: Inconsistency? + get() = when (this) { + is ConsistencyVerdict.Inconsistent -> this.inconsistency + else -> null + } + +fun ConsistencyVerdict.join(doCheck: () -> ConsistencyVerdict): ConsistencyVerdict { + // if the inconsistency is already detected -- do not evaluate the second argument + // and return inconsistency immediately + if (this is ConsistencyVerdict.Inconsistent) + return this + // otherwise, evaluate the second argument to determine the result + val other = doCheck() + // if inconsistency is detected, return it + if (other is ConsistencyVerdict.Inconsistent) + return other + // otherwise, return "consistent" verdict only if both arguments are "consistent", + // if one of them is "unknown" -- then return "unknown" + return when { + this is ConsistencyVerdict.Consistent && other is ConsistencyVerdict.Consistent -> + ConsistencyVerdict.Consistent + else -> + ConsistencyVerdict.Unknown + } +} + +abstract class AbstractIncrementalConsistencyChecker>( + execution: X +) : IncrementalConsistencyChecker { + + protected var execution: X = execution + private set + + final override var state: ConsistencyVerdict = ConsistencyVerdict.Unknown + private set + + // `true` means a full consistency check was already performed and its result was cached; + // `false` means the check was not performed or a new event was added since the last check + private var fullCheckCached: Boolean = false + + final override fun check(event: E): ConsistencyVerdict { + // reset check cache + fullCheckCached = false + // do case analysis + when (state) { + // skip the check if the checker is in unknown state + is ConsistencyVerdict.Unknown -> return state + // return inconsistency if it was detected earlier + is ConsistencyVerdict.Inconsistent -> return state + // otherwise, actually perform the incremental check + else -> { + state = doIncrementalCheck(event) + return state + } + } + } + + protected abstract fun doIncrementalCheck(event: E): ConsistencyVerdict + + final override fun check(): Inconsistency? { + // if the full consistency check was already performed, + // and there were no new events added, return the cached result + if (fullCheckCached) { + check(state !is ConsistencyVerdict.Unknown) + return state.inconsistency + } + // return inconsistency if it was detected before by the incremental check + if (state is ConsistencyVerdict.Inconsistent) { + fullCheckCached = true + return state.inconsistency!! + } + // otherwise do the full check + state = when (val inconsistency = doCheck()) { + is Inconsistency -> ConsistencyVerdict.Inconsistent(inconsistency) + else -> ConsistencyVerdict.Consistent + } + // cache the result and return + fullCheckCached = true + return state.inconsistency + } + + protected abstract fun doCheck(): Inconsistency? + + final override fun reset(execution: X): ConsistencyVerdict { + this.execution = execution + fullCheckCached = false + state = ConsistencyVerdict.Consistent + state = doReset() + if (state !is ConsistencyVerdict.Unknown) { + fullCheckCached = true + } + return state + } + + protected abstract fun doReset(): ConsistencyVerdict + +} + +abstract class AbstractPartialIncrementalConsistencyChecker>( + execution: X, + val checker: ConsistencyChecker, +) : AbstractIncrementalConsistencyChecker(execution) { + + override fun doCheck(): Inconsistency? { + // the parent class should guarantee that at this point the state is + // either "consistent" or "unknown". + check(state !is ConsistencyVerdict.Inconsistent) + // do a lightweight check before falling back to full consistency check + return when (val verdict = doLightweightCheck()) { + // if lightweight check returns verdict "consistent", + // then the whole execution is consistent --- return null + is ConsistencyVerdict.Consistent -> null + // if inconsistency is detected, return it + is ConsistencyVerdict.Inconsistent -> verdict.inconsistency + // otherwise, do the full consistency check + is ConsistencyVerdict.Unknown -> doFullCheck() + } + } + + protected abstract fun doLightweightCheck(): ConsistencyVerdict + + private fun doFullCheck(): Inconsistency? { + return checker.check(execution) + } + +} + +abstract class AbstractFullyIncrementalConsistencyChecker>( + execution: X +) : AbstractIncrementalConsistencyChecker(execution) { + + override fun doCheck(): Inconsistency? { + // if a checker is fully incremental, + // it can detect inconsistencies precisely upon each event addition; + // thus we should not reach this point while being in the unknown state + check(state != ConsistencyVerdict.Unknown) + return state.inconsistency + } + +} + +class AggregatedIncrementalConsistencyChecker>( + execution: X, + val incrementalConsistencyCheckers: List>, + val consistencyCheckers: List>, +) : AbstractIncrementalConsistencyChecker(execution) { + + override fun doIncrementalCheck(event: E): ConsistencyVerdict { + var verdict: ConsistencyVerdict = ConsistencyVerdict.Consistent + for (incrementalChecker in incrementalConsistencyCheckers) { + verdict = verdict.join { incrementalChecker.check(event) } + } + return verdict + } + + override fun doCheck(): Inconsistency? { + for (incrementalChecker in incrementalConsistencyCheckers) { + incrementalChecker.check()?.let { return it } + } + for (checker in consistencyCheckers) { + checker.check(execution)?.let { return it } + } + return null + } + + override fun doReset(): ConsistencyVerdict { + var result: ConsistencyVerdict = ConsistencyVerdict.Consistent + for (incrementalChecker in incrementalConsistencyCheckers) { + val verdict = incrementalChecker.reset(execution) + result = result.join { verdict } + } + return result + } + +} + +fun> aggregateConsistencyCheckers( + execution: X, + incrementalConsistencyCheckers: List>, + consistencyCheckers: List>, +) = AggregatedIncrementalConsistencyChecker( + execution, + incrementalConsistencyCheckers, + consistencyCheckers, + ) \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/LockConsistencyChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/LockConsistencyChecker.kt new file mode 100644 index 000000000..3c78f4bb8 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/LockConsistencyChecker.kt @@ -0,0 +1,58 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.util.* + +class LockConsistencyViolation(val event1: Event, val event2: Event) : Inconsistency() + +class LockConsistencyChecker : ConsistencyChecker { + + override fun check(execution: MutableExtendedExecution): Inconsistency? { + /* + * We construct a map that maps an unlock (or notify) event to its single matching lock (or wait) event. + * Because the single initialization event may encode several initial unlock events + * (for different lock objects), we need to handle this case in a special way. + * Therefore, for the first lock (the one synchronizing-with the initialization event), + * we instead use the lock object itself as the key in the map. + */ + // TODO: make incremental (inconsistency can be triggered when processing unlock/notify event) + // TODO: generalize (to arbitrary unique-flagged events) and refactor! + // TODO: unify with the atomicity checker? + val mapping = mutableMapOf() + for (event in execution) { + val label = event.label.refine { isResponse && (this is LockLabel || this is WaitLabel) } + ?: continue + if (label is WaitLabel && (event.notifiedBy.label as NotifyLabel).isBroadcast) + continue + val key: Any = when (event.syncFrom.label) { + is UnlockLabel, is NotifyLabel -> event.syncFrom + else -> label.mutexID + } + val other = mapping.put(key, event) + if (other != null) + return LockConsistencyViolation(event, other) + } + return null + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReadModifyWriteAtomicityChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReadModifyWriteAtomicityChecker.kt new file mode 100644 index 000000000..5dcc789d8 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReadModifyWriteAtomicityChecker.kt @@ -0,0 +1,224 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.util.* + +// TODO: restore atomicity violation information +class ReadModifyWriteAtomicityViolation(/*val write1: Event, val write2: Event*/) : Inconsistency() { + override fun toString(): String { + return "Atomicity violation detected" + } +} + +class ReadModifyWriteAtomicityChecker( + execution: MutableExtendedExecution +) : AbstractFullyIncrementalConsistencyChecker(execution) { + + override fun doIncrementalCheck(event: AtomicThreadEvent): ConsistencyVerdict { + check(execution.readModifyWriteOrderComputable.computed) + val readModifyWriteOrder = execution.readModifyWriteOrderComputable.value + val entry = readModifyWriteOrder.add(event) + ?: return ConsistencyVerdict.Consistent + if (!entry.isConsistent()) { + val inconsistency = ReadModifyWriteAtomicityViolation() + return ConsistencyVerdict.Inconsistent(inconsistency) + } + return ConsistencyVerdict.Consistent + } + + override fun doReset(): ConsistencyVerdict { + execution.readModifyWriteOrderComputable.apply { + reset() + initialize() + compute() + } + val readModifyWriteOrder = execution.readModifyWriteOrderComputable.value + if (!readModifyWriteOrder.isConsistent()) { + val inconsistency = ReadModifyWriteAtomicityViolation() + return ConsistencyVerdict.Inconsistent(inconsistency) + } + return ConsistencyVerdict.Consistent + } + +} + +typealias ReadModifyWriteChain = List +typealias MutableReadModifyWriteChain = MutableList + +class ReadModifyWriteOrder( + val execution: Execution, +) : Relation, Computable { + + abstract class Entry { + abstract val event: AtomicThreadEvent + abstract val chain: ReadModifyWriteChain + abstract val position: Int + } + + private data class EntryImpl( + override val event: AtomicThreadEvent, + override val chain: MutableReadModifyWriteChain, + override val position: Int, + ) : Entry() { + + init { + require(isValid()) + } + + private fun isValid(): Boolean = + chain.isNotEmpty() && when { + // we store only write accesses in the chain + !event.label.isWriteAccess() -> false + // non-exclusive write access can only be the first one in the chain + !event.label.isExclusiveWriteAccess() -> (position == 0) && (event == chain[0]) + // otherwise, consider entry to be valid + else -> true + } + } + + val entries: Collection + get() = eventMap.values + + private val eventMap = mutableMapOf() + + private val rmwChainsMap = mutableMapOf>() + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean { + if (x.label.isInitializingWriteAccess()) + return x == eventMap[y]?.chain?.get(0) + val xEntry = eventMap[x] ?: return false + val yEntry = eventMap[y] ?: return false + return (xEntry.chain == yEntry.chain) && (xEntry.position < yEntry.position) + } + + operator fun get(location: MemoryLocation, event: AtomicThreadEvent): Entry? { + /* Because the initialization (or object allocation) event + * may encode several initialization writes (e.g., one for each field of an object), + * we cannot map this initialization event to a single rmw chain. + * Instead, we need to map it to a different rmw chain for each location. + * To do so, we can utilize the fact that in such a scenario, + * the first chain for a given location may start with the chain + * beginning at the initialization event. + * If the first chain starts with some other event, + * it means that initialization event does not belong to any rmw chain. + */ + if (event.label.isInitializingWriteAccess()) { + val chain = rmwChainsMap[location]?.get(0)?.takeIf { it[0] == event } + return chain?.let { EntryImpl(event, chain, position = 0) } + } + /* otherwise we simply take the chain mapped to the read-from event */ + return eventMap[event] + } + + operator fun get(location: MemoryLocation): List? { + return rmwChainsMap[location] + } + + fun add(event: AtomicThreadEvent): Entry? { + val writeLabel = event.label.refine { isExclusive } + ?: return null + val location = writeLabel.location + val readFrom = event.exclusiveReadPart.readsFrom + val chain = get(location, readFrom)?.chain?.ensure { it.isNotEmpty() } ?: arrayListOf() + check(chain is MutableReadModifyWriteChain) + // if the read-from event is not yet mapped to any rmw chain, + // then we about to start a new one + if (chain.isEmpty()) { + check(!readFrom.label.isExclusiveWriteAccess()) + chain.add(readFrom) + if (!readFrom.label.isInitializingWriteAccess()) { + eventMap[readFrom] = EntryImpl(readFrom, chain, position = 0) + } + rmwChainsMap.updateInplace(location, default = arrayListOf()) { + // we order chains with respect to the enumeration order of their starting events + var position = indexOfFirst { readFrom.id < it[0].id } + if (position == -1) + position = size + add(position, chain) + } + } + chain.add(event) + return EntryImpl(event, chain, position = chain.size - 1).also { + eventMap[event] = it + } + } + + override fun compute() { + /* It is important to add events in some causality-compatible order + * (such as event enumeration order). + * This guarantees the following property: a mapping `w -> c`, where + * - `w` is an exclusive write event from the rmw pair of events (r, w), + * - `c` is a rmw chain to which `w` belongs to, + * would be added to the map only after mapping `w' -> c` is added, + * where `w'` is the write event from which `r` reads-from. + * In particular, for atomic-consistent executions, it implies that + * the rmw-chains would be added in their order from the begging of the chain to its end. + */ + for (event in execution.enumerationOrderSorted()) { + add(event) + } + } + + override fun reset() { + eventMap.clear() + rmwChainsMap.clear() + } + + fun respectful(events: List): Boolean { + check(events.isNotEmpty()) + val location = events.getLocationForSameLocationWriteAccesses()!! + val chains = rmwChainsMap[location]?.ensure { it.isNotEmpty() } + ?: return true + /* atomicity violation occurs when a write event is put in the middle of some rmw chain */ + var i = 0 + var pos = 0 + while (pos + chains[i].size <= events.size) { + if (events[pos] == chains[i].first()) { + if (events.subList(pos, pos + chains[i].size) != chains[i]) + return false + pos += chains[i].size + if (++i == chains.size) + return true + continue + } + pos++ + } + return false + } + +} + +fun ReadModifyWriteOrder.Entry.isConsistent() = when { + // write-part of atomic-read-modify write operation should read-from + // the preceding write event in the chain + event.label.isExclusiveWriteAccess() -> + event.exclusiveReadPart.readsFrom == chain[position - 1] + // the other case is a non-exclusive write access, + // which should be the first event in the chain + // (this is enforced by the `isValid` check in the constructor) + else -> true +} + +fun ReadModifyWriteOrder.isConsistent() = + entries.all { it.isConsistent() } \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReleaseAcquireConsistencyChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReleaseAcquireConsistencyChecker.kt new file mode 100644 index 000000000..33af43216 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/ReleaseAcquireConsistencyChecker.kt @@ -0,0 +1,134 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.util.* + +// TODO: what information should we display to help identify the cause of inconsistency: +// a cycle in writes-before relation? +class ReleaseAcquireInconsistency : Inconsistency() { + override fun toString(): String { + return "Release/Acquire inconsistency detected" + } +} + +class ReleaseAcquireConsistencyChecker : ConsistencyChecker { + + override fun check(execution: MutableExtendedExecution): Inconsistency? { + execution.writesBeforeOrderComputable.apply { + initialize() + compute() + } + val writesBeforeOrder = execution.writesBeforeOrderComputable.value + return if (!writesBeforeOrder.isConsistent()) + ReleaseAcquireInconsistency() + else null + } + +} + +class WritesBeforeOrder( + val execution: Execution, + val memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + val rmwChainsStorage: ReadModifyWriteOrder, + val happensBefore: Relation, +) : Relation, Computable { + + private val relations: MutableMap> = mutableMapOf() + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean { + val location = getLocationForSameLocationWriteAccesses(x, y) + ?: return false + return relations[location]?.get(x, y) ?: happensBefore(x, y) + } + + fun isConsistent(): Boolean = + isIrreflexive() + + private fun isIrreflexive(): Boolean = + relations.values.all { it.isIrreflexive() } + + override fun initialize() { + for (location in memoryAccessEventIndex.locations) { + if (memoryAccessEventIndex.isWriteWriteRaceFree(location)) + continue + initialize(location) + } + } + + private fun initialize(location: MemoryLocation) { + val writes = memoryAccessEventIndex.getWrites(location) + val enumerator = memoryAccessEventIndex.enumerator(AtomicMemoryAccessCategory.Write, location)!! + relations[location] = RelationMatrix(writes, enumerator) + } + + override fun reset() { + relations.clear() + } + + override fun compute() { + for ((location, relation) in relations) { + relation.apply { + compute(location) + transitiveClosure() + } + } + } + + private fun RelationMatrix.compute(location: MemoryLocation) { + addHappensBeforeEdges(location) + addOverwrittenWriteEdges(location) + computeReadModifyWriteChainsClosure(location) + } + + private fun RelationMatrix.addHappensBeforeEdges(location: MemoryLocation) { + val relation = this + for (write1 in memoryAccessEventIndex.getWrites(location)) { + for (write2 in memoryAccessEventIndex.getWrites(location)) { + // TODO: also add `rf^?;hb` edges (it is required for any model where `causalityOrder < happensBefore`) + if (happensBefore(write1, write2) && write1 != write2) { + relation[write1, write2] = true + } + } + } + } + + private fun RelationMatrix.addOverwrittenWriteEdges(location: MemoryLocation) { + val relation = this + for (read in memoryAccessEventIndex.getReadResponses(location)) { + for (write in memoryAccessEventIndex.getWrites(location)) { + // TODO: change this check from `(w,r) \in hb` to `(w,r) \in rf^?;hb` + if (happensBefore(write, read) && write != read.readsFrom) { + relation[write, read.readsFrom] = true + } + } + } + } + + private fun RelationMatrix.computeReadModifyWriteChainsClosure(location: MemoryLocation) { + this.equivalenceClosure { event -> + rmwChainsStorage[location, event]?.chain + } + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyChecker.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyChecker.kt new file mode 100644 index 000000000..e6a34d2b8 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyChecker.kt @@ -0,0 +1,374 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.util.* +import kotlin.collections.* + + +abstract class SequentialConsistencyViolation : Inconsistency() + +class SequentialConsistencyChecker( + val checkReleaseAcquireConsistency: Boolean = true, + val approximateSequentialConsistency: Boolean = true, + val checkCoherence: Boolean = true, +) : ConsistencyChecker { + + private val releaseAcquireChecker : ReleaseAcquireConsistencyChecker? = + if (checkReleaseAcquireConsistency) ReleaseAcquireConsistencyChecker() else null + + private val coherenceChecker : CoherenceChecker? = + if (checkCoherence) CoherenceChecker() else null + + override fun check(execution: MutableExtendedExecution): Inconsistency? { + check(!execution.executionOrderComputable.computed) + releaseAcquireChecker?.check(execution) + ?.let { return it } + coherenceChecker?.check(execution) + ?.let { return it } + check(execution.executionOrderComputable.computed) + val executionOrder = execution.executionOrderComputable.value + .ensure { it.isConsistent() } + SequentialConsistencyReplayer(1 + execution.maxThreadID).ensure { + it.replay(executionOrder.ordering) != null + } + return null + } + + /* + // we will gradually approximate the total sequential execution order of events + // by a partial order, starting with the partial causality order + var executionOrderApproximation : Relation = causalityOrder + // first try to check release/acquire consistency (it is cheaper) --- + // release/acquire inconsistency will also imply violation of sequential consistency, + if (releaseAcquireChecker != null) { + when (val verdict = releaseAcquireChecker.check(execution)) { + is ReleaseAcquireInconsistency -> return verdict + is ConsistencyWitness -> { + // if execution is release/acquire consistent, + // the writes-before relation can be used + // to refine the execution ordering approximation + val rmwChainsStorage = verdict.witness.rmwChainsStorage + val writesBefore = verdict.witness.writesBefore + executionOrderApproximation = executionOrderApproximation union writesBefore + // TODO: combine SC approximation phase with coherence phase + if (computeCoherenceOrdering) { + return checkByCoherenceOrdering(execution, memoryAccessEventIndex, rmwChainsStorage, writesBefore) + } + } + } + } + // TODO: combine SC approximation phase with coherence phase (and remove this check) + check(!computeCoherenceOrdering) + if (approximateSequentialConsistency) { + // TODO: embed the execution order approximation relation into the execution instance, + // so that this (and following stages) can be implemented as separate consistency check classes + val executionIndex = MutableAtomicMemoryAccessEventIndex() + .apply { index(execution) } + val scApprox = SequentialConsistencyOrder(execution, executionIndex, executionOrderApproximation).apply { + initialize() + compute() + } + if (!scApprox.isConsistent()) { + return SequentialConsistencyApproximationInconsistency() + } + executionOrderApproximation = scApprox + } + // get dependency covering to guide the search + val covering = execution.buildExternalCovering(executionOrderApproximation) + // aggregate atomic events before replaying + val (aggregated, remapping) = execution.aggregate(ThreadAggregationAlgebra.aggregator()) + // check consistency by trying to replay execution using sequentially consistent abstract machine + return checkByReplaying(aggregated, covering.aggregate(remapping)) + */ + + + // private fun checkByCoherenceOrdering( + // execution: Execution, + // executionIndex: AtomicMemoryAccessEventIndex, + // rmwChainsStorage: ReadModifyWriteOrder, + // wbRelation: WritesBeforeOrder, + // ): ConsistencyVerdict { + // val writesOrder = causalityOrder union wbRelation + // val executionOrderComputable = computable { + // ExecutionOrder(execution, executionIndex, Relation.empty()) + // } + // val coherence = CoherenceOrder(execution, executionIndex, rmwChainsStorage, writesOrder, + // executionOrder = executionOrderComputable + // ) + // .apply { initialize(); compute() } + // if (!coherence.isConsistent()) + // return SequentialConsistencyCoherenceViolation() + // val executionOrder = executionOrderComputable.value.ensure { it.isConsistent() } + // SequentialConsistencyReplayer(1 + execution.maxThreadID).ensure { + // it.replay(executionOrder.ordering) != null + // } + // return SequentialConsistencyWitness.create(executionOrder.ordering) + // } + +} + +class CoherenceViolation : SequentialConsistencyViolation() { + override fun toString(): String { + // TODO: what information should we display to help identify the cause of inconsistency? + return "Sequential consistency coherence violation detected" + } +} + +class IncrementalSequentialConsistencyChecker( + execution: MutableExtendedExecution, + checkReleaseAcquireConsistency: Boolean = true, + approximateSequentialConsistency: Boolean = true +) : AbstractPartialIncrementalConsistencyChecker( + execution = execution, + checker = SequentialConsistencyChecker( + checkReleaseAcquireConsistency, + approximateSequentialConsistency, + ) +) { + + private val lockConsistencyChecker = LockConsistencyChecker() + + override fun doIncrementalCheck(event: AtomicThreadEvent): ConsistencyVerdict { + check(state is ConsistencyVerdict.Consistent) + check(execution.executionOrderComputable.computed) + resetRelations() + val executionOrder = execution.executionOrderComputable.value + if (!executionOrder.isConsistentExtension(event)) { + // if we end up in an unknown state, reset the execution order, + // so it can be re-computed by the full consistency check + execution.executionOrderComputable.reset() + return ConsistencyVerdict.Unknown + } + executionOrder.add(event) + return ConsistencyVerdict.Consistent + } + + override fun doLightweightCheck(): ConsistencyVerdict { + // TODO: extract into separate checker + lockConsistencyChecker.check(execution)?.let { inconsistency -> + return ConsistencyVerdict.Inconsistent(inconsistency) + } + // check by trying to replay execution order + if (state == ConsistencyVerdict.Consistent) { + check(execution.executionOrderComputable.computed) + val replayer = SequentialConsistencyReplayer(1 + execution.maxThreadID) + val executionOrder = execution.executionOrderComputable.value + if (replayer.replay(executionOrder.ordering) != null) { + // if replay is successful, return "consistent" verdict + return ConsistencyVerdict.Consistent + } + } + // if we end up in an unknown state, reset the execution order, + // so it can be re-computed by the full consistency check + execution.executionOrderComputable.reset() + return ConsistencyVerdict.Unknown + } + + override fun doReset(): ConsistencyVerdict { + resetRelations() + execution.executionOrderComputable.apply { + reset() + // set state to `computed`, + // so we can push the events into the execution order + setComputed() + } + for (event in execution.enumerationOrderSorted()) { + val verdict = doIncrementalCheck(event) + if (verdict is ConsistencyVerdict.Unknown) { + return ConsistencyVerdict.Unknown + } + } + return ConsistencyVerdict.Consistent + } + + private fun ExecutionOrder.isConsistentExtension(event: AtomicThreadEvent): Boolean { + val last = ordering.lastOrNull() + val label = event.label + // TODO: for this check to be more robust, + // can we generalize it to work with the arbitrary aggregation algebra? + return when { + label is ReadAccessLabel && label.isResponse -> + // TODO: also check that read reads-from some consistent write: + // e.g. the globally last write, or the last observed write + event.isValidResponse(last!!) + + label is WriteAccessLabel && label.isExclusive -> + event.isWritePartOfAtomicUpdate(last!!) + + else -> true + } + } + + // TODO: move to corresponding individual consistency checkers + private fun resetRelations() { + execution.writesBeforeOrderComputable.reset() + execution.coherenceOrderComputable.reset() + execution.extendedCoherenceComputable.reset() + } + +} + + +class SequentialConsistencyApproximationInconsistency : SequentialConsistencyViolation() { + override fun toString(): String { + // TODO: what information should we display to help identify the cause of inconsistency? + return "Approximate sequential inconsistency detected" + } +} + +class SequentialConsistencyOrder( + val execution: Execution, + val memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + val memoryAccessOrder: Relation, +) : Relation, Computable { + + // TODO: make cached delegate? + private var consistent = true + + private var relation: RelationMatrix? = null + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean = + relation?.invoke(x, y) ?: false + + fun isConsistent(): Boolean { + return consistent + } + + override fun initialize() { + // TODO: optimize -- build the relation only for write and read-response events + relation = RelationMatrix(execution, execution.buildEnumerator()) + } + + override fun compute() { + val relation = this.relation!! + relation.add(memoryAccessOrder) + relation.fixpoint { + // TODO: maybe we can remove this check without affecting performance? + if (!isIrreflexive()) { + consistent = false + return@fixpoint + } + coherenceClosure() + transitiveClosure() + } + } + + override fun invalidate() { + consistent = true + } + + override fun reset() { + invalidate() + relation = null + } + + private fun RelationMatrix.coherenceClosure() { + for (location in memoryAccessEventIndex.locations) { + coherenceClosure(location) + } + } + + private fun RelationMatrix.coherenceClosure(location: MemoryLocation) { + val relation = this + for (read in memoryAccessEventIndex.getReadResponses(location)) { + for (write in memoryAccessEventIndex.getWrites(location)) { + if (relation(write, read) && write != read.readsFrom) { + relation[write, read.readsFrom] = true + } + if (relation(read.readsFrom, write)) { + relation[read, write] = true + } + } + } + } + +} + +class ExecutionOrder( + val execution: Execution, + val memoryAccessEventIndex: AtomicMemoryAccessEventIndex, + val approximation: Relation, +) : Relation, Computable { + + private var consistent = true + + private val _ordering = mutableListOf() + + val ordering: List + get() = _ordering + + private val constraints = Relation { x, y -> + when { + // put wait-request before notify event + x.label.isRequest && x.label is WaitLabel -> + (y == execution.getResponse(x)?.notifiedBy) + + else -> false + } + } + + override fun invoke(x: AtomicThreadEvent, y: AtomicThreadEvent): Boolean { + TODO("Not yet implemented") + } + + fun isConsistent(): Boolean = + // TODO: embed failure state into ComputableNode state machine? + consistent + + fun add(event: AtomicThreadEvent) { + check(consistent) + _ordering.add(event) + } + + override fun compute() { + check(_ordering.isEmpty()) + val relation = approximation union constraints + // construct aggregated execution consisting of atomic events + // to incorporate the atomicity constraints during the search for topological sorting + val (aggregatedExecution, _) = execution.aggregate(ThreadAggregationAlgebra.aggregator()) + val aggregatedRelation = relation.existsLifting() + // TODO: optimization --- we can build graph only for a subset of events, excluding: + // - non-blocking request events + // - events accessing race-free locations + // - what else? + // and then insert them back into the topologically sorted list + val graph = aggregatedExecution.buildGraph(aggregatedRelation) + val ordering = topologicalSorting(graph) + if (ordering == null) { + consistent = false + return + } + this._ordering.addAll(ordering.flatMap { it.events }) + } + + override fun invalidate() { + consistent = true + } + + override fun reset() { + _ordering.clear() + invalidate() + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyReplayer.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyReplayer.kt new file mode 100644 index 000000000..0b4ec895c --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/eventstructure/consistency/SequentialConsistencyReplayer.kt @@ -0,0 +1,240 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.consistency + +import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingMonitorTracker +import org.jetbrains.kotlinx.lincheck.util.* + +fun checkByReplaying( + execution: Execution, + covering: Covering +): Inconsistency? { + // TODO: this is just a DFS search. + // In fact, we can generalize this algorithm to + // two arbitrary labelled transition systems by taking their product LTS + // and trying to find a trace in this LTS leading to terminal state. + val context = Context(execution, covering) + val initState = State.initial(execution) + val stack = ArrayDeque(listOf(initState)) + val visited = mutableSetOf(initState) + with(context) { + while (stack.isNotEmpty()) { + val state = stack.removeLast() + if (state.isTerminal) { + return null + // return SequentialConsistencyWitness.create( + // executionOrder = state.history.flatMap { it.events } + // ) + } + state.transitions().forEach { + val unvisited = visited.add(it) + if (unvisited) { + stack.addLast(it) + } + } + } + return SequentialConsistencyReplayViolation() + } +} + +class SequentialConsistencyReplayViolation : SequentialConsistencyViolation() { + override fun toString(): String { + // TODO: what information should we display to help identify the cause of inconsistency? + return "Sequential consistency replay violation detected" + } +} + +internal data class SequentialConsistencyReplayer( + val nThreads: Int, + val memoryView: MutableMap = mutableMapOf(), + val monitorTracker: ModelCheckingMonitorTracker = ModelCheckingMonitorTracker(nThreads), + val monitorMapping: MutableMap = mutableMapOf() +) { + + fun replay(event: AtomicThreadEvent): SequentialConsistencyReplayer? { + val label = event.label + return when { + + label is ReadAccessLabel && label.isRequest -> + this + + label is ReadAccessLabel && label.isResponse -> + this.takeIf { + // TODO: do we really need this `if` here? + if (event.readsFrom.label is WriteAccessLabel) + memoryView[label.location] == event.readsFrom + else memoryView[label.location] == null + } + + label is WriteAccessLabel -> + this.copy().apply { memoryView[label.location] = event } + + label is LockLabel && label.isRequest -> + this + + label is LockLabel && label.isResponse && !label.isSynthetic -> { + val monitor = getMonitor(label.mutexID) + if (this.monitorTracker.canAcquireMonitor(event.threadId, monitor)) { + this.copy().apply { monitorTracker.acquireMonitor(event.threadId, monitor).ensure() } + } else null + } + + label is UnlockLabel && !label.isSynthetic -> + this.copy().apply { monitorTracker.releaseMonitor(event.threadId, getMonitor(label.mutexID)) } + + label is WaitLabel && label.isRequest -> + this.copy().apply { monitorTracker.waitOnMonitor(event.threadId, getMonitor(label.mutexID)).ensure() } + + label is WaitLabel && label.isResponse -> { + val monitor = getMonitor(label.mutexID) + if (this.monitorTracker.canAcquireMonitor(event.threadId, monitor)) { + this.copy().takeIf { !it.monitorTracker.waitOnMonitor(event.threadId, monitor) } + } else null + } + + label is NotifyLabel -> + this.copy().apply { monitorTracker.notify(event.threadId, getMonitor(label.mutexID), label.isBroadcast) } + + // auxiliary unlock/lock events inserted before/after wait events + label is LockLabel && label.isSynthetic -> + this + label is UnlockLabel && label.isSynthetic -> + this + + label is InitializationLabel -> this + label is ObjectAllocationLabel -> this + label is ThreadEventLabel -> this + // TODO: do we need to care about parking? + label is ParkingEventLabel -> this + label is CoroutineLabel -> this + label is ActorLabel -> this + label is RandomLabel -> this + + else -> unreachable() + + } + } + + fun replay(events: Iterable): SequentialConsistencyReplayer? { + var replayer = this + for (event in events) { + replayer = replayer.replay(event) ?: return null + } + return replayer + } + + fun replay(event: HyperThreadEvent): SequentialConsistencyReplayer? { + return replay(event.events) + } + + fun copy(): SequentialConsistencyReplayer = + SequentialConsistencyReplayer( + nThreads, + memoryView.toMutableMap(), + monitorTracker.copy(), + monitorMapping.toMutableMap(), + ) + + private fun getMonitor(objID: ObjectID): Any { + check(objID != NULL_OBJECT_ID) + return monitorMapping.computeIfAbsent(objID) { Any() } + } + +} + +private data class State( + val executionClock: MutableVectorClock, + val replayer: SequentialConsistencyReplayer, +) { + + // TODO: move to Context + var history: List = listOf() + private set + + constructor( + executionClock: MutableVectorClock, + replayer: SequentialConsistencyReplayer, + history: List, + ) : this(executionClock, replayer) { + this.history = history + } + + companion object { + fun initial(execution: Execution) = State( + executionClock = MutableVectorClock(1 + execution.maxThreadID), + replayer = SequentialConsistencyReplayer(1 + execution.maxThreadID), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) + return true + return (other is State) + && executionClock == other.executionClock + && replayer == other.replayer + } + + override fun hashCode(): Int { + var result = executionClock.hashCode() + result = 31 * result + replayer.hashCode() + return result + } + +} + +private class Context(val execution: Execution, val covering: Covering) { + + fun State.covered(event: HyperThreadEvent): Boolean = + executionClock.observes(event) + + fun State.coverable(event: HyperThreadEvent): Boolean = + covering.coverable(event, executionClock) + + val State.isTerminal: Boolean + get() = executionClock.observes(execution) + + fun State.transition(threadId: Int): State? { + val position = 1 + executionClock[threadId] + val event = execution[threadId, position] + ?.takeIf { coverable(it) } + ?: return null + val view = replayer.replay(event) + ?: return null + return State( + replayer = view, + history = this.history + event, + executionClock = this.executionClock.copy().apply { + increment(event.threadId) + }, + ) + } + + fun State.transitions() : List { + val states = arrayListOf() + for (threadId in execution.threadIDs) { + transition(threadId)?.let { states.add(it) } + } + return states + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingCTestConfiguration.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingCTestConfiguration.kt index 2fa92c3a2..e9a9275d0 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingCTestConfiguration.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingCTestConfiguration.kt @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.lincheck.Actor import org.jetbrains.kotlinx.lincheck.execution.* import org.jetbrains.kotlinx.lincheck.strategy.* import org.jetbrains.kotlinx.lincheck.strategy.managed.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode.* import org.jetbrains.kotlinx.lincheck.verifier.* @@ -26,7 +27,8 @@ class ModelCheckingCTestConfiguration(testClass: Class<*>, iterations: Int, thre checkObstructionFreedom: Boolean, hangingDetectionThreshold: Int, invocationsPerIteration: Int, guarantees: List, minimizeFailedScenario: Boolean, sequentialSpecification: Class<*>, timeoutMs: Long, - customScenarios: List + customScenarios: List, + experimentalModelChecking: Boolean, ) : ManagedCTestConfiguration( testClass = testClass, iterations = iterations, @@ -46,7 +48,11 @@ class ModelCheckingCTestConfiguration(testClass: Class<*>, iterations: Int, thre customScenarios = customScenarios ) { - override val instrumentationMode: InstrumentationMode get() = MODEL_CHECKING + private val useExperimentalModelChecking = + experimentalModelChecking || System.getProperty("lincheck.useExperimentalModelChecking")?.toBoolean() ?: false + + override val instrumentationMode: InstrumentationMode get() = + if (useExperimentalModelChecking) EXPERIMENTAL_MODEL_CHECKING else MODEL_CHECKING private var isReplayModeForIdeaPluginEnabled = false @@ -59,5 +65,13 @@ class ModelCheckingCTestConfiguration(testClass: Class<*>, iterations: Int, thre scenario: ExecutionScenario, validationFunction: Actor?, stateRepresentationMethod: Method?, - ): Strategy = ModelCheckingStrategy(this, testClass, scenario, validationFunction, stateRepresentationMethod, isReplayModeForIdeaPluginEnabled) + ): Strategy = + if (useExperimentalModelChecking) + EventStructureStrategy(this, testClass, scenario, validationFunction, stateRepresentationMethod) + else + ModelCheckingStrategy(this, testClass, scenario, validationFunction, stateRepresentationMethod, + replay = isReplayModeForIdeaPluginEnabled + ) } + + diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingOptions.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingOptions.kt index 1d2101db7..cb7fdd9f9 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingOptions.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/strategy/managed/modelchecking/ModelCheckingOptions.kt @@ -16,6 +16,13 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.* * Options for [model checking][ModelCheckingStrategy] strategy. */ class ModelCheckingOptions : ManagedOptions() { + private var experimentalModelChecking = false + + internal fun useExperimentalModelChecking(): ModelCheckingOptions { + experimentalModelChecking = true + return this + } + override fun createTestConfigurations(testClass: Class<*>): ModelCheckingCTestConfiguration { return ModelCheckingCTestConfiguration( testClass = testClass, @@ -33,7 +40,9 @@ class ModelCheckingOptions : ManagedOptions 0 - ExecutionPart.PARALLEL -> currentInterleaving.chooseThread(0) - ExecutionPart.POST -> 0 - ExecutionPart.VALIDATION -> 0 - } - loopDetector.beforePart(nextThread) - currentThread = nextThread - } + override fun chooseThread(iThread: Int): Int = - currentInterleaving.chooseThread(iThread).also { - check(it in switchableThreads(iThread)) { """ - Trying to switch the execution to thread $it, - but only the following threads are eligible to switch: ${switchableThreads(iThread)} - """.trimIndent() } - } + currentInterleaving.chooseThread(iThread) /** * An abstract node with an execution choice in the interleaving tree. @@ -462,7 +452,6 @@ internal class ModelCheckingStrategy( executionPosition = -1 // the first execution position will be zero interleavingFinishingRandom = Random(2) // random with a constant seed nextThreadToSwitch = threadSwitchChoices.iterator() - loopDetector.initialize() lastNotInitializedNodeChoices = null lastNotInitializedNode?.let { // Create a mutable list for the initialization of the not initialized node choices. @@ -581,6 +570,10 @@ internal class LocalObjectManager : ObjectTracker { objects.forEach { markObjectNonLocal(it) } } + override fun initializeObject(obj: Any) { + throw UnsupportedOperationException("Model checking strategy does not track object initialization") + } + override fun shouldTrackObjectAccess(obj: Any): Boolean = !isLocalObject(obj) @@ -589,6 +582,12 @@ internal class LocalObjectManager : ObjectTracker { */ private fun isLocalObject(obj: Any?) = localObjects.containsKey(obj) + override fun getObjectId(obj: Any): ObjectID { + // model checking strategy does not currently use object IDs + throw UnsupportedOperationException("Model checking strategy does not track unique object IDs") + // return System.identityHashCode(obj).toLong() + } + override fun reset() { localObjects.clear() } @@ -597,7 +596,7 @@ internal class LocalObjectManager : ObjectTracker { /** * Tracks synchronization operations on the monitors (intrinsic locks) */ -internal class ModelCheckingMonitorTracker(nThreads: Int) : MonitorTracker { +internal class ModelCheckingMonitorTracker(val nThreads: Int) : MonitorTracker { // Maintains a set of acquired monitors with an information on which thread // performed the acquisition and the reentrancy depth. private val acquiredMonitors = IdentityHashMap() @@ -653,7 +652,7 @@ internal class ModelCheckingMonitorTracker(nThreads: Int) : MonitorTracker { * Returns `true` if the monitor is already acquired by * the thread [threadId], or if this monitor is free to acquire. */ - private fun canAcquireMonitor(threadId: Int, monitor: Any) = + fun canAcquireMonitor(threadId: Int, monitor: Any) = acquiredMonitors[monitor]?.threadId?.equals(threadId) ?: true /** @@ -706,11 +705,40 @@ internal class ModelCheckingMonitorTracker(nThreads: Int) : MonitorTracker { waitForNotify.fill(false) } + fun copy(): ModelCheckingMonitorTracker { + val tracker = ModelCheckingMonitorTracker(nThreads) + acquiredMonitors.forEach { (monitor, info) -> + tracker.acquiredMonitors[monitor] = info.copy() + } + waitingMonitor.forEachIndexed { thread, info -> + tracker.waitingMonitor[thread] = info?.copy() + } + waitForNotify.copyInto(tracker.waitForNotify) + return tracker + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + return (other is ModelCheckingMonitorTracker) && + (nThreads == other.nThreads) && + (acquiredMonitors == other.acquiredMonitors) && + (waitingMonitor.contentEquals(other.waitingMonitor)) && + (waitForNotify.contentEquals(other.waitForNotify)) + } + + override fun hashCode(): Int { + var result = acquiredMonitors.hashCode() + result = 31 * result + waitingMonitor.contentHashCode() + result = 31 * result + waitForNotify.contentHashCode() + return result + } + /** * Stores the [monitor], id of the thread acquired the monitor [threadId], * and the number of reentrant acquisitions [timesAcquired]. */ - private class MonitorAcquiringInfo(val monitor: Any, val threadId: Int, var timesAcquired: Int) + // TODO: monitor should be opaque for the correctness of the generated equals/hashCode (?) + private data class MonitorAcquiringInfo(val monitor: Any, val threadId: Int, var timesAcquired: Int) } class ModelCheckingParkingTracker(val nThreads: Int, val allowSpuriousWakeUps: Boolean = false) : ParkingTracker { diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckClassVisitor.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckClassVisitor.kt index bba907c8f..833caefe7 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckClassVisitor.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckClassVisitor.kt @@ -90,7 +90,9 @@ internal class LincheckClassVisitor( return mv } if (methodName == "") { - mv = ObjectCreationTransformer(fileName, className, methodName, mv.newAdapter()) + mv = ObjectCreationTransformer(fileName, className, methodName, mv.newAdapter(), + interceptObjectInitialization = (instrumentationMode == EXPERIMENTAL_MODEL_CHECKING), + ) return mv } if (className.contains("ClassLoader")) { @@ -111,11 +113,18 @@ internal class LincheckClassVisitor( // `skipVisitor` must not capture `MethodCallTransformer` // (to filter static method calls inserted by coverage library) val skipVisitor: MethodVisitor = mv - mv = MethodCallTransformer(fileName, className, methodName, mv.newAdapter()) + mv = MethodCallTransformer(fileName, className, methodName, mv.newAdapter(), + interceptAtomicMethodCallResult = (instrumentationMode == EXPERIMENTAL_MODEL_CHECKING), + ) mv = MonitorTransformer(fileName, className, methodName, mv.newAdapter()) mv = WaitNotifyTransformer(fileName, className, methodName, mv.newAdapter()) mv = ParkingTransformer(fileName, className, methodName, mv.newAdapter()) - mv = ObjectCreationTransformer(fileName, className, methodName, mv.newAdapter()) + mv = ObjectCreationTransformer(fileName, className, methodName, mv.newAdapter(), + interceptObjectInitialization = (instrumentationMode == EXPERIMENTAL_MODEL_CHECKING), + ) + mv = ReflectionTransformer(fileName, className, methodName, mv.newAdapter(), + interceptArrayCopyMethod = (instrumentationMode == EXPERIMENTAL_MODEL_CHECKING), + ) mv = DeterministicHashCodeTransformer(fileName, className, methodName, mv.newAdapter()) mv = DeterministicTimeTransformer(mv.newAdapter()) mv = DeterministicRandomTransformer(fileName, className, methodName, mv.newAdapter()) @@ -126,7 +135,9 @@ internal class LincheckClassVisitor( // which should be put in front of the byte-code transformer chain, // so that it can correctly analyze the byte-code and compute required type-information mv = run { - val sv = SharedMemoryAccessTransformer(fileName, className, methodName, mv.newAdapter()) + val sv = SharedMemoryAccessTransformer(fileName, className, methodName, mv.newAdapter(), + interceptReadAccesses = (instrumentationMode == EXPERIMENTAL_MODEL_CHECKING), + ) val aa = AnalyzerAdapter(className, access, methodName, desc, sv) sv.analyzer = aa aa diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckJavaAgent.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckJavaAgent.kt index bff4751a4..416afe57d 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckJavaAgent.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/LincheckJavaAgent.kt @@ -13,8 +13,7 @@ package org.jetbrains.kotlinx.lincheck.transformation import net.bytebuddy.agent.ByteBuddyAgent import org.jetbrains.kotlinx.lincheck.canonicalClassName import org.jetbrains.kotlinx.lincheck.runInIgnoredSection -import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode.MODEL_CHECKING -import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode.STRESS +import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode.* import org.jetbrains.kotlinx.lincheck.transformation.LincheckClassFileTransformer.shouldTransform import org.jetbrains.kotlinx.lincheck.transformation.LincheckClassFileTransformer.transformedClassesStress import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent.INSTRUMENT_ALL_CLASSES @@ -22,14 +21,14 @@ import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent.instrumen import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent.instrumentationMode import org.jetbrains.kotlinx.lincheck.transformation.LincheckJavaAgent.instrumentedClasses import org.jetbrains.kotlinx.lincheck.util.readFieldViaUnsafe -import org.objectweb.asm.ClassReader -import org.objectweb.asm.ClassWriter import sun.misc.Unsafe import java.io.File -import java.lang.instrument.ClassFileTransformer -import java.lang.instrument.Instrumentation -import java.lang.reflect.Modifier -import java.security.ProtectionDomain +import org.objectweb.asm.* +import org.objectweb.asm.util.* +import java.io.* +import java.lang.instrument.* +import java.lang.reflect.* +import java.security.* import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.jar.JarFile @@ -49,16 +48,30 @@ internal inline fun withLincheckJavaAgent(instrumentationMode: InstrumentationMo internal enum class InstrumentationMode { /** - * In this mode, Lincheck transforms bytecode - * only to track coroutine suspensions. + * In this mode, Lincheck transforms bytecode only to track coroutine suspensions. */ STRESS, /** - * In this mode, Lincheck tracks - * all shared memory manipulations. + * In this mode, Lincheck tracks all shared memory manipulations. */ - MODEL_CHECKING + MODEL_CHECKING, + + /** + * In this mode, in addition to tracking all shared memory manipulations, + * Lincheck also can intercept read results. + */ + EXPERIMENTAL_MODEL_CHECKING, +} + +internal fun InstrumentationMode.isStressMode(): Boolean = when (this) { + STRESS -> true + else -> false +} + +internal fun InstrumentationMode.isModelCheckingMode(): Boolean = when (this) { + MODEL_CHECKING, EXPERIMENTAL_MODEL_CHECKING -> true + else -> false } /** @@ -74,8 +87,9 @@ internal object LincheckJavaAgent { private val instrumentation = ByteBuddyAgent.install() /** - * Determines how to transform classes; - * see [InstrumentationMode.STRESS] and [InstrumentationMode.MODEL_CHECKING]. + * Determines how to transform classes. + * + * @see [InstrumentationMode] */ lateinit var instrumentationMode: InstrumentationMode @@ -87,7 +101,7 @@ internal object LincheckJavaAgent { private var isBootstrapJarAddedToClasspath = false /** - * Names (canonical) of the classes that were instrumented since the last agent installation. + * TODO */ val instrumentedClasses = HashSet() @@ -124,7 +138,7 @@ internal object LincheckJavaAgent { // In the stress testing mode, Lincheck needs to track coroutine suspensions. // As an optimization, we remember the set of loaded classes that actually // have suspension points, so later we can re-transform only those classes. - instrumentationMode == STRESS -> { + instrumentationMode.isStressMode() -> { check(instrumentedClasses.isEmpty()) val classes = getLoadedClassesToInstrument().filter { val canonicalClassName = it.name @@ -138,7 +152,7 @@ internal object LincheckJavaAgent { } // In the model checking mode, Lincheck processes classes lazily, only when they are used. - instrumentationMode == MODEL_CHECKING -> { + instrumentationMode.isModelCheckingMode() -> { check(instrumentedClasses.isEmpty()) } } @@ -294,13 +308,23 @@ internal object LincheckJavaAgent { return } // Traverse static fields. - clazz.declaredFields - .filter { !it.type.isPrimitive } - .filter { Modifier.isStatic(it.modifiers) } - .mapNotNull { readFieldViaUnsafe(null, it, Unsafe::getObject) } - .forEach { - ensureObjectIsTransformed(it, processedObjects) + val staticFields = clazz.declaredFields.filter { Modifier.isStatic(it.modifiers) } + if (staticFields.isNotEmpty()) { + // ensure the class is loaded and initialized before reading its static field + Class.forName(clazz.name) + for (field in staticFields) { + val value = readFieldViaUnsafe(null, field, Unsafe::getObject) + if (!field.type.isPrimitive && value != null) { + ensureObjectIsTransformed(value, processedObjects) + } } + } + // Traverse interfaces. + clazz.interfaces.forEach { + if (it.name in instrumentedClasses) return // already instrumented + ensureClassHierarchyIsTransformed(it, processedObjects) + } + // Traverse superclass. clazz.superclass?.let { if (it.name in instrumentedClasses) return // already instrumented ensureClassHierarchyIsTransformed(it, processedObjects) @@ -327,13 +351,16 @@ internal object LincheckClassFileTransformer : ClassFileTransformer { * Notice that the transformation depends on the [InstrumentationMode]. * Additionally, this object caches bytes of non-transformed classes. */ - val transformedClassesModelChecking = ConcurrentHashMap() val transformedClassesStress = ConcurrentHashMap() + val transformedClassesModelChecking = ConcurrentHashMap() + val transformedClassesExperimentalModelChecking = ConcurrentHashMap() + val nonTransformedClasses = ConcurrentHashMap() private val transformedClassesCache get() = when (instrumentationMode) { STRESS -> transformedClassesStress MODEL_CHECKING -> transformedClassesModelChecking + EXPERIMENTAL_MODEL_CHECKING -> transformedClassesExperimentalModelChecking } override fun transform( @@ -350,7 +377,7 @@ internal object LincheckClassFileTransformer : ClassFileTransformer { // In the model checking mode, we transform classes lazily, // once they are used in the testing code. if (!INSTRUMENT_ALL_CLASSES && - instrumentationMode == MODEL_CHECKING && + instrumentationMode.isModelCheckingMode() && internalClassName.canonicalClassName !in instrumentedClasses) { return null } diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/MethodCallTransformer.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/MethodCallTransformer.kt index dd98091ad..1c8262213 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/MethodCallTransformer.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/MethodCallTransformer.kt @@ -10,9 +10,7 @@ package org.jetbrains.kotlinx.lincheck.transformation.transformers -import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.transformation.* -import org.jetbrains.kotlinx.lincheck.util.* import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type import org.objectweb.asm.Type.* @@ -28,6 +26,7 @@ internal class MethodCallTransformer( className: String, methodName: String, adapter: GeneratorAdapter, + private val interceptAtomicMethodCallResult: Boolean = false, ) : ManagedStrategyMethodVisitor(fileName, className, methodName, adapter) { override fun visitMethodInsn(opcode: Int, owner: String, name: String, desc: String, itf: Boolean) = adapter.run { @@ -77,17 +76,41 @@ internal class MethodCallTransformer( // STACK [INVOKEVIRTUAL]: owner, owner, className, methodName, codeLocation, methodId // STACK [INVOKESTATIC]: null, className, methodName, codeLocation, methodId pushArray(argumentLocals) - // STACK: ..., argumentsArray + // STACK: ..., arguments invokeStatic(Injections::beforeMethodCall) invokeBeforeEventIfPluginEnabled("method call $methodName", setMethodEventId = true) - // STACK [INVOKEVIRTUAL]: owner, arguments - // STACK [INVOKESTATIC] : arguments val methodCallEndLabel = newLabel() val handlerExceptionStartLabel = newLabel() visitTryCatchBlock(methodCallStartLabel, methodCallEndLabel, handlerExceptionStartLabel, null) visitLabel(methodCallStartLabel) - loadLocals(argumentLocals) - visitMethodInsn(opcode, owner, name, desc, itf) + if (interceptAtomicMethodCallResult) { + // STACK: shouldInterceptMethodResult + ifStatement( + condition = { /* already on stack */ }, + ifClause = { + val resultType = Type.getReturnType(desc) + if (opcode != INVOKESTATIC) { + pop() + } + // STACK : + invokeStatic(Injections::interceptMethodCallResult) + if (resultType == Type.VOID_TYPE) { + pop() + } else { + unbox(resultType) + } + }, + elseClause = { + loadLocals(argumentLocals) + visitMethodInsn(opcode, owner, name, desc, itf) + } + ) + } else { + // STACK: shouldInterceptMethodResult + pop() + loadLocals(argumentLocals) + visitMethodInsn(opcode, owner, name, desc, itf) + } visitLabel(methodCallEndLabel) // STACK [INVOKEVIRTUAL]: owner, arguments // STACK [INVOKESTATIC] : arguments diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ObjectCreationTransformer.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ObjectCreationTransformer.kt index 8c19616f1..112ebaa4b 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ObjectCreationTransformer.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ObjectCreationTransformer.kt @@ -15,6 +15,7 @@ import org.objectweb.asm.commons.GeneratorAdapter import org.objectweb.asm.commons.InstructionAdapter.OBJECT_TYPE import org.jetbrains.kotlinx.lincheck.canonicalClassName import org.jetbrains.kotlinx.lincheck.transformation.* +import org.objectweb.asm.Type import sun.nio.ch.lincheck.* /** @@ -25,29 +26,56 @@ internal class ObjectCreationTransformer( fileName: String, className: String, methodName: String, - adapter: GeneratorAdapter + adapter: GeneratorAdapter, + private val interceptObjectInitialization: Boolean = false, ) : ManagedStrategyMethodVisitor(fileName, className, methodName, adapter) { - override fun visitMethodInsn(opcode: Int, owner: String, name: String, desc: String, itf: Boolean) = - adapter.run { - if (name == "" && owner == "java/lang/Object") { - invokeIfInTestingCode( - original = { - visitMethodInsn(opcode, owner, name, desc, itf) - }, - code = { - val objectLocal = newLocal(OBJECT_TYPE) - dup() - storeLocal(objectLocal) - visitMethodInsn(opcode, owner, name, desc, itf) - loadLocal(objectLocal) - invokeStatic(Injections::afterNewObjectCreation) - } - ) - } else { + override fun visitMethodInsn(opcode: Int, owner: String, name: String, desc: String, itf: Boolean) = adapter.run { + val isInit = (opcode == INVOKESPECIAL && name == "") + if (!isInit) { + visitMethodInsn(opcode, owner, name, desc, itf) + return + } + // handle the common case separately + if (name == "" && owner == "java/lang/Object") { + invokeIfInTestingCode( + original = { + visitMethodInsn(opcode, owner, name, desc, itf) + }, + code = { + val objLocal = newLocal(OBJECT_TYPE).also { copyLocal(it) } + visitMethodInsn(opcode, owner, name, desc, itf) + loadLocal(objLocal) + invokeStatic(Injections::afterNewObjectCreation) + } + ) + return + } + if (!interceptObjectInitialization) { + visitMethodInsn(opcode, owner, name, desc, itf) + return + } + // TODO: this code handles the situation when Object. constructor is not called, + // because the base class, say class `Foo`, in some hierarchy is not transformed. + // In this case, the user code will only instrument call to `Foo.`, but + // the constructor of `Foo` itself will not be instrumented, + // and therefore the call to Object. will be missed. + invokeIfInTestingCode( + original = { + visitMethodInsn(opcode, owner, name, desc, itf) + }, + code = { + val objType = Type.getObjectType(owner) + val constructorType = Type.getType(desc) + val params = storeLocals(constructorType.argumentTypes) + val objLocal = newLocal(objType).also { copyLocal(it) } + params.forEach { loadLocal(it) } visitMethodInsn(opcode, owner, name, desc, itf) + loadLocal(objLocal) + invokeStatic(Injections::afterObjectInitialization) } - } + ) + } override fun visitIntInsn(opcode: Int, operand: Int) = adapter.run { adapter.visitIntInsn(opcode, operand) diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ReflectionTransformer.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ReflectionTransformer.kt new file mode 100644 index 000000000..325722a1c --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/ReflectionTransformer.kt @@ -0,0 +1,75 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This Source Code Form is subject to the terms of the + * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed + * with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.jetbrains.kotlinx.lincheck.transformation.transformers + +import org.jetbrains.kotlinx.lincheck.transformation.ManagedStrategyMethodVisitor +import org.jetbrains.kotlinx.lincheck.transformation.invokeIfInTestingCode +import org.jetbrains.kotlinx.lincheck.transformation.invokeStatic +import org.objectweb.asm.Opcodes.INVOKESTATIC +import org.objectweb.asm.commons.GeneratorAdapter +import sun.nio.ch.lincheck.EventTracker +import sun.nio.ch.lincheck.Injections + +/** + * [ReflectionTransformer] tracks some of the reflection method calls, + * injecting invocations of corresponding [EventTracker] methods. + */ +internal class ReflectionTransformer( + fileName: String, + className: String, + methodName: String, + adapter: GeneratorAdapter, + private val interceptArrayCopyMethod: Boolean = false, +) : ManagedStrategyMethodVisitor(fileName, className, methodName, adapter) { + + override fun visitMethodInsn(opcode: Int, owner: String, name: String, desc: String, itf: Boolean) = adapter.run { + when { + opcode == INVOKESTATIC && owner == "java/lang/reflect/Array" && name == "newInstance" -> + visitInvokeArrayNewInstance(opcode, owner, name, desc, itf) + + opcode == INVOKESTATIC && owner == "java/lang/System" && name == "arraycopy" -> + visitArrayCopyMethod(opcode, owner, name, desc, itf) + + else -> + visitMethodInsn(opcode, owner, name, desc, itf) + } + } + + private fun visitInvokeArrayNewInstance(opcode: Int, owner: String, name: String, descriptor: String, isInterface: Boolean) = adapter.run { + // TODO: should also call beforeNewObjectCreation? + visitMethodInsn(opcode, owner, name, descriptor, isInterface) + // STACK: array + invokeIfInTestingCode( + original = {}, + code = { + dup() + invokeStatic(Injections::afterNewObjectCreation) + } + ) + } + + private fun visitArrayCopyMethod(opcode: Int, owner: String, name: String, desc: String, itf: Boolean) = adapter.run { + if (!interceptArrayCopyMethod) { + visitMethodInsn(opcode, owner, name, desc, itf) + return + } + // STACK: srcArray, srcPos, dstArray, dstPos, length + invokeIfInTestingCode( + original = { + visitMethodInsn(opcode, owner, name, desc, itf) + }, + code = { + invokeStatic(Injections::onArrayCopy) + } + ) + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/SharedMemoryAccessTransformer.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/SharedMemoryAccessTransformer.kt index de1020306..b80be4a10 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/SharedMemoryAccessTransformer.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/transformation/transformers/SharedMemoryAccessTransformer.kt @@ -28,6 +28,7 @@ internal class SharedMemoryAccessTransformer( className: String, methodName: String, adapter: GeneratorAdapter, + private val interceptReadAccesses: Boolean = false, ) : ManagedStrategyMethodVisitor(fileName, className, methodName, adapter) { lateinit var analyzer: AnalyzerAdapter @@ -49,10 +50,11 @@ internal class SharedMemoryAccessTransformer( pushNull() push(owner) push(fieldName) + push(desc) loadNewCodeLocationId() push(true) // isStatic push(FinalFields.isFinalField(owner, fieldName)) // isFinal - // STACK: null, className, fieldName, codeLocation, isStatic, isFinal + // STACK: null, className, fieldName, typeDescriptor, codeLocation, isStatic, isFinal invokeStatic(Injections::beforeReadField) // STACK: isTracePointCreated ifStatement( @@ -60,9 +62,15 @@ internal class SharedMemoryAccessTransformer( ifClause = { invokeBeforeEventIfPluginEnabled("read static field") }, - elseClause = {}) + elseClause = {} + ) // STACK: - visitFieldInsn(opcode, owner, fieldName, desc) + if (interceptReadAccesses) { + invokeStatic(Injections::interceptReadResult) + unbox(getType(desc)) + } else { + visitFieldInsn(opcode, owner, fieldName, desc) + } // STACK: value invokeAfterRead(getType(desc)) // STACK: value @@ -82,10 +90,11 @@ internal class SharedMemoryAccessTransformer( // STACK: obj, obj push(owner) push(fieldName) + push(desc) loadNewCodeLocationId() push(false) // isStatic push(FinalFields.isFinalField(owner, fieldName)) // isFinal - // STACK: obj, obj, className, fieldName, codeLocation, isStatic, isFinal + // STACK: obj, obj, className, fieldName, typeDescriptor, codeLocation, isStatic, isFinal invokeStatic(Injections::beforeReadField) // STACK: obj, isTracePointCreated ifStatement( @@ -96,7 +105,13 @@ internal class SharedMemoryAccessTransformer( elseClause = {} ) // STACK: obj - visitFieldInsn(opcode, owner, fieldName, desc) + if (interceptReadAccesses) { + pop() + invokeStatic(Injections::interceptReadResult) + unbox(getType(desc)) + } else { + visitFieldInsn(opcode, owner, fieldName, desc) + } // STACK: value invokeAfterRead(getType(desc)) // STACK: value @@ -118,12 +133,13 @@ internal class SharedMemoryAccessTransformer( pushNull() push(owner) push(fieldName) + push(desc) loadLocal(valueLocal) box(valueType) loadNewCodeLocationId() push(true) // isStatic push(FinalFields.isFinalField(owner, fieldName)) // isFinal - // STACK: value, null, className, fieldName, value, codeLocation, isStatic, isFinal + // STACK: value, null, className, fieldName, typeDescriptor, value, codeLocation, isStatic, isFinal invokeStatic(Injections::beforeWriteField) // STACK: isTracePointCreated ifStatement( @@ -156,12 +172,13 @@ internal class SharedMemoryAccessTransformer( // STACK: obj, obj push(owner) push(fieldName) + push(desc) loadLocal(valueLocal) box(valueType) loadNewCodeLocationId() push(false) // isStatic push(FinalFields.isFinalField(owner, fieldName)) // isFinal - // STACK: obj, obj, className, fieldName, value, codeLocation, isStatic, isFinal + // STACK: obj, obj, className, fieldName, typeDescriptor, value, codeLocation, isStatic, isFinal invokeStatic(Injections::beforeWriteField) // STACK: isTracePointCreated ifStatement( @@ -196,12 +213,13 @@ internal class SharedMemoryAccessTransformer( visitInsn(opcode) }, code = { - // STACK: array: Array, index: Int + // STACK: array, index val arrayElementType = getArrayElementType(opcode) dup2() - // STACK: array: Array, index: Int, array: Array, index: Int + // STACK: array, index, array, index + push(arrayElementType.descriptor) loadNewCodeLocationId() - // STACK: array: Array, index: Int, array: Array, index: Int, codeLocation: Int + // STACK: array, index, array, index, typeDescriptor, codeLocation invokeStatic(Injections::beforeReadArray) ifStatement( condition = { /* already on stack */ }, @@ -210,8 +228,15 @@ internal class SharedMemoryAccessTransformer( }, elseClause = {} ) - // STACK: array: Array, index: Int - visitInsn(opcode) + // STACK: array, index + if (interceptReadAccesses) { + pop() + pop() + invokeStatic(Injections::interceptReadResult) + unbox(arrayElementType) + } else { + visitInsn(opcode) + } // STACK: value invokeAfterRead(arrayElementType) // STACK: value @@ -225,17 +250,18 @@ internal class SharedMemoryAccessTransformer( visitInsn(opcode) }, code = { - // STACK: array: Array, index: Int, value: Object + // STACK: array, index, value val arrayElementType = getArrayElementType(opcode) val valueLocal = newLocal(arrayElementType) // we cannot use DUP as long/double require DUP2 storeLocal(valueLocal) - // STACK: array: Array, index: Int + // STACK: array, index dup2() - // STACK: array: Array, index: Int, array: Array, index: Int + // STACK: array, index, array, index + push(arrayElementType.descriptor) loadLocal(valueLocal) box(arrayElementType) loadNewCodeLocationId() - // STACK: array: Array, index: Int, array: Array, index: Int, value: Object, codeLocation: Int + // STACK: array, index, array, index, typeDescriptor, value, codeLocation invokeStatic(Injections::beforeWriteArray) ifStatement( condition = { /* already on stack */ }, @@ -244,9 +270,9 @@ internal class SharedMemoryAccessTransformer( }, elseClause = {} ) - // STACK: array: Array, index: Int + // STACK: array, index loadLocal(valueLocal) - // STACK: array: Array, index: Int, value: Object + // STACK: array, index, value visitInsn(opcode) // STACK: invokeStatic(Injections::afterWrite) diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/AtomicMethods.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/AtomicMethods.kt index 38b641047..e835a0716 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/AtomicMethods.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/AtomicMethods.kt @@ -14,6 +14,8 @@ import java.util.concurrent.atomic.* import org.jetbrains.kotlinx.lincheck.util.AtomicMethodKind.* import org.jetbrains.kotlinx.lincheck.util.MemoryOrdering.* import java.lang.invoke.VarHandle +import org.objectweb.asm.Type +import org.objectweb.asm.commons.InstructionAdapter.OBJECT_TYPE internal data class AtomicMethodDescriptor( val kind: AtomicMethodKind, @@ -66,7 +68,6 @@ internal fun isAtomic(receiver: Any?) = receiver is kotlinx.atomicfu.AtomicInt || receiver is kotlinx.atomicfu.AtomicLong - internal fun isAtomicClass(className: String) = // java.util.concurrent className == "java.util.concurrent.atomic.AtomicInteger" || @@ -94,7 +95,6 @@ internal fun isAtomicArray(receiver: Any?) = receiver is kotlinx.atomicfu.AtomicIntArray || receiver is kotlinx.atomicfu.AtomicLongArray - internal fun isAtomicArrayClass(className: String) = // java.util.concurrent className == "java.util.concurrent.atomic.AtomicReferenceArray" || @@ -109,6 +109,17 @@ internal fun isAtomicArrayClass(className: String) = internal fun isAtomicArrayMethod(className: String, methodName: String) = isAtomicArrayClass(className) && methodName in atomicMethods +internal fun getAtomicType(atomic: Any?): Type? = when (atomic) { + is AtomicReference<*> -> OBJECT_TYPE + is AtomicBoolean -> Type.BOOLEAN_TYPE + is AtomicInteger -> Type.INT_TYPE + is AtomicLong -> Type.LONG_TYPE + is AtomicReferenceArray<*> -> OBJECT_TYPE + is AtomicIntegerArray -> Type.INT_TYPE + is AtomicLongArray -> Type.LONG_TYPE + else -> null +} + internal fun isAtomicFieldUpdater(obj: Any?) = obj is AtomicReferenceFieldUpdater<*, *> || obj is AtomicIntegerFieldUpdater<*> || @@ -141,6 +152,19 @@ internal fun isUnsafeClass(className: String) = internal fun isUnsafeMethod(className: String, methodName: String) = isUnsafeClass(className) && methodName in unsafeMethods +internal fun parseUnsafeMethodAccessType(methodName: String): Type? = when { + "Boolean" in methodName -> Type.BOOLEAN_TYPE + "Byte" in methodName -> Type.BYTE_TYPE + "Short" in methodName -> Type.SHORT_TYPE + "Int" in methodName -> Type.INT_TYPE + "Long" in methodName -> Type.LONG_TYPE + "Float" in methodName -> Type.FLOAT_TYPE + "Double" in methodName -> Type.DOUBLE_TYPE + "Reference" in methodName -> OBJECT_TYPE + "Object" in methodName -> OBJECT_TYPE + else -> null +} + private val atomicMethods = mapOf( // get "get" to AtomicMethodDescriptor(GET, VOLATILE), diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Computable.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Computable.kt new file mode 100644 index 000000000..5a8f636a6 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Computable.kt @@ -0,0 +1,230 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +import kotlin.reflect.KProperty + +interface Computable { + fun initialize() {} + fun compute() + fun invalidate() {} + fun reset() +} + +interface Incremental { + fun add(element: T) +} + +fun computable(builder: () -> T) = + ComputableNode(builder) + +class ComputableNode(val builder: () -> T) : Computable { + + private var _value: T? = null + + val value: T get() = + _value ?: builder().also { _value = it } + + private data class Dependency( + val computable: ComputableNode<*>, + // computations of soft dependencies are not enforced + val soft: Boolean, + ) + + private val dependencies = mutableListOf() + + private data class Subscriber( + val computable: ComputableNode<*>, + // tells whether this subscriber should be + // recursively invalidated on invalidation + val invalidating: Boolean, + // tells whether this subscriber should be + // recursively reset on resetting + val resetting: Boolean, + ) + + private val subscribers = mutableListOf() + + private enum class State { + UNSET, INITIALIZED, COMPUTED + } + + private var state = State.UNSET + + val unset: Boolean + get() = (state.ordinal == State.UNSET.ordinal) + + val initialized: Boolean + get() = (state.ordinal >= State.INITIALIZED.ordinal) + + val computed: Boolean + get() = (state.ordinal >= State.COMPUTED.ordinal) + + override fun initialize() { + if (!initialized) { + initializeDependencies() + initializeValue() + } + } + + private fun initializeValue() { + value.initialize() + state = State.INITIALIZED + } + + override fun compute() { + if (!computed) { + computeDependencies() + computeValue() + } + } + + private fun computeValue() { + value.compute() + state = State.COMPUTED + } + + override fun invalidate() { + if (state == State.COMPUTED) { + invalidateValue() + invalidateSubscribers() + } + } + + private fun invalidateValue() { + value.invalidate() + state = State.INITIALIZED + } + + override fun reset() { + if (!unset) { + resetValue() + resetSubscribers() + } + } + + private fun resetValue() { + value.reset() + state = State.UNSET + } + + fun setComputed() { + state = State.COMPUTED + invalidateSubscribers() + } + + fun setComputed(value: T) { + this._value = value + setComputed() + } + + fun addDependency(dependency: ComputableNode<*>, + soft: Boolean = false, + invalidating: Boolean = false, + resetting: Boolean = false + ) { + dependencies.add(Dependency(dependency, soft)) + dependency.subscribers.add(Subscriber(this, invalidating, resetting)) + } + + operator fun getValue(thisRef: Any?, property: KProperty<*>): T { + initialize() + compute() + return value + } + + private fun initializeDependencies() { + traverseDependencies { dependency -> + if (!dependency.soft && !dependency.computable.initialized) { + dependency.computable.initializeValue() + true + } else false + } + } + + private fun computeDependencies() { + traverseDependencies { dependency -> + if (!dependency.soft && !dependency.computable.computed) { + dependency.computable.computeValue() + true + } else false + } + } + + private fun invalidateSubscribers() { + traverseSubscribers { subscriber -> + if (subscriber.invalidating && subscriber.computable.computed) { + subscriber.computable.invalidateValue() + true + } else false + } + } + + private fun resetSubscribers() { + traverseSubscribers { subscriber -> + if (subscriber.resetting && !subscriber.computable.unset) { + subscriber.computable.resetValue() + true + } else false + } + } + + private fun traverseDependencies(action: (Dependency) -> Boolean) { + val stack = ArrayDeque>(listOf(this)) + val visited = mutableSetOf>(this) + while (stack.isNotEmpty()) { + val computable = stack.removeLast() + computable.dependencies.forEach { dependency -> + val unvisited = visited.add(dependency.computable) + if (unvisited) { + val expandable = action(dependency) + if (expandable) stack.add(dependency.computable) + } + } + } + } + + private fun traverseSubscribers(action: (Subscriber) -> Boolean) { + val stack = ArrayDeque>(listOf(this)) + val visited = mutableSetOf>(this) + while (stack.isNotEmpty()) { + val computable = stack.removeLast() + computable.subscribers.forEach { subscriber -> + val unvisited = visited.add(subscriber.computable) + if (unvisited) { + val expandable = action(subscriber) + if (expandable) stack.add(subscriber.computable) + } + } + } + } + +} + +fun ComputableNode.dependsOn(dependency: ComputableNode<*>, + soft: Boolean = false, + invalidating: Boolean = false, + resetting: Boolean = false +): ComputableNode { + addDependency(dependency, soft, invalidating, resetting) + return this +} + diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/IntMap.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/IntMap.kt new file mode 100644 index 000000000..66a73f696 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/IntMap.kt @@ -0,0 +1,335 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +import kotlin.math.max + +interface IntMap { + + interface Entry { + val key: Int + val value: T + } + + val keys: Set + + val values: Collection + + val entries: Set> + + val size: Int + + fun isEmpty(): Boolean + + fun containsKey(key: Int): Boolean + + fun containsValue(value: @UnsafeVariance T): Boolean + + operator fun get(key: Int): T? + +} + +interface MutableIntMap: IntMap { + + interface MutableEntry : IntMap.Entry { + fun setValue(newValue: T): T + } + + override val keys: MutableSet + + override val values: MutableCollection + + override val entries: MutableSet> + + fun put(key: Int, value: T): T? + + fun remove(key: Int) + + fun clear() + +} + +operator fun IntMap.Entry.component1(): Int = key +operator fun IntMap.Entry.component2(): T = value + +fun intMapOf(vararg pairs: Pair) : IntMap = + mutableIntMapOf(*pairs) + +fun mutableIntMapOf(vararg pairs: Pair) : MutableIntMap = + ArrayIntMap(*pairs) + +fun IntMap.forEach(action: (IntMap.Entry) -> Unit) = + entries.forEach(action) + +fun IntMap.map(transform: (IntMap.Entry) -> R) = + entries.map(transform) + +fun IntMap.mapNotNull(transform: (IntMap.Entry) -> R?) = + entries.mapNotNull(transform) + +fun IntMap.mapValues(transform: (IntMap.Entry) -> R) = + entries + .map { entry -> entry.key to transform(entry) } + .let { intMapOf(*it.toTypedArray()) } + +inline fun MutableIntMap.getOrPut(key: Int, defaultValue: () -> T): T { + get(key)?.let { return it } + return defaultValue().also { put(key, it) } +} + +operator fun MutableIntMap.set(key: Int, value: T) { + put(key, value) +} + +fun MutableIntMap.update(key: Int, default: T, transform: (T) -> T) { + // TODO: could it be done with a single lookup in a map? + put(key, get(key)?.let(transform) ?: default) +} + +fun MutableIntMap.mergeReduce(other: IntMap, reduce: (T, T) -> T) { + other.forEach { (key, value) -> + update(key, default = value) { reduce(it, value) } + } +} + +class ArrayIntMap(capacity: Int) : MutableIntMap { + + override var size: Int = 0 + private set + + private val array = MutableList(capacity) { null } + + // we keep a separate bitmap identifying mapped keys + // to distinguish user-provided `null` value from `null` as not-yet-mapped value + private var bitmap = BooleanArray(capacity) { false } + + val capacity: Int + get() = array.size + + override val keys: MutableSet + get() = KeySet() + + override val values: MutableCollection + get() = ValueCollection() + + override val entries: MutableSet> + get() = EntrySet() + + constructor(vararg pairs: Pair) : this(pairs.calculateCapacity()) { + if (capacity == 0) + return + pairs.forEach { (key, value) -> + ++size + array[key] = value + bitmap[key] = true + } + } + + override fun isEmpty(): Boolean = + size == 0 + + override fun containsKey(key: Int): Boolean = + bitmap[key] + + override fun containsValue(value: T): Boolean { + for (i in array.indices) { + if (bitmap[i] && value == array[i]) + return true + } + return false + } + + override fun get(key: Int): T? = + array.getOrNull(key) + + override fun put(key: Int, value: T): T? { + val oldValue = get(key) + if (key >= array.size) { + val newCapacity = key + 1 + expand(newCapacity) + } + if (!bitmap[key]) { + size++ + } + array[key] = value + bitmap[key] = true + return oldValue + } + + override fun remove(key: Int) { + if (bitmap[key]) { + size-- + } + array[key] = null + bitmap[key] = false + } + + override fun clear() { + size = 0 + array.clear() + bitmap.fill(false) + } + + fun copy() = ArrayIntMap(capacity).also { + for (i in 0 until capacity) { + if (!bitmap[i]) + continue + it[i] = this[i] as T + } + } + + private fun expand(newCapacity: Int) { + require(newCapacity > capacity) + array.expand(newCapacity, null) + bitmap = BooleanArray(newCapacity) { i -> i < bitmap.size && bitmap[i] } + } + + override fun equals(other: Any?): Boolean { + if (other !is ArrayIntMap<*>) + return false + for (i in 0 until max(array.size, other.array.size)) { + if (this[i] != other[i] || bitmap[i] != other.bitmap[i]) + return false + } + return true + } + + override fun hashCode(): Int { + var hashCode = 1 + for (i in array.indices) { + if (!bitmap[i]) + continue + hashCode = 31 * hashCode + (this[i]?.hashCode() ?: 0) + } + return hashCode + } + + override fun toString(): String = + keys.map { key -> key to get(key) }.toString() + + private inner class KeySet : AbstractMutableSet() { + + override val size: Int + get() = this@ArrayIntMap.size + + override fun contains(element: Int): Boolean = + this@ArrayIntMap.containsKey(element) + + override fun add(element: Int): Boolean { + throw UnsupportedOperationException("Unsupported operation.") + } + + // TODO: cannot override because of weird compiler bug (probably due to boxing and override) + // override fun remove(element: Int): Boolean { + // return this@ArrayIntMap.containsKey(element).also { + // this@ArrayIntMap.remove(element) + // } + // } + + override fun iterator() = object : IteratorBase() { + override fun getElement(key: Int): Int { + return key + } + } + } + + private inner class ValueCollection : AbstractMutableCollection() { + + override val size: Int + get() = this@ArrayIntMap.size + + override fun add(element: T): Boolean { + throw UnsupportedOperationException("Unsupported operation.") + } + + override fun iterator() = object : IteratorBase() { + override fun getElement(key: Int): T { + return this@ArrayIntMap[key] as T + } + } + } + + private inner class EntrySet : AbstractMutableSet>() { + + override val size: Int + get() = this@ArrayIntMap.size + + override fun add(element: MutableIntMap.MutableEntry): Boolean { + val (key, value) = element + val prev = this@ArrayIntMap.put(key, value) + return (prev != value) + } + + override fun remove(element: MutableIntMap.MutableEntry): Boolean { + val (key, value) = element + val prev = this@ArrayIntMap[key] + this@ArrayIntMap.remove(key) + return (prev == value) + } + + override fun iterator() = object : IteratorBase>() { + override fun getElement(key: Int): MutableIntMap.MutableEntry { + val value = this@ArrayIntMap[key] as T + return Entry(key, value) + } + } + } + + private data class Entry( + override val key: Int, + override var value: T, + ) : MutableIntMap.MutableEntry { + override fun setValue(newValue: T): T { + val prev = value + value = newValue + return prev + } + } + + private abstract inner class IteratorBase : AbstractIterator(), MutableIterator { + private var index: Int = -1 + + override fun computeNext() { + while (++index < this@ArrayIntMap.capacity) { + if (this@ArrayIntMap.containsKey(index)) { + setNext(getElement(index)) + return + } + } + done() + } + + override fun remove() { + throw UnsupportedOperationException("Unsupported operation.") + } + + abstract fun getElement(key: Int): E + + } + + companion object { + private fun Array>.calculateCapacity(): Int { + require(all { (i, _) -> i >= 0 }) + return 1 + (maxOfOrNull { (i, _) -> i } ?: -1) + } + } + +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/SortedList.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/SortedList.kt new file mode 100644 index 000000000..1566813af --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/SortedList.kt @@ -0,0 +1,163 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.Relation + +interface SortedList> : List { + + override fun contains(element: @UnsafeVariance T): Boolean = + binarySearch(element) >= 0 + + override fun containsAll(elements: Collection<@UnsafeVariance T>): Boolean = + elements.all { contains(it) } + + override fun indexOf(element: @UnsafeVariance T): Int { + var i: Int + var j = size + do { + i = j + j = binarySearch(element, toIndex = i) + } while (j >= 0) + return i + } + + override fun lastIndexOf(element: @UnsafeVariance T): Int { + var i: Int + var j = 0 + do { + i = j + j = binarySearch(element, fromIndex = i) + } while (j >= 0) + return i + } + +} + +interface SortedMutableList> : MutableList, SortedList + +private class SortedListImpl>(val list: List) : SortedList { + + init { + require(list.isSorted()) + } + + override val size: Int + get() = list.size + + override fun isEmpty(): Boolean = + list.isEmpty() + + override fun get(index: Int): T = + list[index] + + override fun subList(fromIndex: Int, toIndex: Int): SortedList { + return SortedListImpl(list.subList(fromIndex, toIndex)) + } + + override fun iterator(): Iterator = + list.iterator() + + override fun listIterator(): ListIterator = + list.listIterator() + + override fun listIterator(index: Int): ListIterator = + list.listIterator(index) + +} + +class SortedArrayList> : ArrayList, SortedMutableList { + + constructor() : super() + + constructor(initialCapacity: Int) : super(initialCapacity) + + constructor(elements: Collection) : super(elements) { + require(isSorted()) { "Expected sorted list" } + } + + override fun add(element: T): Boolean { + require(isNotEmpty() implies { last() <= element }) + return super.add(element) + } + + override fun add(index: Int, element: T) { + require((index - 1 >= 0) implies { get(index - 1) <= element }) + require((index < size) implies { element <= get(index) }) + super.add(index, element) + } + + override fun addAll(elements: Collection): Boolean { + val oldSize = size + return super.addAll(elements).ensure { + isSorted(fromIndex = oldSize) + } + } + + override fun addAll(index: Int, elements: Collection): Boolean { + return super.addAll(index, elements).ensure { + val lastIndex = index + elements.size + val fromIndex = if (index - 1 >= 0) (index - 1) else index + val toIndex = if (lastIndex + 1 < size) (lastIndex + 1) else lastIndex + isSorted(fromIndex = fromIndex, toIndex = toIndex) + } + } + + override fun set(index: Int, element: T): T { + require((index - 1 >= 0) implies { get(index - 1) <= element }) + require((index + 1 < size) implies { element <= get(index + 1) }) + return super.set(index, element) + } + + override fun contains(element: T): Boolean = + super.contains(element) + + override fun containsAll(elements: Collection): Boolean = + super.containsAll(elements) + + override fun indexOf(element: T): Int = + super.indexOf(element) + + override fun lastIndexOf(element: T): Int = + super.lastIndexOf(element) + +} + +fun> sortedListOf(vararg elements: T): SortedList = + SortedListImpl(elements.asList()) + +fun> sortedMutableListOf(vararg elements: T): SortedMutableList = + sortedArrayListOf(*elements) + +fun> sortedArrayListOf(vararg elements: T): SortedArrayList = + SortedArrayList(elements.asList()) + + +fun > List.isSorted(fromIndex : Int = 0, toIndex : Int = size): Boolean = + isChain(fromIndex, toIndex) { x, y -> x <= y } + +fun List.isChain(fromIndex : Int = 0, toIndex : Int = size, relation: Relation): Boolean { + for (i in fromIndex until toIndex - 1) { + if (!relation(get(i), get(i + 1))) + return false + } + return true +} \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Threads.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Threads.kt new file mode 100644 index 000000000..c7ae524cd --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Threads.kt @@ -0,0 +1,25 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +typealias ThreadID = Int +typealias ThreadMap = IntMap +typealias MutableThreadMap = MutableIntMap diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/UnsafeHolder.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/UnsafeHolder.kt index 6d8623a78..ae1e573dc 100644 --- a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/UnsafeHolder.kt +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/UnsafeHolder.kt @@ -34,4 +34,51 @@ internal inline fun readFieldViaUnsafe(obj: Any?, field: Field, getter: Unsa val offset = UnsafeHolder.UNSAFE.objectFieldOffset(field) return UnsafeHolder.UNSAFE.getter(obj, offset) } -} \ No newline at end of file +} + +internal inline fun writeFieldViaUnsafe(obj: Any?, field: Field, value: Any?, setter: Unsafe.(Any?, Long, Any?) -> Unit) { + if (Modifier.isStatic(field.modifiers)) { + val base = UnsafeHolder.UNSAFE.staticFieldBase(field) + val offset = UnsafeHolder.UNSAFE.staticFieldOffset(field) + return UnsafeHolder.UNSAFE.setter(base, offset, value) + } else { + val offset = UnsafeHolder.UNSAFE.objectFieldOffset(field) + return UnsafeHolder.UNSAFE.setter(obj, offset, value) + } +} + +internal fun readField(obj: Any?, field: Field): Any? { + if (!field.type.isPrimitive) { + return readFieldViaUnsafe(obj, field, Unsafe::getObject) + } + return when (field.type) { + Boolean::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getBoolean) + Byte::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getByte) + Char::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getChar) + Short::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getShort) + Int::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getInt) + Long::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getLong) + Double::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getDouble) + Float::class.javaPrimitiveType -> readFieldViaUnsafe(obj, field, Unsafe::getFloat) + else -> error("No more types expected") + } +} + +internal fun writeField(obj: Any?, field: Field, value: Any?) { + if (!field.type.isPrimitive) { + return writeFieldViaUnsafe(obj, field, value, Unsafe::putObject) + } + // TODO: clean-up! + return when (field.type) { + Boolean::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putBoolean(obj, field, value as Boolean) } + Byte::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putByte(obj, field, value as Byte) } + Char::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putChar(obj, field, value as Char) } + Short::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putShort(obj, field, value as Short) } + Int::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putInt(obj, field, value as Int) } + Long::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putLong(obj, field, value as Long) } + Double::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putDouble(obj, field, value as Double) } + Float::class.javaPrimitiveType -> writeFieldViaUnsafe(obj, field, value) { obj, field, value -> putFloat(obj, field, value as Float) } + else -> error("No more types expected") + } +} + diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Utils.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Utils.kt new file mode 100644 index 000000000..516affe4c --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Utils.kt @@ -0,0 +1,306 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +import org.objectweb.asm.Type +import org.objectweb.asm.Type.* +import org.objectweb.asm.commons.InstructionAdapter.OBJECT_TYPE +import kotlin.reflect.KClass + +fun Boolean.toInt(): Int = this.compareTo(false) + +fun Byte.toBoolean(): Boolean = when (this) { + 0.toByte() -> false + 1.toByte() -> true + else -> throw IllegalArgumentException("Byte $this is not a Boolean") +} + +fun Int.toBoolean() = (this == 1) + +infix fun Boolean.implies(other: Boolean): Boolean = + !this || other + +infix fun Boolean.implies(other: () -> Boolean): Boolean = + !this || other() + +infix fun Boolean.equivalent(other: Boolean): Boolean = + (this && other) || (!this && !other) + +inline fun T.runIf(boolean: Boolean, block: T.() -> T): T = + if (boolean) block() else this + +inline fun Any?.satisfies(predicate: T.() -> Boolean): Boolean = + this is T && predicate(this) + +inline fun Any?.refine(predicate: T.() -> Boolean): T? = + if (this is T && predicate(this)) this else null + +inline fun List.refine(): List? { + return if (all { it is T }) (this as List) else null +} + +fun List.findMapped(transform: (T) -> R?): R? { + for (element in this) { + transform(element)?.let { return it } + } + return null +} + +inline fun Boolean.ensure(): Boolean { + // TODO: add contract? + // contract { + // returns() implies this + // } + check(this) + return this +} + +inline fun Boolean.ensure(lazyMessage: () -> Any): Boolean { + check(this, lazyMessage) + return this +} + +inline fun Boolean.ensureFalse(): Boolean { + check(!this) + return this +} + +inline fun Boolean.ensureFalse(lazyMessage: () -> Any): Boolean { + check(!this, lazyMessage) + return this +} + +inline fun T?.ensureNull(): T? { + check(this == null) + return this +} + +inline fun T?.ensureNull(lazyMessage: (T?) -> Any): T? { + check(this == null) { lazyMessage(this) } + return this +} + +inline fun T?.ensureNotNull(): T { + checkNotNull(this) + return this +} + +inline fun T?.ensureNotNull(lazyMessage: () -> Any): T { + checkNotNull(this, lazyMessage) + return this +} + +inline fun T.ensure(predicate: (T) -> Boolean): T { + check(predicate(this)) + return this +} + +inline fun T.ensure(predicate: (T) -> Boolean, lazyMessage: (T?) -> Any): T { + check(predicate(this)) { lazyMessage(this) } + return this +} + + +private fun rangeCheck(size: Int, fromIndex: Int, toIndex: Int) { + when { + fromIndex > toIndex -> throw IllegalArgumentException("fromIndex ($fromIndex) is greater than toIndex ($toIndex).") + fromIndex < 0 -> throw IndexOutOfBoundsException("fromIndex ($fromIndex) is less than zero.") + toIndex > size -> throw IndexOutOfBoundsException("toIndex ($toIndex) is greater than size ($size).") + } +} + +fun List.binarySearch(fromIndex: Int = 0, toIndex: Int = size, predicate: (T) -> Boolean): Int { + rangeCheck(size, fromIndex, toIndex) + var low = fromIndex - 1 + var high = toIndex + while (low + 1 < high) { + val mid = (low + high).ushr(1) // safe from overflows + if (predicate(get(mid))) + high = mid + else + low = mid + } + return high +} + +fun MutableList.expand(size: Int, defaultValue: T) { + if (size > this.size) { + addAll(List(size - this.size) { defaultValue }) + } +} + +fun MutableList.cut(index: Int) { + require(index <= size) + subList(index, size).clear() +} + +fun MutableMap.updateInplace(key: K, default: V, apply: V.() -> Unit) { + computeIfAbsent(key) { default }.also(apply) +} + +fun MutableMap.update(key: K, default: V, transform: (V) -> V) { + // TODO: could it be done with a single lookup in a map? + put(key, get(key)?.let(transform) ?: default) +} + +fun MutableMap.mergeReduce(other: Map, reduce: (V, V) -> V) { + other.forEach { (key, value) -> + update(key, default = value) { reduce(it, value) } + } +} + +fun List.squash(relation: (T, T) -> Boolean): List> { + if (isEmpty()) + return emptyList() + val squashed = arrayListOf>() + var pos = 0 + while (pos < size) { + val i = pos + var j = i + while (++j < size) { + if (!relation(get(j - 1), get(j))) + break + } + squashed.add(subList(i, j)) + pos = j + } + return squashed +} + +fun List>.cartesianProduct(): Sequence> = sequence { + val sequences = this@cartesianProduct + if (sequences.isEmpty()) + return@sequence + + // prepare iterators of argument sequences + val iterators = sequences.map { it.iterator() } + .toMutableList() + // compute the first element of each argument sequence, + // while also count the number of non-empty sequences + var count = 0 + val elements = iterators.map { + if (it.hasNext()) it.next().also { count++ } else null + }.toMutableList() + // return the empty sequence if at least one of the argument sequences is empty + if (count != iterators.size) + return@sequence + // can cast here since the list can only contain elements + // returned by iterators' `next()` function + elements as MutableList + + // produce tuples in a loop + while (true) { + // yield current tuple (make a copy) + yield(elements.toMutableList()) + // prepare the next tuple: + // while the last sequence has elements, spawn it + if (iterators.last().hasNext()) { + elements[iterators.lastIndex] = iterators.last().next() + continue + } + // otherwise, reset the last sequence iterator, + // advance a preceding sequence, and repeat this process + // until we find a non-exceeded sequence + var idx = iterators.indices.last + while (idx >= 0 && !iterators[idx].hasNext()) { + iterators[idx] = sequences[idx].iterator() + elements[idx] = iterators[idx].next() + idx -= 1 + } + // if all sequences have been exceeded, return + if (idx < 0) + return@sequence + // otherwise, advance the non-exceeded sequence + elements[idx] = iterators[idx].next() + } +} + +class UnreachableException(message: String?): Exception(message) + +fun unreachable(message: String? = null): Nothing { + throw UnreachableException(message) +} + +internal fun Type.getKClass(): KClass<*> = when (sort) { + Type.INT -> Int::class + Type.BYTE -> Byte::class + Type.SHORT -> Short::class + Type.LONG -> Long::class + Type.FLOAT -> Float::class + Type.DOUBLE -> Double::class + Type.CHAR -> Char::class + Type.BOOLEAN -> Boolean::class + Type.OBJECT -> when (this) { + INT_TYPE_BOXED -> Int::class + BYTE_TYPE_BOXED -> Byte::class + SHORT_TYPE_BOXED -> Short::class + LONG_TYPE_BOXED -> Long::class + CHAR_TYPE_BOXED -> Char::class + BOOLEAN_TYPE_BOXED -> Boolean::class + else -> Any::class + } + Type.ARRAY -> when (elementType.sort) { + Type.INT -> IntArray::class + Type.BYTE -> ByteArray::class + Type.SHORT -> ShortArray::class + Type.LONG -> LongArray::class + Type.FLOAT -> FloatArray::class + Type.DOUBLE -> DoubleArray::class + Type.CHAR -> CharArray::class + Type.BOOLEAN -> BooleanArray::class + else -> Array::class + } + else -> throw IllegalArgumentException() +} + +internal fun KClass<*>.getType(): Type = when (this) { + Int::class -> INT_TYPE + Byte::class -> BYTE_TYPE + Short::class -> SHORT_TYPE + Long::class -> LONG_TYPE + Float::class -> FLOAT_TYPE + Double::class -> DOUBLE_TYPE + Char::class -> CHAR_TYPE + Boolean::class -> BOOLEAN_TYPE + else -> OBJECT_TYPE +} + +internal fun KClass<*>.getArrayElementType(): Type = when (this) { + IntArray::class -> INT_TYPE + ByteArray::class -> BYTE_TYPE + ShortArray::class -> SHORT_TYPE + LongArray::class -> LONG_TYPE + FloatArray::class -> FLOAT_TYPE + DoubleArray::class -> DOUBLE_TYPE + CharArray::class -> CHAR_TYPE + BooleanArray::class -> BOOLEAN_TYPE + Array::class -> OBJECT_TYPE + // TODO: should we handle atomic arrays? + + else -> throw IllegalArgumentException("Argument is not array") +} + +internal val INT_TYPE_BOXED = Type.getType("Ljava/lang/Integer") +internal val LONG_TYPE_BOXED = Type.getType("Ljava/lang/Long") +internal val SHORT_TYPE_BOXED = Type.getType("Ljava/lang/Short") +internal val BYTE_TYPE_BOXED = Type.getType("Ljava/lang/Byte") +internal val CHAR_TYPE_BOXED = Type.getType("Ljava/lang/Character") +internal val BOOLEAN_TYPE_BOXED = Type.getType("Ljava/lang/Boolean") \ No newline at end of file diff --git a/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/VectorClock.kt b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/VectorClock.kt new file mode 100644 index 000000000..32e93c7a0 --- /dev/null +++ b/src/jvm/main/org/jetbrains/kotlinx/lincheck/util/VectorClock.kt @@ -0,0 +1,154 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck.util + +import org.jetbrains.kotlinx.lincheck.execution.HBClock +import org.jetbrains.kotlinx.lincheck.execution.emptyClock +import kotlin.math.* + +interface VectorClock { + fun isEmpty(): Boolean + + operator fun get(tid: ThreadID): Int +} + +interface MutableVectorClock : VectorClock { + operator fun set(tid: ThreadID, timestamp: Int) + + fun increment(tid: ThreadID, n: Int) { + set(tid, get(tid) + n) + } + + fun merge(other: VectorClock) + + fun clear() +} + +fun VectorClock.observes(tid: ThreadID, timestamp: Int): Boolean = + timestamp <= get(tid) + +fun MutableVectorClock.increment(tid: ThreadID) { + increment(tid, 1) +} + +operator fun VectorClock.plus(other: VectorClock): MutableVectorClock = + copy().apply { merge(other) } + +fun VectorClock(capacity: Int = 0): VectorClock = + MutableVectorClock(capacity) + +fun MutableVectorClock(capacity: Int = 0): MutableVectorClock = + IntArrayClock(capacity) + +fun VectorClock.copy(): MutableVectorClock { + // TODO: make VectorClock sealed interface? + check(this is IntArrayClock) + return copy() +} + +private class IntArrayClock(capacity: Int = 0) : MutableVectorClock { + var clock = emptyIntArrayClock(capacity) + + val capacity: Int + get() = clock.size + + override fun isEmpty(): Boolean = + clock.all { it == -1 } + + override fun get(tid: ThreadID): Int = + if (tid < capacity) clock[tid] else -1 + + override fun set(tid: ThreadID, timestamp: Int) { + expandIfNeeded(tid) + clock[tid] = timestamp + } + + override fun increment(tid: ThreadID, n: Int) { + expandIfNeeded(tid) + clock[tid] += n + } + + override fun merge(other: VectorClock) { + // TODO: make VectorClock sealed interface? + check(other is IntArrayClock) + if (capacity < other.capacity) { + expand(other.capacity) + } + for (i in 0 until capacity) { + clock[i] = max(clock[i], other[i]) + } + } + + override fun clear() { + clock.fill(-1) + } + + private fun expand(newCapacity: Int) { + require(newCapacity > capacity) + val newClock = emptyIntArrayClock(newCapacity) + copyInto(newClock) + clock = newClock + } + + private fun expandIfNeeded(tid: ThreadID) { + if (tid >= capacity) { + expand(tid + 1) + } + } + + fun copy() = IntArrayClock(capacity).also { copyInto(it.clock) } + + private fun copyInto(other: IntArray) { + require(other.size >= capacity) + // TODO: use arraycopy? + // System.arraycopy(old, 0, clock, 0, capacity) + for (i in 0 until capacity) { + other[i] = clock[i] + } + } + + override fun equals(other: Any?): Boolean = + (other is IntArrayClock) && (clock.contentEquals(other.clock)) + + override fun hashCode(): Int = + clock.contentHashCode() + + override fun toString() = + clock.joinToString(prefix = "[", separator = ",", postfix = "]") + + companion object { + private fun emptyIntArrayClock(capacity: Int) = + IntArray(capacity) { -1 } + } +} + +fun VectorClock.toHBClock(capacity: Int, tid: ThreadID, aid: Int): HBClock { + check(this is IntArrayClock) + val result = emptyClock(capacity) + for (i in 0 until capacity) { + if (i == tid) { + result.clock[i] = get(i) + continue + } + result.clock[i] = 1 + get(i) + } + return result +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/AbstractLincheckTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/AbstractLincheckTest.kt index 8bed18304..2b74e21d7 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/AbstractLincheckTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/AbstractLincheckTest.kt @@ -11,9 +11,8 @@ package org.jetbrains.kotlinx.lincheck_test import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.strategy.* -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions import org.jetbrains.kotlinx.lincheck.strategy.stress.* -import org.jetbrains.kotlinx.lincheck.verifier.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* import org.jetbrains.kotlinx.lincheck_test.util.* import org.junit.* import kotlin.reflect.* @@ -56,6 +55,14 @@ abstract class AbstractLincheckTest( runInternalTest() } + @Test(timeout = TIMEOUT) + fun testWithExperimentalModelCheckingStrategy(): Unit = ModelCheckingOptions().run { + useExperimentalModelChecking() + invocationsPerIteration(1_000) + commonConfiguration() + runInternalTest() + } + private fun > O.commonConfiguration(): Unit = run { iterations(30) actorsBefore(2) @@ -63,6 +70,7 @@ abstract class AbstractLincheckTest( actorsPerThread(2) actorsAfter(2) minimizeFailedScenario(false) + logLevel(LoggingLevel.INFO) customize() } } diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/generator/ParamGeneratorsTests.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/generator/ParamGeneratorsTests.kt index c3f9ebb23..8c76e9d56 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/generator/ParamGeneratorsTests.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/generator/ParamGeneratorsTests.kt @@ -21,6 +21,7 @@ import org.junit.Test import org.jetbrains.kotlinx.lincheck.paramgen.* import org.jetbrains.kotlinx.lincheck_test.verifier.linearizability.SpinLockBasedSet import org.junit.Assert.* +import org.junit.Ignore import kotlin.math.pow /** @@ -225,6 +226,7 @@ class EnumParamGeneratorTest { */ @Param(name = "operation_type", gen = EnumGen::class) @Param(name = "key", gen = IntGen::class, conf = "1:5") +@Ignore class NamedEnumParamGeneratorTest { private val set = SpinLockBasedSet() @@ -255,6 +257,7 @@ class NamedEnumParamGeneratorTest { * This test checks enum generation with in-place configured unnamed enum generator */ @Param(name = "key", gen = IntGen::class, conf = "1:5") +@Ignore class UnnamedEnumParamGeneratorTest() { private val set = SpinLockBasedSet() @@ -284,6 +287,7 @@ class UnnamedEnumParamGeneratorTest() { * Test checks that enum generator will be used even without [Param] annotation */ @Param(name = "key", gen = IntGen::class, conf = "1:5") +@Ignore class EnumParamWithoutAnnotationGeneratorTest: BaseEnumSetTest() { @Operation @@ -319,6 +323,7 @@ abstract class BaseEnumSetTest { * Test checks that if one named parameter generator is associated with many types, then an exception is thrown */ @Param(name = "type", gen = EnumGen::class) +@Ignore class MultipleTypesAssociatedWithNamedEnumParameterGeneratorTest { @Operation @@ -344,6 +349,7 @@ class MultipleTypesAssociatedWithNamedEnumParameterGeneratorTest { * Checks configuration works with enums with spaces in values names */ @Param(name = "type", gen = EnumGen::class, conf = "FIRST OPTION, SECOND OPTION") +@Ignore class EnumsWithWhitespacesInNameConfigurationTest { @Operation diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentLinkedDequeTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentLinkedDequeTest.kt index 812a5e4b5..ee7c18d86 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentLinkedDequeTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentLinkedDequeTest.kt @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* import org.junit.* import java.util.concurrent.* +@Ignore class ConcurrentLinkedDequeTest { private val deque = ConcurrentLinkedDeque() diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentMapTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentMapTest.kt index 2fb3c2337..f1f86aae1 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentMapTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ConcurrentMapTest.kt @@ -33,6 +33,7 @@ class ConcurrentHashMapTest { .check(this::class) } +@Ignore class ConcurrentSkipListMapTest { private val map = ConcurrentSkipListMap() diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/MPSCQueueTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/MPSCQueueTest.kt index 742eaf613..401495cbf 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/MPSCQueueTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/MPSCQueueTest.kt @@ -17,6 +17,7 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* import org.jetbrains.kotlinx.lincheck.strategy.stress.* import org.junit.* +@Ignore class MPSCQueueTest { private val queue = MpscLinkedAtomicQueue() diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ObstructionFreedomViolationTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ObstructionFreedomViolationTest.kt index 789b64057..b73c9dcd9 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ObstructionFreedomViolationTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/guide/ObstructionFreedomViolationTest.kt @@ -52,6 +52,7 @@ class MSQueueBlocking { ) } +@Ignore class ObstructionFreedomViolationTest { private val q = MSQueueBlocking() diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/runner/TestThreadExecutionHelperTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/runner/TestThreadExecutionHelperTest.kt index dffe1b8cb..4e4756787 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/runner/TestThreadExecutionHelperTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/runner/TestThreadExecutionHelperTest.kt @@ -47,12 +47,14 @@ class TestThreadExecutionHelperTest { return false } - override fun afterCoroutineCancelled(iThread: Int) {} + override fun afterCoroutineCancelled(iThread: Int, promptCancellation: Boolean, result: CancellationResult) {} override fun afterCoroutineResumed(iThread: Int) {} override fun afterCoroutineSuspended(iThread: Int) {} + override fun onResumeCoroutine(iResumedThread: Int, iResumedActor: Int) {} + override fun onFailure(iThread: Int, e: Throwable) {} override fun onFinish(iThread: Int) {} diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/MemoryModelTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/MemoryModelTest.kt new file mode 100644 index 000000000..29e7d01b9 --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/MemoryModelTest.kt @@ -0,0 +1,134 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.strategy.eventstructure + +import org.jetbrains.kotlinx.lincheck.execution.parallelResults +import org.jetbrains.kotlinx.lincheck.scenario +import java.util.concurrent.atomic.* + +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.junit.Ignore + +import org.junit.Test + +/** + * These tests check that [EventStructureStrategy] adheres to the weak memory model. + * It contains various litmus tests to check for specific weak behaviors. + */ +@Ignore +class MemoryModelTest { + + private val read = SharedMemory::read + private val write = SharedMemory::write + private val compareAndSet = SharedMemory::compareAndSet + private val fetchAndAdd = SharedMemory::fetchAndAdd + + companion object { + const val x = 0 + const val y = 1 + const val z = 2 + } + + @Test + fun testRRWW() { + val testScenario = scenario { + parallel { + thread { + actor(read, x) + actor(read, y) + } + thread { + actor(write, y, 1) + } + thread { + actor(write, x, 1) + } + } + } + val outcomes: Set> = setOf( + (0 to 0), + (0 to 1), + (1 to 0), + (1 to 1) + ) + litmusTest(SharedMemory::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[0][1]!!) + (r1 to r2) + } + } + + /* ======== Store Buffering ======== */ + + @Test + fun testSB() { + val testScenario = scenario { + parallel { + thread { + actor(write, x, 1) + actor(read, y) + } + thread { + actor(write, y, 1) + actor(read, x) + } + } + } + val outcomes: Set> = setOf( + (0 to 1), + (1 to 0), + (1 to 1) + ) + litmusTest(SharedMemory::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][1]!!) + val r2 = getValue(results.parallelResults[1][1]!!) + (r1 to r2) + } + } + +} + +internal class SharedMemory(size: Int = 16) { + // TODO: use AtomicIntegerArray once it is fixed + // TODO: In the future we would likely want to switch to atomicfu primitives. + // However, atomicfu currently does not support various access modes that we intend to test here. + private val memory = Array(size) { AtomicInteger() } + + val size: Int + get() = memory.size + + fun write(location: Int, value: Int) { + memory[location].set(value) + } + + fun read(location: Int): Int { + return memory[location].get() + } + + // TODO: use `compareAndExchange` once Java 9 is available? + fun compareAndSet(location: Int, expected: Int, desired: Int): Boolean { + return memory[location].compareAndSet(expected, desired) + } + + fun fetchAndAdd(location: Int, delta: Int): Int { + return memory[location].getAndAdd(delta) + } +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/PrimitivesTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/PrimitivesTest.kt new file mode 100644 index 000000000..52d79b852 --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/PrimitivesTest.kt @@ -0,0 +1,1249 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +// this line is required to import jdk.internal.misc.Unsafe +// TODO: better solution? +@file:Suppress("JAVA_MODULE_DOES_NOT_EXPORT_PACKAGE") + +package org.jetbrains.kotlinx.lincheck_test.strategy.eventstructure + +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.execution.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import java.util.concurrent.atomic.* +import java.util.concurrent.locks.LockSupport.* +import java.lang.invoke.MethodHandles +import jdk.internal.misc.Unsafe +// import sun.misc.Unsafe +import kotlin.coroutines.* +import kotlinx.coroutines.* +import org.junit.Ignore +import org.junit.Test +import kotlin.reflect.jvm.javaMethod + +class PrimitivesTest { + + class PlainPrimitiveVariable { + private var variable: Int = 0 + + fun write(value: Int) { + variable = value + } + + fun read(): Int { + return variable + } + } + + @Test + fun testPlainPrimitiveAccesses() { + val write = PlainPrimitiveVariable::write + val read = PlainPrimitiveVariable::read + val testScenario = scenario { + parallel { + thread { + actor(write, 1) + } + thread { + actor(read) + } + thread { + actor(write, 2) + } + } + } + // TODO: when we will implement various access modes, + // we should probably report races on plain variables as errors (or warnings at least) + val outcomes: Set = setOf(0, 1, 2) + litmusTest(PlainPrimitiveVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class PlainReferenceVariable { + private var variable: String = "" + + fun write(value: String) { + variable = value + } + + fun read(): String { + return variable + } + } + + @Test + fun testPlainReferenceAccesses() { + val write = PlainReferenceVariable::write + val read = PlainReferenceVariable::read + val testScenario = scenario { + parallel { + thread { + actor(write, "a") + } + thread { + actor(read) + } + thread { + actor(write, "b") + } + } + } + val outcomes: Set = setOf("", "a", "b") + litmusTest(PlainReferenceVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class PrimitiveArray { + private val array = IntArray(8) + + fun write(index: Int, value: Int) { + array[index] = value + } + + fun read(index: Int): Int { + return array[index] + } + } + + @Test + fun testPrimitiveArrayAccesses() { + val write = PrimitiveArray::write + val read = PrimitiveArray::read + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, 1) + } + thread { + actor(read, index) + } + thread { + actor(write, index, 2) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(PrimitiveArray::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class ReferenceArray { + private val array = Array(8) { "" } + + fun write(index: Int, value: String) { + array[index] = value + } + + fun read(index: Int): String { + return array[index] + } + } + + @Test + fun testReferenceArrayAccesses() { + val write = ReferenceArray::write + val read = ReferenceArray::read + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, "a") + } + thread { + actor(read, index) + } + thread { + actor(write, index, "b") + } + } + } + val outcomes: Set = setOf("", "a", "b") + litmusTest(ReferenceArray::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class AtomicVariable { + // TODO: In the future we would likely want to switch to atomicfu primitives. + // However, atomicfu currently does not support various access modes that we intend to test here. + private val variable = AtomicInteger() + + fun write(value: Int) { + variable.set(value) + } + + fun read(): Int { + return variable.get() + } + + fun compareAndSet(expected: Int, desired: Int): Boolean { + return variable.compareAndSet(expected, desired) + } + + fun addAndGet(delta: Int): Int { + return variable.addAndGet(delta) + } + + fun getAndAdd(delta: Int): Int { + return variable.getAndAdd(delta) + } + } + + @Test + fun testAtomicAccesses() { + val read = AtomicVariable::read + val write = AtomicVariable::write + val testScenario = scenario { + parallel { + thread { + actor(write, 1) + } + thread { + actor(read) + } + thread { + actor(write, 2) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(AtomicVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testCompareAndSet() { + val read = AtomicVariable::read + val compareAndSet = AtomicVariable::compareAndSet + val testScenario = scenario { + parallel { + thread { + actor(compareAndSet, 0, 1) + } + thread { + actor(compareAndSet, 0, 1) + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(true, false, 1), + Triple(false, true, 1) + ) + litmusTest(AtomicVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + @Test + fun testGetAndAdd() { + val read = AtomicVariable::read + val getAndAdd = AtomicVariable::getAndAdd + val testScenario = scenario { + parallel { + thread { + actor(getAndAdd, 1) + } + thread { + actor(getAndAdd, 1) + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(0, 1, 2), + Triple(1, 0, 2) + ) + litmusTest(AtomicVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + @Test + fun testAddAndGet() { + val read = AtomicVariable::read + val addAndGet = AtomicVariable::addAndGet + val testScenario = scenario { + parallel { + thread { + actor(addAndGet, 1) + } + thread { + actor(addAndGet, 1) + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(1, 2, 2), + Triple(2, 1, 2) + ) + litmusTest(AtomicVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + class GlobalAtomicVariable { + + companion object { + // TODO: In the future we would likely want to switch to atomicfu primitives. + // However, atomicfu currently does not support various access modes that we intend to test here. + private val globalVariable = AtomicInteger(0) + } + + fun write(value: Int) { + globalVariable.set(value) + } + + fun read(): Int { + return globalVariable.get() + } + + fun compareAndSet(expected: Int, desired: Int): Boolean { + return globalVariable.compareAndSet(expected, desired) + } + + fun addAndGet(delta: Int): Int { + return globalVariable.addAndGet(delta) + } + + fun getAndAdd(delta: Int): Int { + return globalVariable.getAndAdd(delta) + } + } + + // TODO: repair!!! + @Ignore + @Test + fun testGlobalAtomicAccesses() { + val read = GlobalAtomicVariable::read + val write = GlobalAtomicVariable::write + val testScenario = scenario { + parallel { + thread { + actor(write, 1) + } + thread { + actor(read) + } + thread { + actor(write, 2) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(GlobalAtomicVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + // TODO: handle IntRef (var variables accessed from multiple threads) + + class VolatileReferenceVariable { + @Volatile + private var variable: String? = null + + companion object { + private val updater = + AtomicReferenceFieldUpdater.newUpdater(VolatileReferenceVariable::class.java, String::class.java, "variable") + + private val handle = run { + val lookup = MethodHandles.lookup() + lookup.findVarHandle(VolatileReferenceVariable::class.java, "variable", String::class.java) + } + + private val U = Unsafe.getUnsafe() + + private val offset = U.objectFieldOffset(VolatileReferenceVariable::class.java, "variable") + + } + + fun read(): String? { + return variable + } + + fun afuRead(): String? { + return updater.get(this) + } + + fun vhRead(): String? { + return handle.get(this) as String? + } + + fun unsafeRead(): String? { + return U.getObject(this, offset) as String? + } + + fun write(value: String?) { + variable = value + } + + fun afuWrite(value: String?) { + updater.set(this, value) + } + + fun vhWrite(value: String?) { + handle.set(this, value) + } + + fun unsafeWrite(value: String?) { + U.putObject(this, offset, value) + } + + fun afuCompareAndSet(expected: String?, desired: String?): Boolean { + return updater.compareAndSet(this, expected, desired) + } + + fun vhCompareAndSet(expected: String?, desired: String?): Boolean { + return handle.compareAndSet(this, expected, desired) + } + + fun unsafeCompareAndSet(expected: String?, desired: String?): Boolean { + return U.compareAndSetObject(this, offset, expected, desired) + } + + } + + @Test + fun testAtomicFieldUpdaterAccesses() { + val read = VolatileReferenceVariable::afuRead + val write = VolatileReferenceVariable::afuWrite + val testScenario = scenario { + parallel { + thread { + actor(write, "a") + } + thread { + actor(read) + } + thread { + actor(write, "b") + } + } + } + val outcomes: Set = setOf(null, "a", "b") + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testVarHandleAccesses() { + val read = VolatileReferenceVariable::vhRead + val write = VolatileReferenceVariable::vhWrite + val testScenario = scenario { + parallel { + thread { + actor(write, "a") + } + thread { + actor(read) + } + thread { + actor(write, "b") + } + } + } + val outcomes: Set = setOf(null, "a", "b") + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testUnsafeAccesses() { + val read = VolatileReferenceVariable::unsafeRead + val write = VolatileReferenceVariable::unsafeWrite + val testScenario = scenario { + parallel { + thread { + actor(write, "a") + } + thread { + actor(read) + } + thread { + actor(write, "b") + } + } + } + val outcomes: Set = setOf(null, "a", "b") + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testAtomicFieldUpdaterCompareAndSet() { + val read = VolatileReferenceVariable::afuRead + val compareAndSet = VolatileReferenceVariable::afuCompareAndSet + val testScenario = scenario { + parallel { + thread { + actor(compareAndSet, null, "a") + } + thread { + actor(compareAndSet, null, "a") + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(true, false, "a"), + Triple(false, true, "a") + ) + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + @Test + fun testVarHandleCompareAndSet() { + val read = VolatileReferenceVariable::vhRead + val compareAndSet = VolatileReferenceVariable::vhCompareAndSet + val testScenario = scenario { + parallel { + thread { + actor(compareAndSet, null, "a") + } + thread { + actor(compareAndSet, null, "a") + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(true, false, "a"), + Triple(false, true, "a") + ) + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + @Test + fun testUnsafeCompareAndSet() { + val read = VolatileReferenceVariable::unsafeRead + val compareAndSet = VolatileReferenceVariable::unsafeCompareAndSet + val testScenario = scenario { + parallel { + thread { + actor(compareAndSet, null, "a") + } + thread { + actor(compareAndSet, null, "a") + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(true, false, "a"), + Triple(false, true, "a") + ) + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + private data class Quad( + val first: A, val second: B, val third: C, val forth: D + ) + + @Test + fun testMixedAccesses() { + val read = VolatileReferenceVariable::read + val afuRead = VolatileReferenceVariable::afuRead + val vhRead = VolatileReferenceVariable::vhRead + val unsafeRead = VolatileReferenceVariable::unsafeRead + val write = VolatileReferenceVariable::write + val afuWrite = VolatileReferenceVariable::afuWrite + val vhWrite = VolatileReferenceVariable::vhWrite + val unsafeWrite = VolatileReferenceVariable::unsafeWrite + // TODO: also add Unsafe accessors once they are supported + val testScenario = scenario { + parallel { + thread { + actor(write, "a") + } + thread { + actor(afuWrite, "b") + } + thread { + actor(vhWrite, "c") + } + thread { + actor(unsafeWrite, "d") + } + thread { + actor(read) + } + thread { + actor(afuRead) + } + thread { + actor(vhRead) + } + thread { + actor(unsafeRead) + } + } + } + val values = setOf(null, "a", "b", "c", "d") + val outcomes: Set> = + values.flatMap { a -> values.flatMap { b -> values.flatMap { c -> values.flatMap { d -> + listOf(Quad(a, b, c, d)) + }}}}.toSet() + litmusTest(VolatileReferenceVariable::class.java, testScenario, outcomes) { results -> + val a = getValue(results.parallelResults[4][0]!!) + val b = getValue(results.parallelResults[5][0]!!) + val c = getValue(results.parallelResults[6][0]!!) + val d = getValue(results.parallelResults[7][0]!!) + Quad(a, b, c, d) + } + } + + class UnsafeArrays { + private var byteArray: ByteArray = ByteArray(8) + private var shortArray: ShortArray = ShortArray(8) + private var intArray: IntArray = IntArray(8) + private var longArray: LongArray = LongArray(8) + private var referenceArray: Array = Array(8) { "" } + + companion object { + private val U = Unsafe.getUnsafe() + + private val byteArrayOffset = U.arrayBaseOffset(ByteArray::class.java) + private val shortArrayOffset = U.arrayBaseOffset(ShortArray::class.java) + private val intArrayOffset = U.arrayBaseOffset(IntArray::class.java) + private val longArrayOffset = U.arrayBaseOffset(LongArray::class.java) + private val referenceArrayOffset = U.arrayBaseOffset(Array::class.java) + + private val byteIndexScale = U.arrayIndexScale(ByteArray::class.java) + private val shortIndexScale = U.arrayIndexScale(ShortArray::class.java) + private val intIndexScale = U.arrayIndexScale(IntArray::class.java) + private val longIndexScale = U.arrayIndexScale(LongArray::class.java) + private val referenceIndexScale = U.arrayIndexScale(Array::class.java) + + } + + fun writeByte(index: Int, value: Byte) { + U.putByte(byteArray, (index.toLong() * byteIndexScale) + byteArrayOffset, value) + } + + fun writeShort(index: Int, value: Short) { + U.putShort(shortArray, (index.toLong() * shortIndexScale) + shortArrayOffset, value) + } + + fun writeInt(index: Int, value: Int) { + U.putInt(intArray, (index.toLong() * intIndexScale) + intArrayOffset, value) + } + + fun writeLong(index: Int, value: Long) { + U.putLong(longArray, (index.toLong() * longIndexScale) + longArrayOffset, value) + } + + fun writeReference(index: Int, value: String) { + U.putObject(referenceArray, (index.toLong() * referenceIndexScale) + referenceArrayOffset, value) + } + + fun readByte(index: Int): Byte { + return U.getByte(byteArray, (index.toLong() * byteIndexScale) + byteArrayOffset) + } + + fun readShort(index: Int): Short { + return U.getShort(shortArray, (index.toLong() * shortIndexScale) + shortArrayOffset) + } + + fun readInt(index: Int): Int { + return U.getInt(intArray, (index.toLong() * intIndexScale) + intArrayOffset) + } + + fun readLong(index: Int): Long { + return U.getLong(longArray, (index.toLong() * longIndexScale) + longArrayOffset) + } + + fun readReference(index: Int): String { + return U.getObject(referenceArray, (index.toLong() * referenceIndexScale) + referenceArrayOffset) as String + } + + } + + @Test + fun testUnsafeByteArrayAccesses() { + val read = UnsafeArrays::readByte + val write = UnsafeArrays::writeByte + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, 1.toByte()) + } + thread { + actor(read, index) + } + thread { + actor(write, index, 2.toByte()) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(UnsafeArrays::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testUnsafeShortArrayAccesses() { + val read = UnsafeArrays::readShort + val write = UnsafeArrays::writeShort + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, 1.toShort()) + } + thread { + actor(read, index) + } + thread { + actor(write, index, 2.toShort()) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(UnsafeArrays::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testUnsafeIntArrayAccesses() { + val read = UnsafeArrays::readInt + val write = UnsafeArrays::writeInt + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, 1) + } + thread { + actor(read, index) + } + thread { + actor(write, index, 2) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(UnsafeArrays::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testUnsafeLongArrayAccesses() { + val read = UnsafeArrays::readLong + val write = UnsafeArrays::writeLong + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, 1L) + } + thread { + actor(read, index) + } + thread { + actor(write, index, 2L) + } + } + } + val outcomes: Set = setOf(0, 1, 2) + litmusTest(UnsafeArrays::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + @Test + fun testUnsafeReferenceArrayAccesses() { + val read = UnsafeArrays::readReference + val write = UnsafeArrays::writeReference + val index = 2 + val testScenario = scenario { + parallel { + thread { + actor(write, index, "a") + } + thread { + actor(read, index) + } + thread { + actor(write, index, "b") + } + } + } + val outcomes: Set = setOf("", "a", "b") + litmusTest(UnsafeArrays::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class SynchronizedVariable { + + private var variable: Int = 0 + + @Synchronized + fun write(value: Int) { + variable = value + } + + @Synchronized + fun read(): Int { + return variable + } + + @Synchronized + fun waitAndRead(): Int { + // TODO: handle spurious wake-ups? + (this as Object).wait() + return variable + } + + @Synchronized + fun writeAndNotify(value: Int) { + variable = value + (this as Object).notify() + } + + @Synchronized + fun compareAndSet(expected: Int, desired: Int): Boolean { + return if (variable == expected) { + variable = desired + true + } else false + } + + @Synchronized + fun addAndGet(delta: Int): Int { + variable += delta + return variable + } + + @Synchronized + fun getAndAdd(delta: Int): Int { + val value = variable + variable += delta + return value + } + + } + + @Test + fun testSynchronized() { + val read = SynchronizedVariable::read + val addAndGet = SynchronizedVariable::addAndGet + val testScenario = scenario { + parallel { + thread { + actor(addAndGet, 1) + } + thread { + actor(addAndGet, 1) + } + } + post { + actor(read) + } + } + val outcomes: Set> = setOf( + Triple(1, 2, 2), + Triple(2, 1, 2) + ) + // TODO: investigate why `executionCount = 3` + litmusTest(SynchronizedVariable::class.java, testScenario, outcomes) { results -> + val r1 = getValue(results.parallelResults[0][0]!!) + val r2 = getValue(results.parallelResults[1][0]!!) + val r3 = getValue(results.postResults[0]!!) + Triple(r1, r2, r3) + } + } + + @Test + fun testWaitNotify() { + val writeAndNotify = SynchronizedVariable::writeAndNotify + val waitAndRead = SynchronizedVariable::waitAndRead + val testScenario = scenario { + parallel { + thread { + actor(writeAndNotify, 1) + } + thread { + actor(waitAndRead) + } + } + } + val outcomes = setOf(1) + litmusTest(SynchronizedVariable::class.java, testScenario, outcomes) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class ParkLatchedVariable { + + private var variable: Int = 0 + + @Volatile + private var parkedThread: Thread? = null + + @Volatile + private var delivered: Boolean = false + + fun parkAndRead(): Int? { + // TODO: handle spurious wake-ups? + parkedThread = Thread.currentThread() + return if (delivered) { + park() + variable + } else null + } + + fun writeAndUnpark(value: Int) { + variable = value + val thread = parkedThread + if (thread != null) + delivered = true + unpark(thread) + } + + } + + @Test + fun testParking() { + val writeAndUnpark = ParkLatchedVariable::writeAndUnpark + val parkAndRead = ParkLatchedVariable::parkAndRead + val testScenario = scenario { + parallel { + thread { + actor(writeAndUnpark, 1) + } + thread { + actor(parkAndRead) + } + } + } + val outcomes = setOf(null, 1) + litmusTest(ParkLatchedVariable::class.java, testScenario, outcomes, executionCount = 3) { results -> + getValue(results.parallelResults[1][0]!!) + } + } + + class CoroutineWrapper { + + val continuation = AtomicReference>() + var resumedOrCancelled = AtomicBoolean(false) + + @Operation(handleExceptionsAsResult = [CancelledOperationException::class]) + suspend fun suspend(): Int { + return suspendCancellableCoroutine { continuation -> + this.continuation.set(continuation) + } + } + + @InternalCoroutinesApi + @Operation + fun resume(value: Int): Boolean { + this.continuation.get()?.let { + if (resumedOrCancelled.compareAndSet(false, true)) { + val token = it.tryResume(value) + if (token != null) + it.completeResume(token) + return (token != null) + } + } + return false + } + + @Operation + fun cancel(): Boolean { + this.continuation.get()?.let { + if (resumedOrCancelled.compareAndSet(false, true)) { + it.cancel(CancelledOperationException) + return true + } + } + return false + } + } + + internal object CancelledOperationException : Exception() + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun testResume() { + val suspend = CoroutineWrapper::suspend + val resume = CoroutineWrapper::resume + val testScenario = scenario { + parallel { + thread { + actor(suspend) + } + thread { + actor(resume, 1) + } + } + } + val outcomes = setOf( + (Suspended to false), + (1 to true) + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b = getValue(results.parallelResults[1][0]!!) + (r to b) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun testCancel() { + val suspendActor = Actor( + method = CoroutineWrapper::suspend.javaMethod!!, + arguments = listOf(), + cancelOnSuspension = false + ) + val cancel = CoroutineWrapper::cancel + val testScenario = scenario { + parallel { + thread { + add(suspendActor) + } + thread { + actor(cancel) + } + } + } + val outcomes = setOf( + (Suspended to false), + (CancelledOperationException to true) + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b = getValue(results.parallelResults[1][0]!!) + (r to b) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun testLincheckCancellation() { + val suspendActor = Actor( + method = CoroutineWrapper::suspend.javaMethod!!, + arguments = listOf(), + cancelOnSuspension = true + ) + val resume = CoroutineWrapper::resume + val testScenario = scenario { + parallel { + thread { + add(suspendActor) + } + thread { + actor(resume, 1) + } + } + } + val outcomes = setOf( + (Cancelled to false), + (1 to true) + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b = getValue(results.parallelResults[1][0]!!) + (r to b) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun testLincheckPromptCancellation() { + val suspendActor = Actor( + method = CoroutineWrapper::suspend.javaMethod!!, + arguments = listOf(), + cancelOnSuspension = true, + promptCancellation = true, + ) + val resume = CoroutineWrapper::resume + val testScenario = scenario { + parallel { + thread { + add(suspendActor) + } + thread { + actor(resume, 1) + } + } + } + val outcomes = setOf( + (Cancelled to false), + (Cancelled to true), + // (1 to true), + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b = getValue(results.parallelResults[1][0]!!) + (r to b) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun testResumeCancel() { + val suspendActor = Actor( + method = CoroutineWrapper::suspend.javaMethod!!, + arguments = listOf(), + cancelOnSuspension = false + ) + val resume = CoroutineWrapper::resume + val cancel = CoroutineWrapper::cancel + val testScenario = scenario { + parallel { + thread { + add(suspendActor) + } + thread { + actor(resume, 1) + } + thread { + actor(cancel) + } + } + } + val outcomes = setOf( + Triple(Suspended, false, false), + Triple(1, true, false), + Triple(CancelledOperationException, false, true) + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b1 = getValue(results.parallelResults[1][0]!!) + val b2 = getValue(results.parallelResults[2][0]!!) + Triple(r, b1, b2) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun test1Resume2Suspend() { + val suspend = CoroutineWrapper::suspend + val resume = CoroutineWrapper::resume + val testScenario = scenario { + parallel { + thread { + actor(suspend) + } + thread { + actor(suspend) + } + thread { + actor(resume, 1) + } + } + } + val outcomes = setOf( + Triple(Suspended, Suspended, false), + Triple(Suspended, 1, true), + Triple(1, Suspended, true), + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r1 = getValueSuspended(results.parallelResults[0][0]!!) + val r2 = getValueSuspended(results.parallelResults[1][0]!!) + val b = getValue(results.parallelResults[2][0]!!) + Triple(r1, r2, b) + } + } + + @InternalCoroutinesApi + @Test(timeout = TIMEOUT) + fun test2Resume1Suspend() { + val suspend = CoroutineWrapper::suspend + val resume = CoroutineWrapper::resume + val testScenario = scenario { + parallel { + thread { + actor(suspend) + } + thread { + actor(resume, 1) + } + thread { + actor(resume, 2) + } + } + } + val outcomes = setOf( + Triple(Suspended, false, false), + Triple(1, true, false), + Triple(2, false, true), + ) + litmusTest(CoroutineWrapper::class.java, testScenario, outcomes, executionCount = UNKNOWN) { results -> + val r = getValueSuspended(results.parallelResults[0][0]!!) + val b1 = getValue(results.parallelResults[1][0]!!) + val b2 = getValue(results.parallelResults[2][0]!!) + Triple(r, b1, b2) + } + } + +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/Utils.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/Utils.kt new file mode 100644 index 000000000..08d66f74f --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/strategy/eventstructure/Utils.kt @@ -0,0 +1,108 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.strategy.eventstructure + +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.execution.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.eventstructure.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions +import org.jetbrains.kotlinx.lincheck.strategy.runIteration +import org.jetbrains.kotlinx.lincheck.transformation.InstrumentationMode +import org.jetbrains.kotlinx.lincheck.transformation.withLincheckJavaAgent +import org.jetbrains.kotlinx.lincheck.verifier.* + +import org.junit.Assert + +internal const val UNIQUE = -1 +internal const val UNKNOWN = -2 + +internal fun litmusTest( + testClass: Class<*>, + testScenario: ExecutionScenario, + expectedOutcomes: Set, + executionCount: Int = UNIQUE, + getOutcome: (ExecutionResult) -> Outcome, +) { + require(executionCount >= 0 || executionCount == UNIQUE || executionCount == UNKNOWN) + + val outcomes: MutableSet = mutableSetOf() + val verifier = createVerifier(testScenario) { results -> + outcomes.add(getOutcome(results)) + true + } + withLincheckJavaAgent(InstrumentationMode.EXPERIMENTAL_MODEL_CHECKING) { + val strategy = createStrategy(testClass, testScenario) + val failure = strategy.runIteration(INVOCATIONS, verifier) + assert(failure == null) { failure.toString() } + Assert.assertEquals(expectedOutcomes, outcomes) + + val expectedCount = when (executionCount) { + UNIQUE -> expectedOutcomes.size + UNKNOWN -> strategy.stats.consistentInvocations + else -> executionCount + } + Assert.assertEquals(expectedCount, strategy.stats.consistentInvocations) + } +} + +private fun createConfiguration(testClass: Class<*>) = + ModelCheckingOptions() + .useExperimentalModelChecking() + // for tests debugging set large timeout + .invocationTimeout(60 * 60 * 1000) + .createTestConfigurations(testClass) + +internal fun createStrategy(testClass: Class<*>, scenario: ExecutionScenario): EventStructureStrategy { + return createConfiguration(testClass) + .createStrategy( + testClass = testClass, + scenario = scenario, + validationFunction = null, + stateRepresentationMethod = null, + ) as EventStructureStrategy +} + +internal fun createVerifier(testScenario: ExecutionScenario?, verify: (ExecutionResult) -> Boolean): Verifier = + object : Verifier { + + override fun verifyResults(scenario: ExecutionScenario?, results: ExecutionResult?): Boolean { + require(testScenario == scenario) + require(results != null) + return verify(results) + } + + } + +internal inline fun getValue(result: Result): T = + (result as ValueResult).value as T + +internal fun getValueSuspended(result: Result): Any? = when (result) { + is ValueResult -> result.value + is ExceptionResult -> result.throwable + is Suspended -> result + is Cancelled -> result + else -> throw IllegalArgumentException() +} + +internal const val TIMEOUT = 30 * 1000L // 30 sec + +// we expect for all litmus tests to have less than 1000 outcomes +private const val INVOCATIONS = 1000 \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/ArrayAccessesTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/ArrayAccessesTest.kt new file mode 100644 index 000000000..ec71bccec --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/ArrayAccessesTest.kt @@ -0,0 +1,219 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2023 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation + +import org.jetbrains.kotlinx.lincheck.Options +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest + +/** + * Tests that int array accesses are properly transformed and tracked. + */ +class IntArrayAccessTest : AbstractLincheckTest() { + private var array = IntArray(3) { 0 } + + @Operation + fun operation() { + array[0] = 0 + array[1] = 1 + array[2] = 2 + check(array[0] == 0) + check(array[1] == 1) + check(array[2] == 2) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +/** + * Tests that byte array accesses are properly transformed and tracked. + */ +class ByteArrayAccessTest : AbstractLincheckTest() { + private var array = ByteArray(3) { 0 } + + @Operation + fun operation() { + array[0] = 0 + array[1] = 1 + array[2] = 2 + check(array[0] == 0.toByte()) + check(array[1] == 1.toByte()) + check(array[2] == 2.toByte()) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +/** + * Tests that byte array accesses are properly transformed and tracked. + */ +class ShortArrayAccessTest : AbstractLincheckTest() { + private var array = ShortArray(3) { 0 } + + @Operation + fun operation() { + array[0] = 0 + array[1] = 1 + array[2] = 2 + check(array[0] == 0.toShort()) + check(array[1] == 1.toShort()) + check(array[2] == 2.toShort()) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +/** + * Tests that long array accesses are properly transformed and tracked. + */ +class LongArrayAccessTest : AbstractLincheckTest() { + private var array = LongArray(3) { 0 } + + @Operation + fun operation() { + array[0] = 0 + array[1] = 1 + array[2] = 2 + check(array[0] == 0L) + check(array[1] == 1L) + check(array[2] == 2L) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +class CharArrayAccessTest : AbstractLincheckTest() { + private var array = CharArray(3) { 0.toChar() } + + @Operation + fun operation() { + array[0] = 'a' + array[1] = 'b' + array[2] = 'c' + check(array[0] == 'a') + check(array[1] == 'b') + check(array[2] == 'c') + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +class BooleanArrayAccessTest : AbstractLincheckTest() { + private var array = BooleanArray(3) { false } + + @Operation + fun operation() { + array[0] = false + array[1] = true + check(array[0] == false) + check(array[1] == true) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +/** + * Tests that boxed array accesses are properly transformed and tracked. + */ +class BoxedArrayAccessTest : AbstractLincheckTest() { + private var array = Array(3) { 0 } + + @Operation + fun operation() { + array[0] = 0 + array[1] = 1 + array[2] = 2 + check(array[0] == 0) + check(array[1] == 1) + check(array[2] == 2) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} + +/** + * Tests that multidimensional array accesses are properly transformed and tracked. + */ +class MultiDimensionalArrayAccessTest : AbstractLincheckTest() { + private var array = Array>(2) { Array(2) { 0 } } + + @Operation + fun operation() { + array[0][0] = 0 + array[0][1] = 1 + array[1][0] = 2 + array[1][1] = 3 + check(array[0][0] == 0) + check(array[0][1] == 1) + check(array[1][0] == 2) + check(array[1][1] == 3) + } + + override fun > O.customize() { + iterations(1) + threads(1) + actorsPerThread(1) + actorsBefore(0) + actorsAfter(0) + } +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/AtomicLongTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/AtomicLongTest.kt deleted file mode 100644 index ac2efba61..000000000 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/AtomicLongTest.kt +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Lincheck - * - * Copyright (C) 2019 - 2024 JetBrains s.r.o. - * - * This Source Code Form is subject to the terms of the - * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed - * with this file, You can obtain one at http://mozilla.org/MPL/2.0/. - */ - -package org.jetbrains.kotlinx.lincheck_test.transformation - -import org.jetbrains.kotlinx.lincheck.annotations.Operation -import org.jetbrains.kotlinx.lincheck_test.* -import java.util.concurrent.atomic.AtomicLong - -/** - * Checks that the AtomicLong.VMSupportsCS8() native method - * is correctly transformed. - */ -class AtomicLongTest : AbstractLincheckTest() { - val counter = AtomicLong() - - @Operation - fun inc() = counter.incrementAndGet() -} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/LocalObjectsTests.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/LocalObjectsTests.kt index ba082e730..811fc9086 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/LocalObjectsTests.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/LocalObjectsTests.kt @@ -47,7 +47,7 @@ class LocalObjectEliminationTest { a.array[1] = 54 val b = A(a.value, a.any, a.array) b.value = 65 - repeat(20) { + repeat(3) { b.array[0] = it } a.any = b diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicArrayTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicArrayTest.kt new file mode 100644 index 000000000..ad787108c --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicArrayTest.kt @@ -0,0 +1,62 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This Source Code Form is subject to the terms of the + * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed + * with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation.atomics + +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck.annotations.Param +import org.jetbrains.kotlinx.lincheck.paramgen.IntGen +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest +import java.util.concurrent.atomic.AtomicIntegerArray + +@Param(name = "idx", gen = IntGen::class, conf = "0:4") +class AtomicIntegerArrayTest : AbstractLincheckTest() { + val value = AtomicIntegerArray(5) + + @Operation + fun get(@Param(name = "idx") idx: Int) = + value.get(idx) + + @Operation + fun set(@Param(name = "idx") idx: Int, newValue: Int) = + value.set(idx, newValue) + + @Operation + fun getAndSet(@Param(name = "idx") idx: Int, newValue: Int) = + value.getAndSet(idx, newValue) + + @Operation + fun compareAndSet(@Param(name = "idx") idx: Int, expectedValue: Int, newValue: Int) = + value.compareAndSet(idx, expectedValue, newValue) + + @Operation + fun addAndGet(@Param(name = "idx") idx: Int, delta: Int) = + value.addAndGet(idx, delta) + + @Operation + fun getAndAdd(@Param(name = "idx") idx: Int, delta: Int) = + value.getAndAdd(idx, delta) + + @Operation + fun incrementAndGet(@Param(name = "idx") idx: Int) = + value.incrementAndGet(idx) + + @Operation + fun getAndIncrement(@Param(name = "idx") idx: Int) = + value.getAndIncrement(idx) + + @Operation + fun decrementAndGet(@Param(name = "idx") idx: Int) = + value.decrementAndGet(idx) + + @Operation + fun getAndDecrement(@Param(name = "idx") idx: Int) = + value.getAndDecrement(idx) +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFUTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFUTest.kt new file mode 100644 index 000000000..89af3e731 --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFUTest.kt @@ -0,0 +1,167 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation.atomics + +import org.jetbrains.kotlinx.lincheck.annotations.Param +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest +import kotlinx.atomicfu.* + +class AtomicFUBooleanTest : AbstractLincheckTest() { + val bool = atomic(false) + + @Operation + fun get() = bool.value + + @Operation + fun set(newValue: Boolean) { + bool.value = newValue + } + + @Operation + fun getAndSet(newValue: Boolean) = + bool.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Boolean, newValue: Boolean) = + bool.compareAndSet(expectedValue, newValue) + +} + +class AtomicFUIntegerTest : AbstractLincheckTest() { + val int = atomic(0) + + @Operation + fun get() = int.value + + @Operation + fun set(newValue: Int) { + int.value = newValue + } + + @Operation + fun getAndSet(newValue: Int) = + int.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Int, newValue: Int) = + int.compareAndSet(expectedValue, newValue) + + @Operation + fun getAndIncrement() = int.getAndIncrement() + + @Operation + fun getAndDecrement() = int.getAndDecrement() + + @Operation + fun getAndAdd(delta: Int) = int.getAndAdd(delta) + + @Operation + fun addAndGet(delta: Int) = int.addAndGet(delta) + + @Operation + fun incrementAndGet() = int.incrementAndGet() + + @Operation + fun decrementAndGet() = int.decrementAndGet() + + @Operation + fun plusAssign(delta: Int) { + int += delta + } + + @Operation + fun minusAssign(delta: Int) { + int -= delta + } +} + +class AtomicFULongTest : AbstractLincheckTest() { + val long = atomic(0L) + + @Operation + fun get() = long.value + + @Operation + fun set(newValue: Long) { + long.value = newValue + } + + @Operation + fun getAndSet(newValue: Long) = + long.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Long, newValue: Long) = + long.compareAndSet(expectedValue, newValue) + + @Operation + fun getAndIncrement() = long.getAndIncrement() + + @Operation + fun getAndDecrement() = long.getAndDecrement() + + @Operation + fun getAndAdd(delta: Long) = long.getAndAdd(delta) + + @Operation + fun addAndGet(delta: Long) = long.addAndGet(delta) + + @Operation + fun incrementAndGet() = long.incrementAndGet() + + @Operation + fun decrementAndGet() = long.decrementAndGet() + + @Operation + fun plusAssign(delta: Long) { + long += delta + } + + @Operation + fun minusAssign(delta: Long) { + long -= delta + } +} + +// see comment on AtomicReferenceTest explaining usage of custom parameter generator here +@Param(name = "test", gen = TestStringGenerator::class) +class AtomicFUReferenceTest : AbstractLincheckTest() { + + val ref = atomic("") + + @Operation + fun get() = ref.value + + @Operation + fun set(@Param(name = "test") newValue: String) { + ref.value = newValue + } + + @Operation + fun getAndSet(@Param(name = "test") newValue: String) = + ref.getAndSet(newValue) + + @Operation + fun compareAndSet(@Param(name = "test") expectedValue: String, + @Param(name = "test") newValue: String) = + ref.compareAndSet(expectedValue, newValue) +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFieldUpdaterTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFieldUpdaterTest.kt new file mode 100644 index 000000000..be3ebb6e8 --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicFieldUpdaterTest.kt @@ -0,0 +1,124 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation.atomics + +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck.annotations.Param +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest +import java.util.concurrent.atomic.* + +class AtomicIntegerFieldUpdaterTest : AbstractLincheckTest() { + @Volatile + var value: Int = 0 + + private val updater = AtomicIntegerFieldUpdater.newUpdater(AtomicIntegerFieldUpdaterTest::class.java, "value") + + @Operation + fun get() = updater.get(this) + + @Operation + fun set(newValue: Int) = updater.set(this, newValue) + + @Operation + fun getAndSet(newValue: Int) = updater.getAndSet(this, newValue) + + @Operation + fun compareAndSet(expectedValue: Int, newValue: Int) = updater.compareAndSet(this, expectedValue, newValue) + + @Operation + fun addAndGet(delta: Int) = updater.addAndGet(this, delta) + + @Operation + fun getAndAdd(delta: Int) = updater.getAndAdd(this, delta) + + @Operation + fun incrementAndGet() = updater.incrementAndGet(this) + + @Operation + fun getAndIncrement() = updater.getAndIncrement(this) + + @Operation + fun decrementAndGet() = updater.decrementAndGet(this) + + @Operation + fun getAndDecrement() = updater.getAndDecrement(this) +} + +class AtomicLongFieldUpdaterTest : AbstractLincheckTest() { + @Volatile + var value: Long = 0 + + private val updater = AtomicLongFieldUpdater.newUpdater(AtomicLongFieldUpdaterTest::class.java, "value") + + @Operation + fun get() = updater.get(this) + + @Operation + fun set(newValue: Long) = updater.set(this, newValue) + + @Operation + fun getAndSet(newValue: Long) = updater.getAndSet(this, newValue) + + @Operation + fun compareAndSet(expectedValue: Long, newValue: Long) = updater.compareAndSet(this, expectedValue, newValue) + + @Operation + fun addAndGet(delta: Long) = updater.addAndGet(this, delta) + + @Operation + fun getAndAdd(delta: Long) = updater.getAndAdd(this, delta) + + @Operation + fun incrementAndGet() = updater.incrementAndGet(this) + + @Operation + fun getAndIncrement() = updater.getAndIncrement(this) + + @Operation + fun decrementAndGet() = updater.decrementAndGet(this) + + @Operation + fun getAndDecrement() = updater.getAndDecrement(this) +} + +@Param(name = "test", gen = TestStringGenerator::class) +class AtomicReferenceFieldUpdaterTest : AbstractLincheckTest() { + @Volatile + var value: String = "" + + private val updater = AtomicReferenceFieldUpdater.newUpdater( + AtomicReferenceFieldUpdaterTest::class.java, + String::class.java, + "value" + ) + + @Operation + fun get() = updater.get(this) + + @Operation + fun set(newValue: String) = updater.set(this, newValue) + + @Operation + fun getAndSet(newValue: String) = updater.getAndSet(this, newValue) + + @Operation + fun compareAndSet(expectedValue: String, newValue: String) = updater.compareAndSet(this, expectedValue, newValue) +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicTest.kt new file mode 100644 index 000000000..169f7fcd5 --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/AtomicTest.kt @@ -0,0 +1,160 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2022 JetBrains s.r.o. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Lesser Public License for more details. + * + * You should have received a copy of the GNU General Lesser Public + * License along with this program. If not, see + * + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation.atomics + +import org.jetbrains.kotlinx.lincheck.RandomProvider +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck.annotations.Param +import org.jetbrains.kotlinx.lincheck.paramgen.ParameterGenerator +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest +import java.util.concurrent.atomic.* +import java.util.* + + +class AtomicBooleanTest : AbstractLincheckTest() { + val value = AtomicBoolean() + + @Operation + fun get() = value.get() + + @Operation + fun set(newValue: Boolean) = value.set(newValue) + + @Operation + fun getAndSet(newValue: Boolean) = value.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Boolean, newValue: Boolean) = value.compareAndSet(expectedValue, newValue) + +} + +class AtomicIntegerTest : AbstractLincheckTest() { + val value = AtomicInteger() + + @Operation + fun get() = value.get() + + @Operation + fun set(newValue: Int) = value.set(newValue) + + @Operation + fun getAndSet(newValue: Int) = value.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Int, newValue: Int) = value.compareAndSet(expectedValue, newValue) + + @Operation + fun addAndGet(delta: Int) = value.addAndGet(delta) + + @Operation + fun getAndAdd(delta: Int) = value.getAndAdd(delta) + + @Operation + fun incrementAndGet() = value.incrementAndGet() + + @Operation + fun getAndIncrement() = value.getAndIncrement() + + @Operation + fun decrementAndGet() = value.decrementAndGet() + + @Operation + fun getAndDecrement() = value.getAndDecrement() +} + +class AtomicLongTest : AbstractLincheckTest() { + val value = AtomicLong() + + @Operation + fun get() = value.get() + + @Operation + fun set(newValue: Long) = value.set(newValue) + + @Operation + fun getAndSet(newValue: Long) = value.getAndSet(newValue) + + @Operation + fun compareAndSet(expectedValue: Long, newValue: Long) = value.compareAndSet(expectedValue, newValue) + + @Operation + fun addAndGet(delta: Long) = value.addAndGet(delta) + + @Operation + fun getAndAdd(delta: Long) = value.getAndAdd(delta) + + @Operation + fun incrementAndGet() = value.incrementAndGet() + + @Operation + fun getAndIncrement() = value.getAndIncrement() + + @Operation + fun decrementAndGet() = value.decrementAndGet() + + @Operation + fun getAndDecrement() = value.getAndDecrement() +} + +/* We use here a generator choosing from a predefined array of strings, + * because the default string generator is not "referentially-stable". + * In other words, it can generate two strings with identical content but have different references. + * Even empty string "" can be represented by several objects. + * Besides that, because we are also testing the compare-and-swap method here, + * executing it on randomly generated strings will result in failures most of the time. + * On the contrary, by choosing from a fixed predefined list of strings, + * we increase the chance of CAS to succeed. + */ +@Param(name = "test", gen = TestStringGenerator::class) +class AtomicReferenceTest : AbstractLincheckTest() { + + val ref = AtomicReference("") + + @Operation + fun get() = ref.get() + + @Operation + fun set(@Param(name = "test") newValue: String) { + ref.set(newValue) + } + + @Operation + fun compareAndSet(@Param(name = "test") expectedValue: String, + @Param(name = "test") newValue: String) = + ref.compareAndSet(expectedValue, newValue) + + @Operation + fun getAndSet(@Param(name = "test") newValue: String) = + ref.getAndSet(newValue) + +} + +// TODO: this generator can be generalized to a generator choosing random element +// from an arbitrary user-defined list +class TestStringGenerator(randomProvider: RandomProvider, configuration: String): ParameterGenerator { + private val random = randomProvider.createRandom() + + private val strings = arrayOf("", "abc", "xyz") + + override fun generate(): String = + strings[random.nextInt(strings.size)] +} + diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/VarHandleTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/VarHandleTest.kt new file mode 100644 index 000000000..18eb6a88b --- /dev/null +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/transformation/atomics/VarHandleTest.kt @@ -0,0 +1,101 @@ +/* + * Lincheck + * + * Copyright (C) 2019 - 2024 JetBrains s.r.o. + * + * This Source Code Form is subject to the terms of the + * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed + * with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.jetbrains.kotlinx.lincheck_test.transformation.atomics + +import org.jetbrains.kotlinx.lincheck.Options +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck_test.AbstractLincheckTest +import java.lang.invoke.MethodHandles + +class VarHandleIntegerFieldTest : AbstractLincheckTest() { + @Volatile + var value: Int = 0 + + private val varHandle = run { + val lookup = MethodHandles.lookup() + lookup.findVarHandle(VarHandleIntegerFieldTest::class.java, "value", Int::class.javaPrimitiveType) + } + + @Operation + fun getField() = value + + @Operation + fun get() = varHandle.get(this) + + @Operation + fun getOpaque() = varHandle.getOpaque(this) + + @Operation + fun getAcquire() = varHandle.getAcquire(this) + + @Operation + fun getVolatile() = varHandle.getVolatile(this) + + @Operation + fun setField(newValue: Int) = run { value = newValue } + + @Operation + fun set(newValue: Int) = varHandle.set(this, newValue) + + @Operation + fun setOpaque(newValue: Int) = varHandle.setOpaque(this, newValue) + + @Operation + fun setRelease(newValue: Int) = varHandle.setRelease(this, newValue) + + @Operation + fun setVolatile(newValue: Int) = varHandle.setVolatile(this, newValue) + + @Operation + fun getAndSet(newValue: Int) = + varHandle.getAndSet(this, newValue) + + @Operation + fun getAndSetAcquire(newValue: Int) = + varHandle.getAndSetAcquire(this, newValue) + + @Operation + fun getAndSetRelease(newValue: Int) = + varHandle.getAndSetRelease(this, newValue) + + @Operation + fun compareAndSet(expectedValue: Int, newValue: Int) = + varHandle.compareAndSet(this, expectedValue, newValue) + + @Operation + fun weakCompareAndSet(expectedValue: Int, newValue: Int) = + varHandle.weakCompareAndSet(this, expectedValue, newValue) + + @Operation + fun weakCompareAndSetAcquire(expectedValue: Int, newValue: Int) = + varHandle.weakCompareAndSetAcquire(this, expectedValue, newValue) + + @Operation + fun weakCompareAndSetRelease(expectedValue: Int, newValue: Int) = + varHandle.weakCompareAndSetRelease(this, expectedValue, newValue) + + @Operation + fun weakCompareAndSetPlain(expectedValue: Int, newValue: Int) = + varHandle.weakCompareAndSetPlain(this, expectedValue, newValue) + + @Operation + fun getAndAdd(delta: Int) = + varHandle.getAndAdd(this, delta) + + @Operation + fun getAndAddAcquire(delta: Int) = + varHandle.getAndAddAcquire(this, delta) + + @Operation + fun getAndAddRelease(delta: Int) = + varHandle.getAndAddRelease(this, delta) + +} \ No newline at end of file diff --git a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/verifier/linearizability/BufferedChannelTest.kt b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/verifier/linearizability/BufferedChannelTest.kt index 856fd440e..7c147f021 100644 --- a/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/verifier/linearizability/BufferedChannelTest.kt +++ b/src/jvm/test/org/jetbrains/kotlinx/lincheck_test/verifier/linearizability/BufferedChannelTest.kt @@ -15,9 +15,11 @@ import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.annotations.* import org.jetbrains.kotlinx.lincheck.paramgen.IntGen import org.jetbrains.kotlinx.lincheck_test.* +import org.junit.Ignore @InternalCoroutinesApi @Param(name = "value", gen = IntGen::class, conf = "1:5") +@Ignore class BufferedChannelTest : AbstractLincheckTest() { private val c = Channel(2)