-
-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UninterpretedConstOracle
refactor
#2734
Merged
Merged
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
07d2acb
UninterpretedConstOracle
Kukovec 7a62103
Merge branch 'main' into jk/oracles2
Kukovec 1461206
Reduced # of tests for CI
Kukovec 4a64389
Merge branch 'main' into jk/oracles2
Kukovec 617fa06
solver reuse
Kukovec 67890e2
Merge branch 'main' into jk/oracles2
Kukovec 6a3d378
Merge branch 'main' into jk/oracles2
Kukovec 85afcfb
Merge branch 'main' into jk/oracles2
Kukovec 1cdb0c8
Merge branch 'main' into jk/oracles2
Kukovec c857f3d
Suggestion by Thomas
Kukovec 66d757c
test ignore
Kukovec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
67 changes: 67 additions & 0 deletions
67
.../at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/UninterpretedConstOracle.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles | ||
|
||
import at.forsyte.apalache.tla.bmcmt._ | ||
import at.forsyte.apalache.tla.bmcmt.smt.SolverContext | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope} | ||
import at.forsyte.apalache.tla.bmcmt.types.CellT | ||
import at.forsyte.apalache.tla.lir.ConstT1 | ||
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction | ||
import at.forsyte.apalache.tla.types.tla | ||
|
||
/** | ||
* An oracle that uses a fixed collection of potential cells. | ||
* | ||
* The oracle value must be equal to one of the cells, if the collection is nonempty. | ||
* | ||
* @author | ||
* Jure Kukovec | ||
*/ | ||
class UninterpretedConstOracle(val valueCells: Seq[ArenaCell], val oracleCell: ArenaCell) extends Oracle { | ||
|
||
/** | ||
* The number of values that this oracle is defined over: |valueCells| | ||
*/ | ||
override def size: Int = valueCells.size | ||
|
||
override def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction = | ||
if (valueCells.indices.contains(index)) tla.eql(oracleCell.toBuilder, valueCells(index.toInt).toBuilder) | ||
else tla.bool(false) | ||
|
||
def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt = | ||
// the oracle must be equal to one of the values. If not, indexWhere returns -1 | ||
valueCells.indexWhere { valueCell => | ||
val eq = tla.eql(valueCell.toBuilder, oracleCell.toBuilder) | ||
solverContext.evalGroundExpr(eq) == tla.bool(true).build | ||
} | ||
} | ||
|
||
object UninterpretedConstOracle { | ||
|
||
/** | ||
* Designated type to be used in this oracle. | ||
*/ | ||
val UNINTERPRETED_TYPE: ConstT1 = ConstT1("_ORA") | ||
|
||
def create( | ||
rewriter: Rewriter, | ||
cache: UninterpretedLiteralCache, | ||
scope: RewriterScope, | ||
nvalues: Int): (RewriterScope, UninterpretedConstOracle) = { | ||
require(nvalues >= 0, "UninterpretedConstOracle must have a non-negative number of candidate values.") | ||
val (newArena, valueCells) = | ||
0.until(nvalues).map(_.toString).foldLeft((scope.arena, Seq.empty[ArenaCell])) { case ((arena, cells), name) => | ||
val (newArena, newCell) = cache.getOrCreate(arena, (UNINTERPRETED_TYPE, name)) | ||
(newArena, cells :+ newCell) | ||
} | ||
val arenaWithCell = newArena.appendCell(CellT.fromType1(UNINTERPRETED_TYPE)) | ||
val newScope = scope.copy(arena = arenaWithCell) | ||
val oracleCell = arenaWithCell.topCell | ||
val oracle = new UninterpretedConstOracle(valueCells, oracleCell) | ||
|
||
// the oracle value must be equal to one of the value cells, if there are any | ||
if (nvalues > 0) | ||
rewriter.assert(tla.or(0.until(nvalues).map(i => oracle.chosenValueIsEqualToIndexedValue(newScope, i)): _*)) | ||
(newScope, oracle) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
...c/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestUCOracle.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles | ||
|
||
import at.forsyte.apalache.tla.bmcmt.PureArena | ||
import at.forsyte.apalache.tla.bmcmt.arena.PureArenaAdapter | ||
import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext} | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope, TestingRewriter} | ||
import at.forsyte.apalache.tla.lir._ | ||
import at.forsyte.apalache.tla.lir.oper.TlaOper | ||
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction | ||
import at.forsyte.apalache.tla.types.tla | ||
import org.junit.runner.RunWith | ||
import org.scalacheck.Prop.forAll | ||
import org.scalacheck.{Gen, Prop} | ||
import org.scalatest.BeforeAndAfterEach | ||
import org.scalatest.funsuite.AnyFunSuite | ||
import org.scalatestplus.junit.JUnitRunner | ||
import org.scalatestplus.scalacheck.Checkers | ||
|
||
@RunWith(classOf[JUnitRunner]) | ||
class TestUCOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers { | ||
|
||
var rewriter: Rewriter = TestingRewriter(Map.empty) | ||
var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache | ||
var initScope: RewriterScope = RewriterScope.initial() | ||
|
||
override def beforeEach(): Unit = { | ||
rewriter = TestingRewriter(Map.empty) | ||
cache = new UninterpretedLiteralCache | ||
initScope = RewriterScope.initial() | ||
} | ||
|
||
val intGen: Gen[Int] = Gen.choose(-10, 10) | ||
val nonNegIntGen: Gen[Int] = Gen.choose(0, 10) | ||
|
||
val maxSizeAndIndexGen: Gen[(Int, Int)] = for { | ||
max <- Gen.choose(1, 10) // size 0 is degenerate | ||
idx <- Gen.choose(0, max - 1) // index must be < | ||
} yield (max, idx) | ||
|
||
test("Oracle cannot be constructed with negative size") { | ||
val prop = | ||
forAll(intGen) { | ||
case i if i < 0 => | ||
Prop.throws(classOf[IllegalArgumentException]) { | ||
UninterpretedConstOracle.create(rewriter, cache, initScope, i) | ||
} | ||
case i => UninterpretedConstOracle.create(rewriter, cache, initScope, i)._2.size == i | ||
} | ||
|
||
check(prop, minSuccessful(100), sizeRange(4)) | ||
} | ||
|
||
test("chosenValueIsEqualToIndexedValue returns an equality, or shorthands") { | ||
val prop = | ||
forAll(Gen.zip(nonNegIntGen, intGen)) { case (size, index) => | ||
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) | ||
val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(scope, index) | ||
if (index < 0 || index >= size) | ||
cmp == tla.bool(false).build | ||
else | ||
cmp match { | ||
case OperEx(TlaOper.eq, NameEx(name1), NameEx(name2)) => | ||
name1 == oracle.oracleCell.toString && name2 == oracle.valueCells(index).toString | ||
case _ => false | ||
} | ||
} | ||
|
||
check(prop, minSuccessful(200), sizeRange(4)) | ||
} | ||
|
||
val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0 | ||
.to(10) | ||
.map { i => | ||
(tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1)) | ||
} | ||
.unzip | ||
|
||
test("caseAssertions requires assertion sequences of equal length") { | ||
val assertionsGen: Gen[(Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for { | ||
i <- Gen.choose(0, assertionsA.size) | ||
j <- Gen.choose(0, assertionsB.size) | ||
opt <- Gen.option(Gen.const(assertionsB.take(j))) | ||
} yield (assertionsA.take(i), opt) | ||
|
||
val prop = | ||
forAll(Gen.zip(nonNegIntGen, assertionsGen)) { case (size, (assertionsIfTrue, assertionsIfFalseOpt)) => | ||
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) | ||
if (assertionsIfTrue.size != oracle.size || assertionsIfFalseOpt.exists { _.size != oracle.size }) | ||
Prop.throws(classOf[IllegalArgumentException]) { | ||
oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt) | ||
} | ||
else true | ||
} | ||
|
||
check(prop, minSuccessful(200), sizeRange(4)) | ||
} | ||
|
||
test("caseAssertions constructs a collection of ITEs, or shorthands") { | ||
val gen: Gen[(Int, Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for { | ||
size <- nonNegIntGen | ||
opt <- Gen.option(Gen.const(assertionsB.take(size))) | ||
} yield (size, assertionsA.take(size), opt) | ||
|
||
val prop = | ||
forAll(gen) { case (size, assertionsIfTrue, assertionsIfFalseOpt) => | ||
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) | ||
val caseEx: TlaEx = oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt) | ||
size match { | ||
case 0 => | ||
caseEx == PureArena.cellTrue(scope.arena).toBuilder.build | ||
case 1 => | ||
caseEx == assertionsA.head.build | ||
case _ => | ||
assertionsIfFalseOpt match { | ||
case None => | ||
val ites = assertionsIfTrue.zip(oracle.valueCells).map { case (a, c) => | ||
tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), a, tla.bool(true)) | ||
} | ||
caseEx == tla.and(ites: _*).build | ||
case Some(assertionsIfFalse) => | ||
val ites = assertionsIfTrue.zip(assertionsIfFalse).zip(oracle.valueCells).map { case ((at, af), c) => | ||
tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), at, af) | ||
} | ||
caseEx == tla.and(ites: _*).build | ||
} | ||
} | ||
} | ||
|
||
check(prop, minSuccessful(200), sizeRange(4)) | ||
} | ||
|
||
// We cannot test getIndexOfChosenValueFromModel without running the solver | ||
test("getIndexOfChosenValueFromModel recovers the index correctly for nonempty cell collection") { | ||
val ctx = new Z3SolverContext(SolverConfig.default) | ||
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization | ||
initScope = initScope.copy(arena = paa.arena) | ||
val prop = | ||
forAll(maxSizeAndIndexGen) { case (size, index) => | ||
cache.dispose() // prevent redeclarations in every loop | ||
val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) | ||
ctx.push() | ||
oracle.valueCells.foreach(ctx.declareCell) | ||
ctx.declareCell(oracle.oracleCell) | ||
cache.addAllConstraints(ctx) | ||
val eql = oracle.chosenValueIsEqualToIndexedValue(scope, index) | ||
ctx.assertGroundExpr(eql) | ||
ctx.sat() | ||
val ret = oracle.getIndexOfChosenValueFromModel(ctx) == index | ||
ctx.pop() | ||
ret | ||
} | ||
|
||
// 1000 is too many, since each run invokes the solver | ||
check(prop, minSuccessful(80), sizeRange(4)) | ||
} | ||
|
||
test("getIndexOfChosenValueFromModel recovers the index correctly for empty collections") { | ||
val ctx = new Z3SolverContext(SolverConfig.default) | ||
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization | ||
val (_, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope.copy(arena = paa.arena), 0) | ||
ctx.declareCell(oracle.oracleCell) | ||
ctx.sat() | ||
assert(oracle.getIndexOfChosenValueFromModel(ctx) == -1) | ||
} | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shortening the class name makes this hard to find (if I'm looking for the tests for
UninterpretedConstOracle
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'll have to add this manually though to also rename the file.