Skip to content

Commit

Permalink
SparseOracle refactor (#2744)
Browse files Browse the repository at this point in the history
* SparseOracle

* rework
  • Loading branch information
Kukovec authored Sep 22, 2023
1 parent e0a23e0 commit 281c66d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt.smt.SolverContext
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla

/**
* Given a set `indices` of size `N`, sparse oracle is able to answer {{{chosenValueIsEqualToIndexedValue(_, i),}}} for
* any `i \in indices`, despite having a size, which may be smaller than some (ot all) elements of `indices`.
* {{{chosenValueIsEqualToIndexedValue(_, j),}}} for any `j` not in `indices` does not hold.
*
* Internally, every SparseOracle `s` maintains an oracle `o` of size `N-1`, such that for every scope `X`, and any
* index `i \in indices` the following holds:
* {{{s.chosenValueIsEqualToIndexedValue(X,i) = o.chosenValueIsEqualToIndexedValue(x, idxMap(i),}}} where `idxMap` is
* some bijection from `indices` to `{0,...,N-1}`.
*
* @author
* Jure Kukovec
*
* @param mkOracle
* the method to create the backend oracle, of a given size
* @param values
* the set S of oracle values
*/
class SparseOracle(mkOracle: Int => Oracle, val values: Set[Int]) extends Oracle {
private[oracles] val oracle = mkOracle(values.size)
private[oracles] val sortedValues: Seq[Int] = values.toSeq.sorted
private[oracles] val indexMap: Map[Int, Int] = Map(sortedValues.zipWithIndex: _*)

override def size: Int = values.size

def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction =
indexMap
.get(index.toInt)
.map {
oracle.chosenValueIsEqualToIndexedValue(scope, _)
}
.getOrElse(tla.bool(false))

override def caseAssertions(
scope: RewriterScope,
assertions: Seq[TBuilderInstruction],
elseAssertionsOpt: Option[Seq[TBuilderInstruction]] = None): TBuilderInstruction =
oracle.caseAssertions(scope, assertions, elseAssertionsOpt)

override def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt = {
val oracleIdx = oracle.getIndexOfChosenValueFromModel(solverContext).toInt
sortedValues.applyOrElse[Int, Int](oracleIdx, _ => -1)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt.arena.PureArenaAdapter
import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext}
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope
import at.forsyte.apalache.tla.bmcmt.types.CellT
import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena}
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla
import org.junit.runner.RunWith
import org.scalacheck.Gen
import org.scalacheck.Prop.forAll
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.junit.JUnitRunner
import org.scalatestplus.scalacheck.Checkers

@RunWith(classOf[JUnitRunner])
class TestSparseOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {

var initScope: RewriterScope = RewriterScope.initial()
var intOracleCell: ArenaCell = PureArena.cellInvalid

def mkIntOracle(n: Int): Oracle = new IntOracle(intOracleCell, n)

override def beforeEach(): Unit = {
val scope0 = RewriterScope.initial()
initScope = scope0.copy(arena = scope0.arena.appendCell(CellT.fromType1(IntT1)))
intOracleCell = initScope.arena.topCell
}

val intGen: Gen[Int] = Gen.choose(-10, 10)
val nonNegIntGen: Gen[Int] = Gen.choose(0, 20)

val setGen: Gen[Set[Int]] = for {
maxSize <- nonNegIntGen
elems <- Gen.listOfN(maxSize, nonNegIntGen)
} yield elems.toSet

val nonemptySetAndIdxGen: Gen[(Set[Int], Int)] = for {
maxSize <- Gen.choose(1, 20)
elems <- Gen.listOfN(maxSize, nonNegIntGen)
idx <- Gen.oneOf(elems)
} yield (elems.toSet, idx)

test("chosenValueIsEqualToIndexedValue returns the correct value for any element of the constructor set") {
val prop =
forAll(Gen.zip(setGen, intGen)) { case (set, index) =>
val oracle = new SparseOracle(mkIntOracle, set)
val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(initScope, index)
if (!set.contains(index))
cmp == tla.bool(false).build
else {
cmp == oracle.oracle.chosenValueIsEqualToIndexedValue(initScope, oracle.indexMap(index)).build
}
}

check(prop, minSuccessful(1000), sizeRange(4))
}

val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0
.to(10)
.map { i =>
(tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1))
}
.unzip

// "caseAssertions" tests ignored, since SparseOracle literally just invokes the underlying oracle's method,
// which should have its own tests

// We cannot test getIndexOfChosenValueFromModel without running the solver
test("getIndexOfChosenValueFromModel recovers the index correctly") {
val ctx = new Z3SolverContext(SolverConfig.default)
val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization
val paa2 = paa.appendCell(IntT1) // also declares the cell
intOracleCell = paa2.topCell
initScope = initScope.copy(arena = paa2.arena)
val prop =
forAll(nonemptySetAndIdxGen) { case (set, index) =>
val oracle = new SparseOracle(mkIntOracle, set)
val eql = oracle.chosenValueIsEqualToIndexedValue(initScope, index)
ctx.push()
ctx.assertGroundExpr(eql)
ctx.sat()
val ret = oracle.getIndexOfChosenValueFromModel(ctx) == index
ctx.pop()
ret
}

// 1000 is too many, since each run invokes the solver
check(prop, minSuccessful(80), sizeRange(4))
}
}

0 comments on commit 281c66d

Please sign in to comment.