Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UninterpretedConstOracle refactor #2734

Merged
merged 11 commits into from
Oct 19, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt._
import at.forsyte.apalache.tla.bmcmt.smt.SolverContext
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope}
import at.forsyte.apalache.tla.bmcmt.types.CellT
import at.forsyte.apalache.tla.lir.ConstT1
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla

/**
* An oracle that uses a fixed collection of potential cells.
*
* The oracle value must be equal to one of the cells, if the collection is nonempty.
*
* @author
* Jure Kukovec
*/
class UninterpretedConstOracle(val valueCells: Seq[ArenaCell], val oracleCell: ArenaCell) extends Oracle {

/**
* The number of values that this oracle is defined over: |valueCells|
*/
override def size: Int = valueCells.size

override def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction =
if (valueCells.indices.contains(index)) tla.eql(oracleCell.toBuilder, valueCells(index.toInt).toBuilder)
else tla.bool(false)

def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt =
// the oracle must be equal to one of the values. If not, indexWhere returns -1
valueCells.indexWhere { valueCell =>
val eq = tla.eql(valueCell.toBuilder, oracleCell.toBuilder)
solverContext.evalGroundExpr(eq) == tla.bool(true).build
}
}

object UninterpretedConstOracle {

/**
* Designated type to be used in this oracle.
*/
val UNINTERPRETED_TYPE: ConstT1 = ConstT1("_ORA")

def create(
rewriter: Rewriter,
cache: UninterpretedLiteralCache,
scope: RewriterScope,
nvalues: Int): (RewriterScope, UninterpretedConstOracle) = {
require(nvalues >= 0, "UninterpretedConstOracle must have a non-negative number of candidate values.")
val (newArena, valueCells) =
0.until(nvalues).map(_.toString).foldLeft((scope.arena, Seq.empty[ArenaCell])) { case ((arena, cells), name) =>
val (newArena, newCell) = cache.getOrCreate(arena, (UNINTERPRETED_TYPE, name))
(newArena, cells :+ newCell)
}
val arenaWithCell = newArena.appendCell(CellT.fromType1(UNINTERPRETED_TYPE))
val newScope = scope.copy(arena = arenaWithCell)
val oracleCell = arenaWithCell.topCell
val oracle = new UninterpretedConstOracle(valueCells, oracleCell)

// the oracle value must be equal to one of the value cells, if there are any
if (nvalues > 0)
rewriter.assert(tla.or(0.until(nvalues).map(i => oracle.chosenValueIsEqualToIndexedValue(newScope, i)): _*))
(newScope, oracle)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TestIntOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {
check(prop, minSuccessful(100), sizeRange(4))
}

test("oracleValueIsEqualToIndexedValue returns an integer comparison") {
test("chosenValueIsEqualToIndexedValue returns an integer comparison") {
val prop =
forAll(maxSizeAndIndexGen) { case (size, index) =>
val (scope, oracle) = IntOracle.create(initScope, size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt.PureArena
import at.forsyte.apalache.tla.bmcmt.arena.PureArenaAdapter
import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext}
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope, TestingRewriter}
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.oper.TlaOper
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla
import org.junit.runner.RunWith
import org.scalacheck.Prop.forAll
import org.scalacheck.{Gen, Prop}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.junit.JUnitRunner
import org.scalatestplus.scalacheck.Checkers

@RunWith(classOf[JUnitRunner])
class TestUCOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortening the class name makes this hard to find (if I'm looking for the tests for UninterpretedConstOracle).

Suggested change
class TestUCOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {
class TestUninterpretedConstOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll have to add this manually though to also rename the file.


var rewriter: Rewriter = TestingRewriter(Map.empty)
var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache
var initScope: RewriterScope = RewriterScope.initial()

override def beforeEach(): Unit = {
rewriter = TestingRewriter(Map.empty)
cache = new UninterpretedLiteralCache
initScope = RewriterScope.initial()
}

val intGen: Gen[Int] = Gen.choose(-10, 10)
val nonNegIntGen: Gen[Int] = Gen.choose(0, 10)

val maxSizeAndIndexGen: Gen[(Int, Int)] = for {
max <- Gen.choose(1, 10) // size 0 is degenerate
idx <- Gen.choose(0, max - 1) // index must be <
} yield (max, idx)

test("Oracle cannot be constructed with negative size") {
val prop =
forAll(intGen) {
case i if i < 0 =>
Prop.throws(classOf[IllegalArgumentException]) {
UninterpretedConstOracle.create(rewriter, cache, initScope, i)
}
case i => UninterpretedConstOracle.create(rewriter, cache, initScope, i)._2.size == i
}

check(prop, minSuccessful(100), sizeRange(4))
}

test("chosenValueIsEqualToIndexedValue returns an equality, or shorthands") {
val prop =
forAll(Gen.zip(nonNegIntGen, intGen)) { case (size, index) =>
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size)
val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(scope, index)
if (index < 0 || index >= size)
cmp == tla.bool(false).build
else
cmp match {
case OperEx(TlaOper.eq, NameEx(name1), NameEx(name2)) =>
name1 == oracle.oracleCell.toString && name2 == oracle.valueCells(index).toString
case _ => false
}
}

check(prop, minSuccessful(200), sizeRange(4))
}

val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0
.to(10)
.map { i =>
(tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1))
}
.unzip

test("caseAssertions requires assertion sequences of equal length") {
val assertionsGen: Gen[(Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for {
i <- Gen.choose(0, assertionsA.size)
j <- Gen.choose(0, assertionsB.size)
opt <- Gen.option(Gen.const(assertionsB.take(j)))
} yield (assertionsA.take(i), opt)

val prop =
forAll(Gen.zip(nonNegIntGen, assertionsGen)) { case (size, (assertionsIfTrue, assertionsIfFalseOpt)) =>
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size)
if (assertionsIfTrue.size != oracle.size || assertionsIfFalseOpt.exists { _.size != oracle.size })
Prop.throws(classOf[IllegalArgumentException]) {
oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt)
}
else true
}

check(prop, minSuccessful(200), sizeRange(4))
}

test("caseAssertions constructs a collection of ITEs, or shorthands") {
val gen: Gen[(Int, Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for {
size <- nonNegIntGen
opt <- Gen.option(Gen.const(assertionsB.take(size)))
} yield (size, assertionsA.take(size), opt)

val prop =
forAll(gen) { case (size, assertionsIfTrue, assertionsIfFalseOpt) =>
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size)
val caseEx: TlaEx = oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt)
size match {
case 0 =>
caseEx == PureArena.cellTrue(scope.arena).toBuilder.build
case 1 =>
caseEx == assertionsA.head.build
case _ =>
assertionsIfFalseOpt match {
case None =>
val ites = assertionsIfTrue.zip(oracle.valueCells).map { case (a, c) =>
tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), a, tla.bool(true))
}
caseEx == tla.and(ites: _*).build
case Some(assertionsIfFalse) =>
val ites = assertionsIfTrue.zip(assertionsIfFalse).zip(oracle.valueCells).map { case ((at, af), c) =>
tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), at, af)
}
caseEx == tla.and(ites: _*).build
}
}
}

check(prop, minSuccessful(200), sizeRange(4))
}

// We cannot test getIndexOfChosenValueFromModel without running the solver
test("getIndexOfChosenValueFromModel recovers the index correctly for nonempty cell collection") {
val ctx = new Z3SolverContext(SolverConfig.default)
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization
initScope = initScope.copy(arena = paa.arena)
val prop =
forAll(maxSizeAndIndexGen) { case (size, index) =>
cache.dispose() // prevent redeclarations in every loop
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size)
ctx.push()
oracle.valueCells.foreach(ctx.declareCell)
ctx.declareCell(oracle.oracleCell)
cache.addAllConstraints(ctx)
val eql = oracle.chosenValueIsEqualToIndexedValue(scope, index)
ctx.assertGroundExpr(eql)
ctx.sat()
val ret = oracle.getIndexOfChosenValueFromModel(ctx) == index
ctx.pop()
ret
}

// 1000 is too many, since each run invokes the solver
check(prop, minSuccessful(80), sizeRange(4))
}

test("getIndexOfChosenValueFromModel recovers the index correctly for empty collections") {
val ctx = new Z3SolverContext(SolverConfig.default)
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization
val (_, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope.copy(arena = paa.arena), 0)
ctx.declareCell(oracle.oracleCell)
ctx.sat()
assert(oracle.getIndexOfChosenValueFromModel(ctx) == -1)
}

}
Loading