Skip to content

Commit

Permalink
Bitwuzla ctx improvement (#79)
Browse files Browse the repository at this point in the history
* Better context cache

* Reorder cache search operations for better performance

* Use `toHashSet`
  • Loading branch information
Saloed authored Feb 28, 2023
1 parent 490fd83 commit 5c4a30d
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 63 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package org.ksmt.solver.bitwuzla

import org.ksmt.KContext
import org.ksmt.decl.KDecl
import org.ksmt.decl.KFuncDecl
import org.ksmt.expr.KArrayLambda
import org.ksmt.expr.KConst
import org.ksmt.expr.KExistentialQuantifier
import org.ksmt.expr.KExpr
import org.ksmt.expr.KFunctionApp
import org.ksmt.expr.KFunctionAsArray
import org.ksmt.expr.KUniversalQuantifier
import org.ksmt.expr.transformer.KNonRecursiveTransformer
import org.ksmt.solver.KSolverException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
Expand All @@ -19,7 +27,7 @@ import org.ksmt.sort.KSort
import org.ksmt.sort.KSortVisitor
import org.ksmt.sort.KUninterpretedSort

open class KBitwuzlaContext : AutoCloseable {
open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
private var isClosed = false

val bitwuzla = Native.bitwuzlaNew()
Expand All @@ -28,43 +36,67 @@ open class KBitwuzlaContext : AutoCloseable {
val falseTerm: BitwuzlaTerm by lazy { Native.bitwuzlaMkFalse(bitwuzla) }
val boolSort: BitwuzlaSort by lazy { Native.bitwuzlaMkBoolSort(bitwuzla) }

private var expressions = hashMapOf<KExpr<*>, BitwuzlaTerm>()
private val exprGlobalCache = hashMapOf<KExpr<*>, BitwuzlaTerm>()
private val bitwuzlaExpressions = hashMapOf<BitwuzlaTerm, KExpr<*>>()

private val constantsGlobalCache = hashMapOf<KDecl<*>, BitwuzlaTerm>()
private val bitwuzlaConstants = hashMapOf<BitwuzlaTerm, KDecl<*>>()

private val sorts = hashMapOf<KSort, BitwuzlaSort>()
private val bitwuzlaSorts = hashMapOf<BitwuzlaSort, KSort>()
private val declSorts = hashMapOf<KDecl<*>, BitwuzlaSort>()

private val bitwuzlaValues = hashMapOf<BitwuzlaTerm, KExpr<*>>()

private var constants = hashMapOf<KDecl<*>, BitwuzlaTerm>()
private val bitwuzlaConstants = hashMapOf<BitwuzlaTerm, KDecl<*>>()
private var exprCurrentLevelCache = hashMapOf<KExpr<*>, BitwuzlaTerm>()
private val exprCacheLevel = hashMapOf<KExpr<*>, Int>()
private val exprLeveledCache = arrayListOf(exprCurrentLevelCache)
private var currentLevelExprMover = ExprMover()

private var declarations = hashSetOf<KDecl<*>>()
private var uninterpretedSorts = hashMapOf<KUninterpretedSort, HashSet<KDecl<*>>>()
private var uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts)
private val currentLevel: Int
get() = exprLeveledCache.lastIndex

private var declarationScope = DeclarationScope(
expressions = expressions,
constants = constants,
declarations = declarations,
uninterpretedSorts = uninterpretedSorts,
parentScope = null
)
private var currentLevelDeclarations = hashSetOf<KDecl<*>>()
private val leveledDeclarations = arrayListOf(currentLevelDeclarations)

private var currentLevelUninterpretedSorts = hashMapOf<KUninterpretedSort, HashSet<KDecl<*>>>()
private var currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts)
private val leveledUninterpretedSorts = arrayListOf(currentLevelUninterpretedSorts)

operator fun get(expr: KExpr<*>): BitwuzlaTerm? = expressions[expr]
operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort]

/**
* Internalize ksmt expr into [BitwuzlaTerm] and cache internalization result to avoid
* internalization of already internalized expressions.
* Search for expression term and
* ensure correctness of currently known declarations.
*
* 1. Expression is in current level cache.
* All declarations are already known.
*
* 2. Expression is not in global cache.
* Expression is not internalized yet.
*
* [internalizer] must use special functions to internalize BitVec values ([internalizeBvValue])
* and constants ([mkConstant]).
* 3. Expression is in global cache but not in current level cache.
* Move expression and recollect declarations.
* See [ExprMover].
* */
fun internalizeExpr(expr: KExpr<*>, internalizer: (KExpr<*>) -> BitwuzlaTerm): BitwuzlaTerm =
expressions.getOrPut(expr) {
internalizer(expr) // don't reverse cache bitwuzla term since it may be rewrote
fun findExprTerm(expr: KExpr<*>): BitwuzlaTerm? {
val globalTerm = exprGlobalCache[expr] ?: return null

val currentLevelTerm = exprCurrentLevelCache[expr]
if (currentLevelTerm != null) return currentLevelTerm

currentLevelExprMover.apply(expr)

return globalTerm
}

fun saveExprTerm(expr: KExpr<*>, term: BitwuzlaTerm) {
if (exprCurrentLevelCache.putIfAbsent(expr, term) == null) {
exprGlobalCache[expr] = term
exprCacheLevel[expr] = currentLevel
}
}

operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort]

fun internalizeSort(sort: KSort, internalizer: (KSort) -> BitwuzlaSort): BitwuzlaSort =
sorts.getOrPut(sort) {
Expand Down Expand Up @@ -92,7 +124,7 @@ open class KBitwuzlaContext : AutoCloseable {
fun findConvertedExpr(expr: BitwuzlaTerm): KExpr<*>? = bitwuzlaExpressions[expr]

fun convertExpr(expr: BitwuzlaTerm, converter: (BitwuzlaTerm) -> KExpr<*>): KExpr<*> =
convert(expressions, bitwuzlaExpressions, expr, converter)
convert(exprGlobalCache, bitwuzlaExpressions, expr, converter)

fun convertSort(sort: BitwuzlaSort, converter: (BitwuzlaSort) -> KSort): KSort =
convert(sorts, bitwuzlaSorts, sort, converter)
Expand All @@ -103,33 +135,45 @@ open class KBitwuzlaContext : AutoCloseable {
fun convertConstantIfKnown(term: BitwuzlaTerm): KDecl<*>? = bitwuzlaConstants[term]

// Find normal constant if it was previously internalized
fun findConstant(decl: KDecl<*>): BitwuzlaTerm? = constants[decl]
fun findConstant(decl: KDecl<*>): BitwuzlaTerm? = constantsGlobalCache[decl]

fun declarations(): Set<KDecl<*>> =
leveledDeclarations.flatMapTo(hashSetOf()) { it }

fun declarations(): Set<KDecl<*>> = declarations
fun uninterpretedSortsWithRelevantDecls(): Map<KUninterpretedSort, Set<KDecl<*>>> {
val result = hashMapOf<KUninterpretedSort, MutableSet<KDecl<*>>>()

fun uninterpretedSortsWithRelevantDecls(): Map<KUninterpretedSort, Set<KDecl<*>>> = uninterpretedSorts
leveledUninterpretedSorts.forEach { levelSorts ->
levelSorts.forEach { entry ->
val values = result.getOrPut(entry.key) { hashSetOf() }
values.addAll(entry.value)
}
}

return result
}

/**
* Add declaration to the current declaration scope.
* Also, if declaration sort is uninterpreted,
* register this declaration as relevant to the sort.
* */
private fun registerDeclaration(decl: KDecl<*>) {
if (declarations.add(decl)) {
uninterpretedSortRegisterer.decl = decl
decl.sort.accept(uninterpretedSortRegisterer)
if (currentLevelDeclarations.add(decl)) {
currentLevelUninterpretedSortRegisterer.decl = decl
decl.sort.accept(currentLevelUninterpretedSortRegisterer)
if (decl is KFuncDecl<*>) {
decl.argSorts.forEach { it.accept(uninterpretedSortRegisterer) }
decl.argSorts.forEach { it.accept(currentLevelUninterpretedSortRegisterer) }
}
}
}

/**
* Internalize constant.
* Internalize constant declaration.
* Since [Native.bitwuzlaMkConst] creates fresh constant on each invocation caches are used
* to guarantee that if two constants are equal in ksmt they are also equal in Bitwuzla.
* */
fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm = constants.getOrPut(decl) {
fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm = constantsGlobalCache.getOrPut(decl) {
Native.bitwuzlaMkConst(bitwuzla, sort, decl.name).also {
bitwuzlaConstants[it] = decl
}
Expand All @@ -141,30 +185,36 @@ open class KBitwuzlaContext : AutoCloseable {
* and must match to the corresponding assertion level ([KBitwuzlaSolver.push]).
* */
fun createNestedDeclarationScope() {
expressions = expressions.toMap(hashMapOf())
constants = constants.toMap(hashMapOf())
declarations = declarations.toHashSet()
uninterpretedSorts = uninterpretedSorts.mapValuesTo(hashMapOf()) { (_, decls) -> decls.toHashSet() }
uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts)
declarationScope = DeclarationScope(expressions, constants, declarations, uninterpretedSorts, declarationScope)
exprCurrentLevelCache = hashMapOf()
exprLeveledCache.add(exprCurrentLevelCache)
currentLevelExprMover = ExprMover()

currentLevelDeclarations = hashSetOf()
leveledDeclarations.add(currentLevelDeclarations)

currentLevelUninterpretedSorts = hashMapOf()
leveledUninterpretedSorts.add(currentLevelUninterpretedSorts)
currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts)
}

/**
* Pop declaration scope to ensure that [declarations] match
* the set of asserted declarations at the current assertion level ([KBitwuzlaSolver.pop]).
*
* We also invalidate [expressions] internalization cache, since it may contain
* We also invalidate expressions internalization cache, since it may contain
* expressions with invalidated declarations.
* */
fun popDeclarationScope() {
declarationScope = declarationScope.parentScope
?: error("Can't pop root declaration scope")

expressions = declarationScope.expressions
constants = declarationScope.constants
declarations = declarationScope.declarations
uninterpretedSorts = declarationScope.uninterpretedSorts
uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts)
exprLeveledCache.removeLast()
exprCurrentLevelCache = exprLeveledCache.last()
currentLevelExprMover = ExprMover()

leveledDeclarations.removeLast()
currentLevelDeclarations = leveledDeclarations.last()

leveledUninterpretedSorts.removeLast()
currentLevelUninterpretedSorts = leveledUninterpretedSorts.last()
currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts)
}

inline fun <reified T> bitwuzlaTry(body: () -> T): T = try {
Expand All @@ -178,9 +228,14 @@ open class KBitwuzlaContext : AutoCloseable {
isClosed = true
sorts.clear()
bitwuzlaSorts.clear()
expressions.clear()

exprGlobalCache.clear()
bitwuzlaExpressions.clear()
constants.clear()
constantsGlobalCache.clear()

exprLeveledCache.clear()
exprCurrentLevelCache.clear()

declSorts.clear()
bitwuzlaConstants.clear()
Native.bitwuzlaDelete(bitwuzla)
Expand All @@ -207,14 +262,6 @@ open class KBitwuzlaContext : AutoCloseable {
return converted
}

private class DeclarationScope(
val expressions: HashMap<KExpr<*>, BitwuzlaTerm>,
val constants: HashMap<KDecl<*>, BitwuzlaTerm>,
val declarations: HashSet<KDecl<*>>,
val uninterpretedSorts: HashMap<KUninterpretedSort, HashSet<KDecl<*>>>,
val parentScope: DeclarationScope?
)

private class UninterpretedSortRegisterer(
private val register: MutableMap<KUninterpretedSort, HashSet<KDecl<*>>>
) : KSortVisitor<Unit> {
Expand Down Expand Up @@ -248,4 +295,82 @@ open class KBitwuzlaContext : AutoCloseable {
override fun visit(sort: KFpRoundingModeSort) {
}
}

/**
* Move expressions from other cache level to the current cache level
* and register declarations for all moved expressions.
*
* 1. If expression is valid on previous levels we don't need to move it.
* Also, all expression declarations are known.
*
* 2. Otherwise, move expression to current level and recollect declarations.
* */
private inner class ExprMover : KNonRecursiveTransformer(ctx) {
override fun <T : KSort> transformExpr(expr: KExpr<T>): KExpr<T> {
// Move expr to current level
val term = exprGlobalCache.getValue(expr)
exprCacheLevel[expr] = currentLevel
exprCurrentLevelCache[expr] = term

return super.transformExpr(expr)
}

override fun <T : KSort> exprTransformationRequired(expr: KExpr<T>): Boolean {
val cachedLevel = exprCacheLevel[expr]
if (cachedLevel != null && cachedLevel < currentLevel) {
val levelCache = exprLeveledCache[cachedLevel]
// If expr is valid on its level we don't need to move it
return expr !in levelCache
}
return super.exprTransformationRequired(expr)
}

private var currentlyIgnoredDeclarations: Set<KDecl<*>>? = null

private fun registerDeclIfNotIgnored(decl: KDecl<*>) {
if (currentlyIgnoredDeclarations?.contains(decl) == true) {
return
}
registerDeclaration(decl)
}

override fun <T : KSort> transform(expr: KFunctionApp<T>): KExpr<T> {
registerDeclIfNotIgnored(expr.decl)
return super.transform(expr)
}

override fun <T : KSort> transform(expr: KConst<T>): KExpr<T> {
registerDeclIfNotIgnored(expr.decl)
return super.transform(expr)
}

override fun <D : KSort, R : KSort> transform(expr: KFunctionAsArray<D, R>): KExpr<KArraySort<D, R>> {
registerDeclIfNotIgnored(expr.function)
return super.transform(expr)
}

private val quantifiedVarsScope = arrayListOf<Pair<KExpr<*>, Set<KDecl<*>>?>>()

private fun <T : KSort> KExpr<T>.transformQuantifier(bounds: List<KDecl<*>>, body: KExpr<*>): KExpr<T> {
if (quantifiedVarsScope.lastOrNull()?.first != this) {
quantifiedVarsScope.add(this to currentlyIgnoredDeclarations)
val ignoredDecls = currentlyIgnoredDeclarations?.toHashSet() ?: hashSetOf()
ignoredDecls.addAll(bounds)
currentlyIgnoredDeclarations = ignoredDecls
}
return transformExprAfterTransformed(this, body) {
currentlyIgnoredDeclarations = quantifiedVarsScope.removeLast().second
this
}
}

override fun <D : KSort, R : KSort> transform(expr: KArrayLambda<D, R>): KExpr<KArraySort<D, R>> =
expr.transformQuantifier(listOf(expr.indexVarDecl), expr.body)

override fun transform(expr: KExistentialQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)

override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ open class KBitwuzlaExprInternalizer(
}

override fun findInternalizedExpr(expr: KExpr<*>): BitwuzlaTerm? {
return bitwuzlaCtx[expr]
return bitwuzlaCtx.findExprTerm(expr)
}

override fun saveInternalizedExpr(expr: KExpr<*>, internalized: BitwuzlaTerm) {
Expand Down Expand Up @@ -252,7 +252,7 @@ open class KBitwuzlaExprInternalizer(
fun <T : KSort> KDecl<T>.bitwuzlaFunctionSort(): BitwuzlaSort = accept(functionSortInternalizer)

private fun saveExprInternalizationResult(expr: KExpr<*>, term: BitwuzlaTerm) {
bitwuzlaCtx.internalizeExpr(expr) { term }
bitwuzlaCtx.saveExprTerm(expr, term)

// Save only constants
if (expr !is KInterpretedValue<*>) return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ open class KBitwuzlaModel(
ctx.ensureContextMatch(decl)
bitwuzlaCtx.ensureActive()

if (decl !in modelDeclarations) return null

val interpretation = interpretations.getOrPut(decl) {
// Constant was not internalized --> constant is unknown to solver --> constant is not present in model
val bitwuzlaConstant = bitwuzlaCtx.findConstant(decl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.ksmt.sort.KBoolSort
import kotlin.time.Duration

open class KBitwuzlaSolver(private val ctx: KContext) : KSolver<KBitwuzlaSolverConfiguration> {
open val bitwuzlaCtx = KBitwuzlaContext()
open val bitwuzlaCtx = KBitwuzlaContext(ctx)
open val exprInternalizer: KBitwuzlaExprInternalizer by lazy {
KBitwuzlaExprInternalizer(bitwuzlaCtx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import kotlin.test.assertTrue

class ConverterTest {
private val ctx = KContext()
private val bitwuzlaCtx = KBitwuzlaContext()
private val bitwuzlaCtx = KBitwuzlaContext(ctx)
private val internalizer = KBitwuzlaExprInternalizer(bitwuzlaCtx)
private val converter = KBitwuzlaExprConverter(ctx, bitwuzlaCtx)
private val sortChecker = SortChecker(ctx)
Expand Down
Loading

0 comments on commit 5c4a30d

Please sign in to comment.