-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* SparseOracle * rework
- Loading branch information
Showing
2 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
51 changes: 51 additions & 0 deletions
51
...c/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/SparseOracle.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles | ||
|
||
import at.forsyte.apalache.tla.bmcmt.smt.SolverContext | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope | ||
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction | ||
import at.forsyte.apalache.tla.types.tla | ||
|
||
/** | ||
* Given a set `indices` of size `N`, sparse oracle is able to answer {{{chosenValueIsEqualToIndexedValue(_, i),}}} for | ||
* any `i \in indices`, despite having a size, which may be smaller than some (ot all) elements of `indices`. | ||
* {{{chosenValueIsEqualToIndexedValue(_, j),}}} for any `j` not in `indices` does not hold. | ||
* | ||
* Internally, every SparseOracle `s` maintains an oracle `o` of size `N-1`, such that for every scope `X`, and any | ||
* index `i \in indices` the following holds: | ||
* {{{s.chosenValueIsEqualToIndexedValue(X,i) = o.chosenValueIsEqualToIndexedValue(x, idxMap(i),}}} where `idxMap` is | ||
* some bijection from `indices` to `{0,...,N-1}`. | ||
* | ||
* @author | ||
* Jure Kukovec | ||
* | ||
* @param mkOracle | ||
* the method to create the backend oracle, of a given size | ||
* @param values | ||
* the set S of oracle values | ||
*/ | ||
class SparseOracle(mkOracle: Int => Oracle, val values: Set[Int]) extends Oracle { | ||
private[oracles] val oracle = mkOracle(values.size) | ||
private[oracles] val sortedValues: Seq[Int] = values.toSeq.sorted | ||
private[oracles] val indexMap: Map[Int, Int] = Map(sortedValues.zipWithIndex: _*) | ||
|
||
override def size: Int = values.size | ||
|
||
def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction = | ||
indexMap | ||
.get(index.toInt) | ||
.map { | ||
oracle.chosenValueIsEqualToIndexedValue(scope, _) | ||
} | ||
.getOrElse(tla.bool(false)) | ||
|
||
override def caseAssertions( | ||
scope: RewriterScope, | ||
assertions: Seq[TBuilderInstruction], | ||
elseAssertionsOpt: Option[Seq[TBuilderInstruction]] = None): TBuilderInstruction = | ||
oracle.caseAssertions(scope, assertions, elseAssertionsOpt) | ||
|
||
override def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt = { | ||
val oracleIdx = oracle.getIndexOfChosenValueFromModel(solverContext).toInt | ||
sortedValues.applyOrElse[Int, Int](oracleIdx, _ => -1) | ||
} | ||
} |
94 changes: 94 additions & 0 deletions
94
...st/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestSparseOracle.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles | ||
|
||
import at.forsyte.apalache.tla.bmcmt.arena.PureArenaAdapter | ||
import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext} | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope | ||
import at.forsyte.apalache.tla.bmcmt.types.CellT | ||
import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena} | ||
import at.forsyte.apalache.tla.lir._ | ||
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction | ||
import at.forsyte.apalache.tla.types.tla | ||
import org.junit.runner.RunWith | ||
import org.scalacheck.Gen | ||
import org.scalacheck.Prop.forAll | ||
import org.scalatest.BeforeAndAfterEach | ||
import org.scalatest.funsuite.AnyFunSuite | ||
import org.scalatestplus.junit.JUnitRunner | ||
import org.scalatestplus.scalacheck.Checkers | ||
|
||
@RunWith(classOf[JUnitRunner]) | ||
class TestSparseOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers { | ||
|
||
var initScope: RewriterScope = RewriterScope.initial() | ||
var intOracleCell: ArenaCell = PureArena.cellInvalid | ||
|
||
def mkIntOracle(n: Int): Oracle = new IntOracle(intOracleCell, n) | ||
|
||
override def beforeEach(): Unit = { | ||
val scope0 = RewriterScope.initial() | ||
initScope = scope0.copy(arena = scope0.arena.appendCell(CellT.fromType1(IntT1))) | ||
intOracleCell = initScope.arena.topCell | ||
} | ||
|
||
val intGen: Gen[Int] = Gen.choose(-10, 10) | ||
val nonNegIntGen: Gen[Int] = Gen.choose(0, 20) | ||
|
||
val setGen: Gen[Set[Int]] = for { | ||
maxSize <- nonNegIntGen | ||
elems <- Gen.listOfN(maxSize, nonNegIntGen) | ||
} yield elems.toSet | ||
|
||
val nonemptySetAndIdxGen: Gen[(Set[Int], Int)] = for { | ||
maxSize <- Gen.choose(1, 20) | ||
elems <- Gen.listOfN(maxSize, nonNegIntGen) | ||
idx <- Gen.oneOf(elems) | ||
} yield (elems.toSet, idx) | ||
|
||
test("chosenValueIsEqualToIndexedValue returns the correct value for any element of the constructor set") { | ||
val prop = | ||
forAll(Gen.zip(setGen, intGen)) { case (set, index) => | ||
val oracle = new SparseOracle(mkIntOracle, set) | ||
val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(initScope, index) | ||
if (!set.contains(index)) | ||
cmp == tla.bool(false).build | ||
else { | ||
cmp == oracle.oracle.chosenValueIsEqualToIndexedValue(initScope, oracle.indexMap(index)).build | ||
} | ||
} | ||
|
||
check(prop, minSuccessful(1000), sizeRange(4)) | ||
} | ||
|
||
val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0 | ||
.to(10) | ||
.map { i => | ||
(tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1)) | ||
} | ||
.unzip | ||
|
||
// "caseAssertions" tests ignored, since SparseOracle literally just invokes the underlying oracle's method, | ||
// which should have its own tests | ||
|
||
// We cannot test getIndexOfChosenValueFromModel without running the solver | ||
test("getIndexOfChosenValueFromModel recovers the index correctly") { | ||
val ctx = new Z3SolverContext(SolverConfig.default) | ||
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization | ||
val paa2 = paa.appendCell(IntT1) // also declares the cell | ||
intOracleCell = paa2.topCell | ||
initScope = initScope.copy(arena = paa2.arena) | ||
val prop = | ||
forAll(nonemptySetAndIdxGen) { case (set, index) => | ||
val oracle = new SparseOracle(mkIntOracle, set) | ||
val eql = oracle.chosenValueIsEqualToIndexedValue(initScope, index) | ||
ctx.push() | ||
ctx.assertGroundExpr(eql) | ||
ctx.sat() | ||
val ret = oracle.getIndexOfChosenValueFromModel(ctx) == index | ||
ctx.pop() | ||
ret | ||
} | ||
|
||
// 1000 is too many, since each run invokes the solver | ||
check(prop, minSuccessful(80), sizeRange(4)) | ||
} | ||
} |