From 5c4a30d78b3976f784c82ff0c1a04453c1cd7217 Mon Sep 17 00:00:00 2001 From: Valentyn Sobol Date: Tue, 28 Feb 2023 20:59:48 +0300 Subject: [PATCH] Bitwuzla ctx improvement (#79) * Better context cache * Reorder cache search operations for better performance * Use `toHashSet` --- .../ksmt/solver/bitwuzla/KBitwuzlaContext.kt | 241 +++++++++++++----- .../bitwuzla/KBitwuzlaExprInternalizer.kt | 4 +- .../ksmt/solver/bitwuzla/KBitwuzlaModel.kt | 2 + .../ksmt/solver/bitwuzla/KBitwuzlaSolver.kt | 2 +- .../org/ksmt/solver/bitwuzla/ConverterTest.kt | 2 +- .../kotlin/org/ksmt/test/TestWorkerProcess.kt | 2 +- 6 files changed, 190 insertions(+), 63 deletions(-) diff --git a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaContext.kt b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaContext.kt index 53af7d73c..b678fe7b7 100644 --- a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaContext.kt +++ b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaContext.kt @@ -1,8 +1,16 @@ package org.ksmt.solver.bitwuzla +import org.ksmt.KContext import org.ksmt.decl.KDecl import org.ksmt.decl.KFuncDecl +import org.ksmt.expr.KArrayLambda +import org.ksmt.expr.KConst +import org.ksmt.expr.KExistentialQuantifier import org.ksmt.expr.KExpr +import org.ksmt.expr.KFunctionApp +import org.ksmt.expr.KFunctionAsArray +import org.ksmt.expr.KUniversalQuantifier +import org.ksmt.expr.transformer.KNonRecursiveTransformer import org.ksmt.solver.KSolverException import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort @@ -19,7 +27,7 @@ import org.ksmt.sort.KSort import org.ksmt.sort.KSortVisitor import org.ksmt.sort.KUninterpretedSort -open class KBitwuzlaContext : AutoCloseable { +open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { private var isClosed = false val bitwuzla = Native.bitwuzlaNew() @@ -28,43 +36,67 @@ open class KBitwuzlaContext : AutoCloseable { val falseTerm: BitwuzlaTerm by lazy { Native.bitwuzlaMkFalse(bitwuzla) } val boolSort: BitwuzlaSort by lazy { Native.bitwuzlaMkBoolSort(bitwuzla) } - private var expressions = hashMapOf, BitwuzlaTerm>() + private val exprGlobalCache = hashMapOf, BitwuzlaTerm>() private val bitwuzlaExpressions = hashMapOf>() + + private val constantsGlobalCache = hashMapOf, BitwuzlaTerm>() + private val bitwuzlaConstants = hashMapOf>() + private val sorts = hashMapOf() private val bitwuzlaSorts = hashMapOf() private val declSorts = hashMapOf, BitwuzlaSort>() private val bitwuzlaValues = hashMapOf>() - private var constants = hashMapOf, BitwuzlaTerm>() - private val bitwuzlaConstants = hashMapOf>() + private var exprCurrentLevelCache = hashMapOf, BitwuzlaTerm>() + private val exprCacheLevel = hashMapOf, Int>() + private val exprLeveledCache = arrayListOf(exprCurrentLevelCache) + private var currentLevelExprMover = ExprMover() - private var declarations = hashSetOf>() - private var uninterpretedSorts = hashMapOf>>() - private var uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts) + private val currentLevel: Int + get() = exprLeveledCache.lastIndex - private var declarationScope = DeclarationScope( - expressions = expressions, - constants = constants, - declarations = declarations, - uninterpretedSorts = uninterpretedSorts, - parentScope = null - ) + private var currentLevelDeclarations = hashSetOf>() + private val leveledDeclarations = arrayListOf(currentLevelDeclarations) + + private var currentLevelUninterpretedSorts = hashMapOf>>() + private var currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts) + private val leveledUninterpretedSorts = arrayListOf(currentLevelUninterpretedSorts) - operator fun get(expr: KExpr<*>): BitwuzlaTerm? = expressions[expr] - operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort] /** - * Internalize ksmt expr into [BitwuzlaTerm] and cache internalization result to avoid - * internalization of already internalized expressions. + * Search for expression term and + * ensure correctness of currently known declarations. + * + * 1. Expression is in current level cache. + * All declarations are already known. + * + * 2. Expression is not in global cache. + * Expression is not internalized yet. * - * [internalizer] must use special functions to internalize BitVec values ([internalizeBvValue]) - * and constants ([mkConstant]). + * 3. Expression is in global cache but not in current level cache. + * Move expression and recollect declarations. + * See [ExprMover]. * */ - fun internalizeExpr(expr: KExpr<*>, internalizer: (KExpr<*>) -> BitwuzlaTerm): BitwuzlaTerm = - expressions.getOrPut(expr) { - internalizer(expr) // don't reverse cache bitwuzla term since it may be rewrote + fun findExprTerm(expr: KExpr<*>): BitwuzlaTerm? { + val globalTerm = exprGlobalCache[expr] ?: return null + + val currentLevelTerm = exprCurrentLevelCache[expr] + if (currentLevelTerm != null) return currentLevelTerm + + currentLevelExprMover.apply(expr) + + return globalTerm + } + + fun saveExprTerm(expr: KExpr<*>, term: BitwuzlaTerm) { + if (exprCurrentLevelCache.putIfAbsent(expr, term) == null) { + exprGlobalCache[expr] = term + exprCacheLevel[expr] = currentLevel } + } + + operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort] fun internalizeSort(sort: KSort, internalizer: (KSort) -> BitwuzlaSort): BitwuzlaSort = sorts.getOrPut(sort) { @@ -92,7 +124,7 @@ open class KBitwuzlaContext : AutoCloseable { fun findConvertedExpr(expr: BitwuzlaTerm): KExpr<*>? = bitwuzlaExpressions[expr] fun convertExpr(expr: BitwuzlaTerm, converter: (BitwuzlaTerm) -> KExpr<*>): KExpr<*> = - convert(expressions, bitwuzlaExpressions, expr, converter) + convert(exprGlobalCache, bitwuzlaExpressions, expr, converter) fun convertSort(sort: BitwuzlaSort, converter: (BitwuzlaSort) -> KSort): KSort = convert(sorts, bitwuzlaSorts, sort, converter) @@ -103,11 +135,23 @@ open class KBitwuzlaContext : AutoCloseable { fun convertConstantIfKnown(term: BitwuzlaTerm): KDecl<*>? = bitwuzlaConstants[term] // Find normal constant if it was previously internalized - fun findConstant(decl: KDecl<*>): BitwuzlaTerm? = constants[decl] + fun findConstant(decl: KDecl<*>): BitwuzlaTerm? = constantsGlobalCache[decl] + + fun declarations(): Set> = + leveledDeclarations.flatMapTo(hashSetOf()) { it } - fun declarations(): Set> = declarations + fun uninterpretedSortsWithRelevantDecls(): Map>> { + val result = hashMapOf>>() - fun uninterpretedSortsWithRelevantDecls(): Map>> = uninterpretedSorts + leveledUninterpretedSorts.forEach { levelSorts -> + levelSorts.forEach { entry -> + val values = result.getOrPut(entry.key) { hashSetOf() } + values.addAll(entry.value) + } + } + + return result + } /** * Add declaration to the current declaration scope. @@ -115,21 +159,21 @@ open class KBitwuzlaContext : AutoCloseable { * register this declaration as relevant to the sort. * */ private fun registerDeclaration(decl: KDecl<*>) { - if (declarations.add(decl)) { - uninterpretedSortRegisterer.decl = decl - decl.sort.accept(uninterpretedSortRegisterer) + if (currentLevelDeclarations.add(decl)) { + currentLevelUninterpretedSortRegisterer.decl = decl + decl.sort.accept(currentLevelUninterpretedSortRegisterer) if (decl is KFuncDecl<*>) { - decl.argSorts.forEach { it.accept(uninterpretedSortRegisterer) } + decl.argSorts.forEach { it.accept(currentLevelUninterpretedSortRegisterer) } } } } /** - * Internalize constant. + * Internalize constant declaration. * Since [Native.bitwuzlaMkConst] creates fresh constant on each invocation caches are used * to guarantee that if two constants are equal in ksmt they are also equal in Bitwuzla. * */ - fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm = constants.getOrPut(decl) { + fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm = constantsGlobalCache.getOrPut(decl) { Native.bitwuzlaMkConst(bitwuzla, sort, decl.name).also { bitwuzlaConstants[it] = decl } @@ -141,30 +185,36 @@ open class KBitwuzlaContext : AutoCloseable { * and must match to the corresponding assertion level ([KBitwuzlaSolver.push]). * */ fun createNestedDeclarationScope() { - expressions = expressions.toMap(hashMapOf()) - constants = constants.toMap(hashMapOf()) - declarations = declarations.toHashSet() - uninterpretedSorts = uninterpretedSorts.mapValuesTo(hashMapOf()) { (_, decls) -> decls.toHashSet() } - uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts) - declarationScope = DeclarationScope(expressions, constants, declarations, uninterpretedSorts, declarationScope) + exprCurrentLevelCache = hashMapOf() + exprLeveledCache.add(exprCurrentLevelCache) + currentLevelExprMover = ExprMover() + + currentLevelDeclarations = hashSetOf() + leveledDeclarations.add(currentLevelDeclarations) + + currentLevelUninterpretedSorts = hashMapOf() + leveledUninterpretedSorts.add(currentLevelUninterpretedSorts) + currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts) } /** * Pop declaration scope to ensure that [declarations] match * the set of asserted declarations at the current assertion level ([KBitwuzlaSolver.pop]). * - * We also invalidate [expressions] internalization cache, since it may contain + * We also invalidate expressions internalization cache, since it may contain * expressions with invalidated declarations. * */ fun popDeclarationScope() { - declarationScope = declarationScope.parentScope - ?: error("Can't pop root declaration scope") - - expressions = declarationScope.expressions - constants = declarationScope.constants - declarations = declarationScope.declarations - uninterpretedSorts = declarationScope.uninterpretedSorts - uninterpretedSortRegisterer = UninterpretedSortRegisterer(uninterpretedSorts) + exprLeveledCache.removeLast() + exprCurrentLevelCache = exprLeveledCache.last() + currentLevelExprMover = ExprMover() + + leveledDeclarations.removeLast() + currentLevelDeclarations = leveledDeclarations.last() + + leveledUninterpretedSorts.removeLast() + currentLevelUninterpretedSorts = leveledUninterpretedSorts.last() + currentLevelUninterpretedSortRegisterer = UninterpretedSortRegisterer(currentLevelUninterpretedSorts) } inline fun bitwuzlaTry(body: () -> T): T = try { @@ -178,9 +228,14 @@ open class KBitwuzlaContext : AutoCloseable { isClosed = true sorts.clear() bitwuzlaSorts.clear() - expressions.clear() + + exprGlobalCache.clear() bitwuzlaExpressions.clear() - constants.clear() + constantsGlobalCache.clear() + + exprLeveledCache.clear() + exprCurrentLevelCache.clear() + declSorts.clear() bitwuzlaConstants.clear() Native.bitwuzlaDelete(bitwuzla) @@ -207,14 +262,6 @@ open class KBitwuzlaContext : AutoCloseable { return converted } - private class DeclarationScope( - val expressions: HashMap, BitwuzlaTerm>, - val constants: HashMap, BitwuzlaTerm>, - val declarations: HashSet>, - val uninterpretedSorts: HashMap>>, - val parentScope: DeclarationScope? - ) - private class UninterpretedSortRegisterer( private val register: MutableMap>> ) : KSortVisitor { @@ -248,4 +295,82 @@ open class KBitwuzlaContext : AutoCloseable { override fun visit(sort: KFpRoundingModeSort) { } } + + /** + * Move expressions from other cache level to the current cache level + * and register declarations for all moved expressions. + * + * 1. If expression is valid on previous levels we don't need to move it. + * Also, all expression declarations are known. + * + * 2. Otherwise, move expression to current level and recollect declarations. + * */ + private inner class ExprMover : KNonRecursiveTransformer(ctx) { + override fun transformExpr(expr: KExpr): KExpr { + // Move expr to current level + val term = exprGlobalCache.getValue(expr) + exprCacheLevel[expr] = currentLevel + exprCurrentLevelCache[expr] = term + + return super.transformExpr(expr) + } + + override fun exprTransformationRequired(expr: KExpr): Boolean { + val cachedLevel = exprCacheLevel[expr] + if (cachedLevel != null && cachedLevel < currentLevel) { + val levelCache = exprLeveledCache[cachedLevel] + // If expr is valid on its level we don't need to move it + return expr !in levelCache + } + return super.exprTransformationRequired(expr) + } + + private var currentlyIgnoredDeclarations: Set>? = null + + private fun registerDeclIfNotIgnored(decl: KDecl<*>) { + if (currentlyIgnoredDeclarations?.contains(decl) == true) { + return + } + registerDeclaration(decl) + } + + override fun transform(expr: KFunctionApp): KExpr { + registerDeclIfNotIgnored(expr.decl) + return super.transform(expr) + } + + override fun transform(expr: KConst): KExpr { + registerDeclIfNotIgnored(expr.decl) + return super.transform(expr) + } + + override fun transform(expr: KFunctionAsArray): KExpr> { + registerDeclIfNotIgnored(expr.function) + return super.transform(expr) + } + + private val quantifiedVarsScope = arrayListOf, Set>?>>() + + private fun KExpr.transformQuantifier(bounds: List>, body: KExpr<*>): KExpr { + if (quantifiedVarsScope.lastOrNull()?.first != this) { + quantifiedVarsScope.add(this to currentlyIgnoredDeclarations) + val ignoredDecls = currentlyIgnoredDeclarations?.toHashSet() ?: hashSetOf() + ignoredDecls.addAll(bounds) + currentlyIgnoredDeclarations = ignoredDecls + } + return transformExprAfterTransformed(this, body) { + currentlyIgnoredDeclarations = quantifiedVarsScope.removeLast().second + this + } + } + + override fun transform(expr: KArrayLambda): KExpr> = + expr.transformQuantifier(listOf(expr.indexVarDecl), expr.body) + + override fun transform(expr: KExistentialQuantifier): KExpr = + expr.transformQuantifier(expr.bounds, expr.body) + + override fun transform(expr: KUniversalQuantifier): KExpr = + expr.transformQuantifier(expr.bounds, expr.body) + } } diff --git a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt index 2fe9eaddb..7680d39cf 100644 --- a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt +++ b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt @@ -181,7 +181,7 @@ open class KBitwuzlaExprInternalizer( } override fun findInternalizedExpr(expr: KExpr<*>): BitwuzlaTerm? { - return bitwuzlaCtx[expr] + return bitwuzlaCtx.findExprTerm(expr) } override fun saveInternalizedExpr(expr: KExpr<*>, internalized: BitwuzlaTerm) { @@ -252,7 +252,7 @@ open class KBitwuzlaExprInternalizer( fun KDecl.bitwuzlaFunctionSort(): BitwuzlaSort = accept(functionSortInternalizer) private fun saveExprInternalizationResult(expr: KExpr<*>, term: BitwuzlaTerm) { - bitwuzlaCtx.internalizeExpr(expr) { term } + bitwuzlaCtx.saveExprTerm(expr, term) // Save only constants if (expr !is KInterpretedValue<*>) return diff --git a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaModel.kt b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaModel.kt index 85e98c276..13d583a2e 100644 --- a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaModel.kt +++ b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaModel.kt @@ -67,6 +67,8 @@ open class KBitwuzlaModel( ctx.ensureContextMatch(decl) bitwuzlaCtx.ensureActive() + if (decl !in modelDeclarations) return null + val interpretation = interpretations.getOrPut(decl) { // Constant was not internalized --> constant is unknown to solver --> constant is not present in model val bitwuzlaConstant = bitwuzlaCtx.findConstant(decl) diff --git a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt index e4706f316..bad20f9d4 100644 --- a/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt +++ b/ksmt-bitwuzla/src/main/kotlin/org/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt @@ -14,7 +14,7 @@ import org.ksmt.sort.KBoolSort import kotlin.time.Duration open class KBitwuzlaSolver(private val ctx: KContext) : KSolver { - open val bitwuzlaCtx = KBitwuzlaContext() + open val bitwuzlaCtx = KBitwuzlaContext(ctx) open val exprInternalizer: KBitwuzlaExprInternalizer by lazy { KBitwuzlaExprInternalizer(bitwuzlaCtx) } diff --git a/ksmt-bitwuzla/src/test/kotlin/org/ksmt/solver/bitwuzla/ConverterTest.kt b/ksmt-bitwuzla/src/test/kotlin/org/ksmt/solver/bitwuzla/ConverterTest.kt index 4f40145ef..ba0f104f8 100644 --- a/ksmt-bitwuzla/src/test/kotlin/org/ksmt/solver/bitwuzla/ConverterTest.kt +++ b/ksmt-bitwuzla/src/test/kotlin/org/ksmt/solver/bitwuzla/ConverterTest.kt @@ -22,7 +22,7 @@ import kotlin.test.assertTrue class ConverterTest { private val ctx = KContext() - private val bitwuzlaCtx = KBitwuzlaContext() + private val bitwuzlaCtx = KBitwuzlaContext(ctx) private val internalizer = KBitwuzlaExprInternalizer(bitwuzlaCtx) private val converter = KBitwuzlaExprConverter(ctx, bitwuzlaCtx) private val sortChecker = SortChecker(ctx) diff --git a/ksmt-test/src/main/kotlin/org/ksmt/test/TestWorkerProcess.kt b/ksmt-test/src/main/kotlin/org/ksmt/test/TestWorkerProcess.kt index 3fb7a85be..c9c9667fb 100644 --- a/ksmt-test/src/main/kotlin/org/ksmt/test/TestWorkerProcess.kt +++ b/ksmt-test/src/main/kotlin/org/ksmt/test/TestWorkerProcess.kt @@ -85,7 +85,7 @@ class TestWorkerProcess : ChildProcessBase() { } private fun internalizeAndConvertBitwuzla(assertions: List>): List> = - KBitwuzlaContext().use { bitwuzlaCtx -> + KBitwuzlaContext(ctx).use { bitwuzlaCtx -> val internalizer = KBitwuzlaExprInternalizer(bitwuzlaCtx) val bitwuzlaAssertions = with(internalizer) { assertions.map { it.internalize() }