Skip to content

Commit

Permalink
oc checker refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
csanadtelbisz committed Apr 12, 2024
1 parent 5d0dd74 commit 36238c6
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import hu.bme.mit.theta.solver.SolverManager
import hu.bme.mit.theta.solver.SolverStatus
import java.util.*

class BasicOcChecker<E : Event> : OcCheckerBase<E> {
class BasicOcChecker<E : Event> : OcCheckerBase<E>() {

override val solver: Solver = SolverManager.resolveSolverFactory("Z3:4.13").createSolver()
private var relations: Array<Array<Reason?>>? = null
Expand Down Expand Up @@ -58,19 +58,13 @@ class BasicOcChecker<E : Event> : OcCheckerBase<E> {
val decision = OcAssignment(decisionStack.peek().rels, rf)
decisionStack.push(decision)
val reason0 = setAndClose(decision.rels, rf)
if (reason0 != null) {
solver.add(BoolExprs.Not(reason0.expr))
continue@dpllLoop
}
if (propagate(reason0)) continue@dpllLoop

val writes = events[rf.from.const.varDecl]!!.values.flatten()
.filter { it.type == EventType.WRITE && it.enabled == true }
for (w in writes) {
val reason = derive(decision.rels, rf, w)
if (reason != null) {
solver.add(BoolExprs.Not(reason.expr))
continue@dpllLoop
}
if (propagate(reason)) continue@dpllLoop
}
}

Expand All @@ -79,10 +73,7 @@ class BasicOcChecker<E : Event> : OcCheckerBase<E> {
decisionStack.push(decision)
for (rf in rfs[w.const.varDecl]!!.filter { it.enabled == true }) {
val reason = derive(decision.rels, rf, w)
if (reason != null) {
solver.add(BoolExprs.Not(reason.expr))
continue@dpllLoop
}
if (propagate(reason)) continue@dpllLoop
}
}

Expand All @@ -96,6 +87,13 @@ class BasicOcChecker<E : Event> : OcCheckerBase<E> {

override fun getRelations(): Array<Array<Reason?>>? = relations?.copy()

override fun propagate(reason: Reason?): Boolean {
reason ?: return false
propagated.add(reason)
solver.add(BoolExprs.Not(reason.expr))
return true
}

/**
* Returns true if obj is not on the stack (in other words, if the value of obj is changed in the new model)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,26 @@ interface OcChecker<E : Event> {
* blocks, see Event::clkId)
*/
fun getRelations(): Array<Array<Reason?>>?

/**
* Get the list of propagated clauses in the form of reasons.
*/
fun getPropagatedClauses(): List<Reason>
}

/**
* This interface implements basic utilities for an ordering consistency checker such as derivation rules and
* transitive closure operations.
*/
internal interface OcCheckerBase<E : Event> : OcChecker<E> {
abstract class OcCheckerBase<E : Event> : OcChecker<E> {

protected val propagated: MutableList<Reason> = mutableListOf()

fun derive(rels: Array<Array<Reason?>>, rf: Relation<E>, w: E): Reason? = when {
override fun getPropagatedClauses() = propagated.toList()

protected abstract fun propagate(reason: Reason?): Boolean

protected fun derive(rels: Array<Array<Reason?>>, rf: Relation<E>, w: E): Reason? = when {
rf.from.clkId == rf.to.clkId -> null // rf within an atomic block
w.clkId == rf.from.clkId || w.clkId == rf.to.clkId -> null // w within an atomic block with one of the rf ends

Expand All @@ -83,7 +94,7 @@ internal interface OcCheckerBase<E : Event> : OcChecker<E> {
else -> null
}

fun setAndClose(rels: Array<Array<Reason?>>, rel: Relation<E>): Reason? {
protected fun setAndClose(rels: Array<Array<Reason?>>, rel: Relation<E>): Reason? {
if (rel.from.clkId == rel.to.clkId) return null // within an atomic block
return setAndClose(rels, rel.from.clkId, rel.to.clkId,
if (rel.type == RelationType.PO) PoReason else RelationReason(rel))
Expand Down Expand Up @@ -118,7 +129,6 @@ internal interface OcCheckerBase<E : Event> : OcChecker<E> {
}
}


/**
* Reason(s) of an enabled relation.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,42 @@ import hu.bme.mit.theta.solver.javasmt.JavaSMTUserPropagator
import org.sosy_lab.java_smt.SolverContextFactory.Solvers.Z3
import java.util.*

class UserPropagatorOcChecker<E : Event> : OcCheckerBase<E>, JavaSMTUserPropagator() {
class UserPropagatorOcChecker<E : Event> : OcCheckerBase<E>() {

private val partialAssignment = Stack<OcAssignment<E>>()
override val solver: Solver = JavaSMTSolverFactory.create(Z3, arrayOf()).createSolverWithPropagators(this)
private var solverLevel: Int = 0

private lateinit var writes: Map<VarDecl<*>, Map<Int, List<E>>>
private lateinit var flatWrites: List<E>
private lateinit var rfs: Map<VarDecl<*>, List<Relation<E>>>
private lateinit var flatRfs: List<Relation<E>>

private val userPropagator: JavaSMTUserPropagator = object : JavaSMTUserPropagator() {
override fun onKnownValue(expr: Expr<BoolType>, value: Boolean) {
if (value) {
flatRfs.find { it.declRef == expr }?.let { rf -> propagate(rf) }
?: flatWrites.filter { it.guardExpr == expr }.forEach { w -> propagate(w) }
}
}

override fun onFinalCheck() =
flatWrites.filter { w -> w.guard.isEmpty() || partialAssignment.any { it.event == w } }.forEach { w ->
propagate(w)
}

override fun onPush() {
solverLevel++
}

override fun onPop(levels: Int) {
solverLevel -= levels
while (partialAssignment.isNotEmpty() && partialAssignment.peek().solverLevel > solverLevel) {
partialAssignment.pop()
}
}
}

override val solver: Solver = JavaSMTSolverFactory.create(Z3, arrayOf()).createSolverWithPropagators(userPropagator)
private var solverLevel: Int = 0

override fun check(
events: Map<VarDecl<*>, Map<Int, List<E>>>,
pos: List<Relation<E>>,
Expand All @@ -52,53 +77,26 @@ class UserPropagatorOcChecker<E : Event> : OcCheckerBase<E>, JavaSMTUserPropagat
pos.forEach { setAndClose(initialRels, it) }
partialAssignment.push(OcAssignment(rels = initialRels))

flatRfs.forEach { rf -> registerExpression(rf.declRef) }
flatWrites.forEach { w -> if (w.guard.isNotEmpty()) registerExpression(w.guardExpr) }
flatRfs.forEach { rf -> userPropagator.registerExpression(rf.declRef) }
flatWrites.forEach { w -> if (w.guard.isNotEmpty()) userPropagator.registerExpression(w.guardExpr) }

return solver.check()
}

override fun getRelations(): Array<Array<Reason?>>? = partialAssignment.lastOrNull()?.rels?.copy()

override fun onKnownValue(expr: Expr<BoolType>, value: Boolean) {
if (value) {
flatRfs.find { it.declRef == expr }?.let { rf -> propagate(rf) }
?: flatWrites.filter { it.guardExpr == expr }.forEach { w -> propagate(w) }
}
}

override fun onFinalCheck() =
flatWrites.filter { w -> w.guard.isEmpty() || partialAssignment.any { it.event == w } }.forEach { w ->
propagate(w)
}

override fun onPush() {
solverLevel++
}

override fun onPop(levels: Int) {
solverLevel -= levels
while (partialAssignment.isNotEmpty() && partialAssignment.peek().solverLevel > solverLevel) {
partialAssignment.pop()
}
}

private fun propagate(rf: Relation<E>) {
check(rf.type == RelationType.RFI || rf.type == RelationType.RFE)
val assignement = OcAssignment(partialAssignment.peek().rels, rf, solverLevel)
partialAssignment.push(assignement)
val reason0 = setAndClose(assignement.rels, rf)
if (reason0 != null) {
propagateConflict(reason0.exprs)
}
propagate(reason0)

val writes = writes[rf.from.const.varDecl]!!.values.flatten()
.filter { w -> w.guard.isEmpty() || partialAssignment.any { it.event == w } }
for (w in writes) {
val reason = derive(assignement.rels, rf, w)
if (reason != null) {
propagateConflict(reason.exprs)
}
propagate(reason)
}
}

Expand All @@ -109,10 +107,15 @@ class UserPropagatorOcChecker<E : Event> : OcCheckerBase<E>, JavaSMTUserPropagat
partialAssignment.push(assignment)
for (rf in rfs) {
val reason = derive(assignment.rels, rf, w)
if (reason != null) {
propagateConflict(reason.exprs)
}
propagate(reason)
}
}
}

override fun propagate(reason: Reason?): Boolean {
reason ?: return false
propagated.add(reason)
userPropagator.propagateConflict(reason.exprs)
return true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ class XcfaOcChecker(xcfa: XCFA, decisionProcedure: OcDecisionProcedureType, priv
return true
}

// Extract counterexample trace from model

private fun getTrace(model: Valuation): Trace<XcfaState<*>, XcfaAction> {
val stateList = mutableListOf<XcfaState<ExplState>>()
val actionList = mutableListOf<XcfaAction>()
Expand Down

0 comments on commit 36238c6

Please sign in to comment.