Skip to content

Commit

Permalink
Added backtracking simplification justification -- ESimpl.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Nov 21, 2024
1 parent d62f1b9 commit dcd8bc7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 24 deletions.
79 changes: 63 additions & 16 deletions shared/src/main/scala/org/sireum/logika/RewritingSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,45 @@ object RewritingSystem {
val none: EvalConfig = EvalConfig(F, F, F, F, F, F, F, F, F, F, F)
}

@record class BacktrackingSchedule(val backtracking: B, var schedule: ISZ[B], var index: Z) {
def done: B = {
return !backtracking || schedule.isEmpty || schedule(0)
}
def choose(): B = {
if (index >= schedule.size) {
schedule = schedule :+ F
index = index + 1
return F
} else if (index == schedule.size - 1) {
schedule = schedule(index ~> T)
return T
} else {
assert(!schedule(index))
var i = index + 1
var restT = T
while (restT & i < schedule.size) {
if (!schedule(i)) {
restT = F
}
i = i + 1
}
index = index + 1
if (restT) {
for (j <- index until schedule.size) {
schedule = schedule(j ~> F)
}
return T
} else {
return F
}
}
}
}

object BacktrackingSchedule {
@strictpure def empty(backtracking: B): BacktrackingSchedule = BacktrackingSchedule(backtracking, ISZ(), 0)
}

@record class Substitutor(var map: HashMap[AST.CoreExp, AST.CoreExp.ParamVarRef]) extends AST.MCoreExpTransformer {
override def transformCoreExpBase(o: AST.CoreExp.Base): MOption[AST.CoreExp.Base] = {
map.get(o) match {
Expand Down Expand Up @@ -462,7 +501,7 @@ object RewritingSystem {
case _ =>
if (r0.isEmpty && labeledOnly) {
evalBase(th, EvalConfig.all, cache, methodPatterns, unfoldingMap, 0,
provenClaimStepIdMapEval, o, F, shouldTraceEval) match {
provenClaimStepIdMapEval, o, F, shouldTraceEval, BacktrackingSchedule.empty(F)) match {
case Some((r1, t)) =>
trace = trace ++ t
done = T
Expand Down Expand Up @@ -500,7 +539,8 @@ object RewritingSystem {
case _ => F
}
if (shouldUnfold && (!labeledOnly || inLabel)) {
evalBase(th, EvalConfig.none, cache, methodPatterns, unfoldingMap, 1, HashSMap.empty, o2, F, shouldTraceEval) match {
evalBase(th, EvalConfig.none, cache, methodPatterns, unfoldingMap, 1, HashSMap.empty, o2, F, shouldTraceEval,
BacktrackingSchedule.empty(F)) match {
case Some((o3, t)) =>
trace = trace ++ t
return Some(o3)
Expand Down Expand Up @@ -559,7 +599,7 @@ object RewritingSystem {
}
val (o3Opt, t): (Option[AST.CoreExp.Base], ISZ[Trace]) =
evalBase(th, EvalConfig.all, cache, methodPatterns, unfoldingMap, 0,
provenClaimStepIdMapEval, o2, F, shouldTraceEval) match {
provenClaimStepIdMapEval, o2, F, shouldTraceEval, BacktrackingSchedule.empty(F)) match {
case Some((o3o, t)) => (Some(o3o), t)
case _ => (None(), ISZ())
}
Expand Down Expand Up @@ -1247,21 +1287,22 @@ object RewritingSystem {
provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId],
exp: AST.CoreExp,
removeLabels: B,
shouldTrace: B): Option[(AST.CoreExp, ISZ[Trace])] = {
shouldTrace: B,
backtrackingSchedule: BacktrackingSchedule): Option[(AST.CoreExp, ISZ[Trace])] = {
exp match {
case exp: AST.CoreExp.Arrow =>
var changed = F
var trace = ISZ[Trace]()
val left: AST.CoreExp.Base = evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding,
provenClaims, exp.left, removeLabels, shouldTrace) match {
provenClaims, exp.left, removeLabels, shouldTrace, backtrackingSchedule) match {
case Some((l, t)) =>
trace = trace ++ t
changed = T
l
case _ => exp.left
}
val right: AST.CoreExp = eval(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding, provenClaims,
exp.right, removeLabels, shouldTrace) match {
exp.right, removeLabels, shouldTrace, backtrackingSchedule) match {
case Some((r, t)) =>
trace = trace ++ t
changed = T
Expand All @@ -1270,15 +1311,16 @@ object RewritingSystem {
}
return if (changed) Some((AST.CoreExp.Arrow(left, right), trace)) else None()
case exp: AST.CoreExp.Base => evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding,
provenClaims, exp, removeLabels, shouldTrace) match {
provenClaims, exp, removeLabels, shouldTrace, backtrackingSchedule) match {
case Some((e, t)) => return Some((e, t))
case _ => return None()
}
}
}

@pure def simplify(th: TypeHierarchy, exp: AST.CoreExp.Base): Option[AST.CoreExp.Base] = {
evalBase(th, EvalConfig.all, NoCache(), HashSMap.empty, MBox(HashSMap.empty), 1, HashSMap.empty, exp, T, F) match {
evalBase(th, EvalConfig.all, NoCache(), HashSMap.empty, MBox(HashSMap.empty), 1, HashSMap.empty,
exp, T, F, BacktrackingSchedule.empty(F)) match {
case Some((r, _)) => return Some(r)
case _ => return None()
}
Expand Down Expand Up @@ -1345,7 +1387,8 @@ object RewritingSystem {
provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId],
exp: AST.CoreExp.Base,
removeLabels: B,
shouldTrace: B): Option[(AST.CoreExp.Base, ISZ[Trace])] = {
shouldTrace: B,
backtrackingSchedule: BacktrackingSchedule): Option[(AST.CoreExp.Base, ISZ[Trace])] = {
@strictpure def equivST(left: AST.CoreExp.Base, right: AST.CoreExp.Base): ST =
AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right, AST.Typed.b).prettyST

Expand Down Expand Up @@ -1430,11 +1473,13 @@ object RewritingSystem {
if (e.tipe == AST.Typed.b) {
provenClaims.get(e) match {
case Some(stepId) =>
val r = AST.CoreExp.True
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"using $stepId", e, r)
if (!backtrackingSchedule.choose()) {
val r = AST.CoreExp.True
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"using $stepId", e, r)
}
rOpt = Some(r)
}
rOpt = Some(r)
case _ =>
}
if (rOpt.isEmpty) {
Expand All @@ -1453,10 +1498,12 @@ object RewritingSystem {
val r = eqMap.get(e)
r match {
case Some((to, stepId)) =>
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"substitution using $stepId [${to.prettyST}/${e.prettyST}]", e, to)
if (!backtrackingSchedule.choose()) {
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"substitution using $stepId [${to.prettyST}/${e.prettyST}]", e, to)
}
rOpt = Some(to)
}
rOpt = Some(to)
case _ =>
}
}
Expand Down
33 changes: 25 additions & 8 deletions shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package org.sireum.logika.plugin
import org.sireum._
import org.sireum.lang.{ast => AST}
import org.sireum.logika.Logika.Reporter
import org.sireum.logika.RewritingSystem.{Rewriter, toCondEquiv}
import org.sireum.logika.RewritingSystem.{BacktrackingSchedule, Rewriter, toCondEquiv}
import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext}

@datatype class RewritePlugin extends JustificationPlugin {
Expand All @@ -41,7 +41,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
override def canHandle(logika: Logika, just: AST.ProofAst.Step.Justification): B = {
just match {
case just: AST.ProofAst.Step.Justification.Ref =>
return just.id.value == "Simpl" && just.isOwnedBy(justificationName)
return (just.id.value == "Simpl" || just.id.value == "ESimpl") && just.isOwnedBy(justificationName)
case just: AST.ProofAst.Step.Justification.Apply =>
just.invoke.ident.attr.resOpt match {
case Some(res: AST.ResolvedInfo.Method) if (res.id == "Rewrite" || res.id == "RSimpl" || res.id == "Eval") && res.owner == justificationName => return T
Expand All @@ -61,7 +61,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
step: AST.ProofAst.Step.Regular, reporter: Logika.Reporter): State = {
@strictpure def err: State = state(status = State.Status.Error)
@strictpure def justArgs: ISZ[AST.Exp] = step.just.asInstanceOf[AST.ProofAst.Step.Justification.Apply].args
val isSimpl = step.just.id.value == "Simpl"
val isESimpl = step.just.id.value == "ESimpl"
val isSimpl = isESimpl | step.just.id.value == "Simpl"
val isEval = step.just.id.value == "Eval"
val isRSimpl = step.just.id.value == "RSimpl"
val (patterns, methodPatterns): (ISZ[Rewriter.Pattern.Claim], HashSMap[(ISZ[String], B), Rewriter.Pattern.Method]) =
Expand Down Expand Up @@ -150,15 +151,29 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Begin("simplifying", stepClaim)
}

val stepClaimEv: AST.CoreExp.Base = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache,
val schedule = BacktrackingSchedule.empty(isESimpl)

var stepClaimEv: AST.CoreExp.Base = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache,
rwPc.methodPatterns, MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, stepClaim, T,
logika.config.rwEvalTrace) match {
logika.config.rwEvalTrace, schedule) match {
case Some((e, t)) =>
rwPc.trace = t
e
case _ => stepClaim
}

while (stepClaimEv != AST.CoreExp.True && !schedule.done) {
schedule.index = 0
stepClaimEv = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache,
rwPc.methodPatterns, MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, stepClaim, T,
logika.config.rwEvalTrace, schedule) match {
case Some((e, t)) =>
rwPc.trace = t
e
case _ => stepClaim
}
}

if (logika.config.rwEvalTrace) {
rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Done(stepClaim, stepClaimEv)
}
Expand Down Expand Up @@ -201,7 +216,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
rwPc.trace = rwPc.trace :+ RewritingSystem.Trace.Begin("evaluating", fromCoreClaim)
}
rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns,
MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, logika.config.rwEvalTrace) match {
MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, logika.config.rwEvalTrace,
BacktrackingSchedule.empty(F)) match {
case Some((c, t)) =>
rwPc.trace = rwPc.trace ++ t
c
Expand Down Expand Up @@ -236,7 +252,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
}
if (continu && !rwPc.labeledOnly) {
rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns,
MBox(HashSMap.empty), 0, rwPc.provenClaimStepIdMapEval, rwClaim, T, T) match {
MBox(HashSMap.empty), 0, rwPc.provenClaimStepIdMapEval, rwClaim, T, T, BacktrackingSchedule.empty(F)) match {
case Some((r, t)) =>
rwPc.trace = rwPc.trace ++ t
r
Expand All @@ -247,7 +263,8 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext
}
if (continu && !rwPc.labeledOnly) {
rwClaim = RewritingSystem.evalBase(logika.th, RewritingSystem.EvalConfig.all, cache, rwPc.methodPatterns,
MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, T) match {
MBox(HashSMap.empty), logika.config.rwMax, rwPc.provenClaimStepIdMapEval, rwClaim, T, T,
BacktrackingSchedule.empty(F)) match {
case Some((r, t)) =>
rwPc.trace = rwPc.trace ++ t
r
Expand Down

0 comments on commit dcd8bc7

Please sign in to comment.