From dcd8bc7f61cc6891ca4c68194b634ff243cbfc4a Mon Sep 17 00:00:00 2001 From: Robby Date: Thu, 21 Nov 2024 09:02:05 -0600 Subject: [PATCH] Added backtracking simplification justification -- ESimpl. --- .../org/sireum/logika/RewritingSystem.scala | 79 +++++++++++++++---- .../sireum/logika/plugin/RewritePlugin.scala | 33 ++++++-- 2 files changed, 88 insertions(+), 24 deletions(-) diff --git a/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala b/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala index ba3e6949..13bf46bb 100644 --- a/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala +++ b/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala @@ -56,6 +56,45 @@ object RewritingSystem { val none: EvalConfig = EvalConfig(F, F, F, F, F, F, F, F, F, F, F) } + @record class BacktrackingSchedule(val backtracking: B, var schedule: ISZ[B], var index: Z) { + def done: B = { + return !backtracking || schedule.isEmpty || schedule(0) + } + def choose(): B = { + if (index >= schedule.size) { + schedule = schedule :+ F + index = index + 1 + return F + } else if (index == schedule.size - 1) { + schedule = schedule(index ~> T) + return T + } else { + assert(!schedule(index)) + var i = index + 1 + var restT = T + while (restT & i < schedule.size) { + if (!schedule(i)) { + restT = F + } + i = i + 1 + } + index = index + 1 + if (restT) { + for (j <- index until schedule.size) { + schedule = schedule(j ~> F) + } + return T + } else { + return F + } + } + } + } + + object BacktrackingSchedule { + @strictpure def empty(backtracking: B): BacktrackingSchedule = BacktrackingSchedule(backtracking, ISZ(), 0) + } + @record class Substitutor(var map: HashMap[AST.CoreExp, AST.CoreExp.ParamVarRef]) extends AST.MCoreExpTransformer { override def transformCoreExpBase(o: AST.CoreExp.Base): MOption[AST.CoreExp.Base] = { map.get(o) match { @@ -462,7 +501,7 @@ object RewritingSystem { case _ => if (r0.isEmpty && labeledOnly) { evalBase(th, EvalConfig.all, cache, methodPatterns, unfoldingMap, 0, - provenClaimStepIdMapEval, o, F, shouldTraceEval) match { + provenClaimStepIdMapEval, o, F, shouldTraceEval, BacktrackingSchedule.empty(F)) match { case Some((r1, t)) => trace = trace ++ t done = T @@ -500,7 +539,8 @@ object RewritingSystem { case _ => F } if (shouldUnfold && (!labeledOnly || inLabel)) { - evalBase(th, EvalConfig.none, cache, methodPatterns, unfoldingMap, 1, HashSMap.empty, o2, F, shouldTraceEval) match { + evalBase(th, EvalConfig.none, cache, methodPatterns, unfoldingMap, 1, HashSMap.empty, o2, F, shouldTraceEval, + BacktrackingSchedule.empty(F)) match { case Some((o3, t)) => trace = trace ++ t return Some(o3) @@ -559,7 +599,7 @@ object RewritingSystem { } val (o3Opt, t): (Option[AST.CoreExp.Base], ISZ[Trace]) = evalBase(th, EvalConfig.all, cache, methodPatterns, unfoldingMap, 0, - provenClaimStepIdMapEval, o2, F, shouldTraceEval) match { + provenClaimStepIdMapEval, o2, F, shouldTraceEval, BacktrackingSchedule.empty(F)) match { case Some((o3o, t)) => (Some(o3o), t) case _ => (None(), ISZ()) } @@ -1247,13 +1287,14 @@ object RewritingSystem { provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId], exp: AST.CoreExp, removeLabels: B, - shouldTrace: B): Option[(AST.CoreExp, ISZ[Trace])] = { + shouldTrace: B, + backtrackingSchedule: BacktrackingSchedule): Option[(AST.CoreExp, ISZ[Trace])] = { exp match { case exp: AST.CoreExp.Arrow => var changed = F var trace = ISZ[Trace]() val left: AST.CoreExp.Base = evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding, - provenClaims, exp.left, removeLabels, shouldTrace) match { + provenClaims, exp.left, removeLabels, shouldTrace, backtrackingSchedule) match { case Some((l, t)) => trace = trace ++ t changed = T @@ -1261,7 +1302,7 @@ object RewritingSystem { case _ => exp.left } val right: AST.CoreExp = eval(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding, provenClaims, - exp.right, removeLabels, shouldTrace) match { + exp.right, removeLabels, shouldTrace, backtrackingSchedule) match { case Some((r, t)) => trace = trace ++ t changed = T @@ -1270,7 +1311,7 @@ object RewritingSystem { } return if (changed) Some((AST.CoreExp.Arrow(left, right), trace)) else None() case exp: AST.CoreExp.Base => evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding, - provenClaims, exp, removeLabels, shouldTrace) match { + provenClaims, exp, removeLabels, shouldTrace, backtrackingSchedule) match { case Some((e, t)) => return Some((e, t)) case _ => return None() } @@ -1278,7 +1319,8 @@ object RewritingSystem { } @pure def simplify(th: TypeHierarchy, exp: AST.CoreExp.Base): Option[AST.CoreExp.Base] = { - evalBase(th, EvalConfig.all, NoCache(), HashSMap.empty, MBox(HashSMap.empty), 1, HashSMap.empty, exp, T, F) match { + evalBase(th, EvalConfig.all, NoCache(), HashSMap.empty, MBox(HashSMap.empty), 1, HashSMap.empty, + exp, T, F, BacktrackingSchedule.empty(F)) match { case Some((r, _)) => return Some(r) case _ => return None() } @@ -1345,7 +1387,8 @@ object RewritingSystem { provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId], exp: AST.CoreExp.Base, removeLabels: B, - shouldTrace: B): Option[(AST.CoreExp.Base, ISZ[Trace])] = { + shouldTrace: B, + backtrackingSchedule: BacktrackingSchedule): Option[(AST.CoreExp.Base, ISZ[Trace])] = { @strictpure def equivST(left: AST.CoreExp.Base, right: AST.CoreExp.Base): ST = AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right, AST.Typed.b).prettyST @@ -1430,11 +1473,13 @@ object RewritingSystem { if (e.tipe == AST.Typed.b) { provenClaims.get(e) match { case Some(stepId) => - val r = AST.CoreExp.True - if (shouldTrace) { - trace = trace :+ Trace.Eval(st"using $stepId", e, r) + if (!backtrackingSchedule.choose()) { + val r = AST.CoreExp.True + if (shouldTrace) { + trace = trace :+ Trace.Eval(st"using $stepId", e, r) + } + rOpt = Some(r) } - rOpt = Some(r) case _ => } if (rOpt.isEmpty) { @@ -1453,10 +1498,12 @@ object RewritingSystem { val r = eqMap.get(e) r match { case Some((to, stepId)) => - if (shouldTrace) { - trace = trace :+ Trace.Eval(st"substitution using $stepId [${to.prettyST}/${e.prettyST}]", e, to) + if (!backtrackingSchedule.choose()) { + if (shouldTrace) { + trace = trace :+ Trace.Eval(st"substitution using $stepId [${to.prettyST}/${e.prettyST}]", e, to) + } + rOpt = Some(to) } - rOpt = Some(to) case _ => } } diff --git a/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala b/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala index 750c3b83..9156c8a7 100644 --- a/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala +++ b/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala @@ -29,7 +29,7 @@ package org.sireum.logika.plugin import org.sireum._ import org.sireum.lang.{ast => AST} import org.sireum.logika.Logika.Reporter -import org.sireum.logika.RewritingSystem.{Rewriter, toCondEquiv} +import org.sireum.logika.RewritingSystem.{BacktrackingSchedule, Rewriter, toCondEquiv} import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext} @datatype class RewritePlugin extends JustificationPlugin { @@ -41,7 +41,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext override def canHandle(logika: Logika, just: AST.ProofAst.Step.Justification): B = { just match { case just: AST.ProofAst.Step.Justification.Ref => - return just.id.value == "Simpl" && just.isOwnedBy(justificationName) + return (just.id.value == "Simpl" || just.id.value == "ESimpl") && just.isOwnedBy(justificationName) case just: AST.ProofAst.Step.Justification.Apply => just.invoke.ident.attr.resOpt match { case Some(res: AST.ResolvedInfo.Method) if (res.id == "Rewrite" || res.id == "RSimpl" || res.id == "Eval") && res.owner == justificationName => return T @@ -61,7 +61,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext step: AST.ProofAst.Step.Regular, reporter: Logika.Reporter): State = { @strictpure def err: State = state(status = State.Status.Error) @strictpure def justArgs: ISZ[AST.Exp] = step.just.asInstanceOf[AST.ProofAst.Step.Justification.Apply].args - val isSimpl = step.just.id.value == "Simpl" + val isESimpl = step.just.id.value == "ESimpl" + val isSimpl = isESimpl | step.just.id.value == "Simpl" val isEval = step.just.id.value == "Eval" val isRSimpl = step.just.id.value == "RSimpl" val (patterns, methodPatterns): (ISZ[Rewriter.Pattern.Claim], HashSMap[(ISZ[String], B), Rewriter.Pattern.Method]) = @@ -150,15 +151,29 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Begin("simplifying", stepClaim) } - val stepClaimEv: AST.CoreExp.Base = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, + val schedule = BacktrackingSchedule.empty(isESimpl) + + var stepClaimEv: AST.CoreExp.Base = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns, MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, stepClaim, T, - logika.config.rwEvalTrace) match { + logika.config.rwEvalTrace, schedule) match { case Some((e, t)) => rwPc.trace = t e case _ => stepClaim } + while (stepClaimEv != AST.CoreExp.True && !schedule.done) { + schedule.index = 0 + stepClaimEv = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, + rwPc.methodPatterns, MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, stepClaim, T, + logika.config.rwEvalTrace, schedule) match { + case Some((e, t)) => + rwPc.trace = t + e + case _ => stepClaim + } + } + if (logika.config.rwEvalTrace) { rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Done(stepClaim, stepClaimEv) } @@ -201,7 +216,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Begin("evaluating", fromCoreClaim) } rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns, - MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, logika.config.rwEvalTrace) match { + MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, logika.config.rwEvalTrace, + BacktrackingSchedule.empty(F)) match { case Some((c, t)) => rwPc.trace = rwPc.trace ++ t c @@ -236,7 +252,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext } if (continu && !rwPc.labeledOnly) { rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns, - MBox(HashSMap.empty), 0, rwPc.provenClaimStepIdMapEval, rwClaim, T, T) match { + MBox(HashSMap.empty), 0, rwPc.provenClaimStepIdMapEval, rwClaim, T, T, BacktrackingSchedule.empty(F)) match { case Some((r, t)) => rwPc.trace = rwPc.trace ++ t r @@ -247,7 +263,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext } if (continu && !rwPc.labeledOnly) { rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns, - MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, T) match { + MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, T, + BacktrackingSchedule.empty(F)) match { case Some((r, t)) => rwPc.trace = rwPc.trace ++ t r