From 39670238f53f9c5e3f7a6489376050642a0cb335 Mon Sep 17 00:00:00 2001 From: Robby Date: Sun, 11 Feb 2024 14:32:39 -0600 Subject: [PATCH] Rewriting system. --- .../org/sireum/logika/example/rewrite.sc | 4 +- .../main/scala/org/sireum/logika/Logika.scala | 10 +- .../org/sireum/logika/RewritingSystem.scala | 442 +++++++++++------- .../org/sireum/logika/StepProofContext.scala | 7 +- .../main/scala/org/sireum/logika/Task.scala | 2 +- .../logika/plugin/InceptionPlugin.scala | 8 +- .../logika/plugin/PredNatDedPlugin.scala | 4 +- .../logika/plugin/PropNatDedPlugin.scala | 6 +- .../sireum/logika/plugin/RewritePlugin.scala | 46 +- 9 files changed, 340 insertions(+), 189 deletions(-) diff --git a/jvm/src/test/scala/org/sireum/logika/example/rewrite.sc b/jvm/src/test/scala/org/sireum/logika/example/rewrite.sc index 5d5e1fb3..ccb67263 100644 --- a/jvm/src/test/scala/org/sireum/logika/example/rewrite.sc +++ b/jvm/src/test/scala/org/sireum/logika/example/rewrite.sc @@ -15,8 +15,8 @@ object Rules { (2 * c + 3 * c == d) |- (5 * c == d) Proof( //@formatter:off 1 (2 * c + 3 * c == d) by Premise, - 2 (5 * c == d) by Rewrite(RS(zDistribute _), 1), - 3 (5 * c == d) by Rewrite(myRewriteSet, 1) + 2 (5 * c == d) by Rewrite(RS(zDistribute _), 1) T, + 3 (5 * c == d) by Rewrite(myRewriteSet, 1) T //@formatter:on ) ) diff --git a/shared/src/main/scala/org/sireum/logika/Logika.scala b/shared/src/main/scala/org/sireum/logika/Logika.scala index 410e1ff8..f602ca5b 100644 --- a/shared/src/main/scala/org/sireum/logika/Logika.scala +++ b/shared/src/main/scala/org/sireum/logika/Logika.scala @@ -4736,7 +4736,7 @@ import Util._ cache.getTransitionAndUpdateSmt2(th, config, Cache.Transition.ProofStep(step, m.values), s0, smt2) match { case Some((ISZ(nextState), cached)) => reporter.coverage(F, cached, pos) - return (nextState, m + stepNo ~> StepProofContext.Regular(stepNo, step.claim, + return (nextState, m + stepNo ~> StepProofContext.Regular(th, stepNo, step.claim, ops.ISZOps(nextState.claims).slice(s0.claims.size, nextState.claims.size))) case _ => } @@ -4794,7 +4794,7 @@ import Util._ } else { reporter.coverage(F, zeroU64, pos) } - return (nextState, m + stepNo ~> StepProofContext.Regular(stepNo, step.claim, claims)) + return (nextState, m + stepNo ~> StepProofContext.Regular(th, stepNo, step.claim, claims)) } reporter.error(step.just.posOpt, Logika.kind, "Could not recognize justification form") return (s0(status = State.Status.Error), m) @@ -4802,7 +4802,7 @@ import Util._ val (ok, nextFresh, claims, claim) = evalRegularStepClaim(smt2, cache, s0, step.claim, step.id.posOpt, reporter) return (s0(status = State.statusOf(ok), nextFresh = nextFresh, claims = (s0.claims ++ claims) :+ claim), - m + stepNo ~> StepProofContext.Regular(stepNo, step.claim, claims :+ claim)) + m + stepNo ~> StepProofContext.Regular(th, stepNo, step.claim, claims :+ claim)) case step: AST.ProofAst.Step.SubProof => for (sub <- step.steps if s0.ok) { val p = evalProofStep(smt2, cache, (s0, m), sub, reporter) @@ -4840,7 +4840,7 @@ import Util._ val (ok, nextFresh, claims, claim) = evalRegularStepClaimRtCheck(smt2, cache, F, s0, step.claim, step.id.posOpt, reporter) return (s0(status = State.statusOf(ok), nextFresh = nextFresh, claims = (s0.claims ++ claims) :+ claim), - m + stepNo ~> StepProofContext.Regular(stepNo, step.claim, claims :+ claim)) + m + stepNo ~> StepProofContext.Regular(th, stepNo, step.claim, claims :+ claim)) case step: AST.ProofAst.Step.Let => for (sub <- step.steps if s0.ok) { val p = evalProofStep(smt2, cache, (s0, m), sub, reporter) @@ -5250,7 +5250,7 @@ import Util._ var st0 = st for (premise <- sequent.premises if st0.ok) { val (ok, nextFresh, claims, claim) = evalRegularStepClaim(smt2, cache, st0, premise, premise.posOpt, reporter) - r = r + id ~> StepProofContext.Regular(id(attr = AST.Attr(premise.posOpt)), premise, claims :+ claim) + r = r + id ~> StepProofContext.Regular(th, id(attr = AST.Attr(premise.posOpt)), premise, claims :+ claim) id = id(no = id.no - 1) st0 = st0(status = State.statusOf(ok), nextFresh = nextFresh, claims = (st0.claims ++ claims) :+ claim) diff --git a/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala b/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala index 9f43fc81..9deb5761 100644 --- a/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala +++ b/shared/src/main/scala/org/sireum/logika/RewritingSystem.scala @@ -33,30 +33,31 @@ import org.sireum.lang.tipe.TypeHierarchy object RewritingSystem { type FunStack = Stack[(String, AST.Typed)] type Local = (ISZ[String], String) - type LocalMap = HashSMap[Local, AST.CoreExp] + type LocalMap = HashSMap[Local, AST.CoreExp.Base] type LocalPatternSet = HashSSet[Local] - type PendingApplications = ISZ[(ISZ[String], String, ISZ[AST.CoreExp], AST.CoreExp)] - type UnificationMap = HashSMap[Local, AST.CoreExp] + type PendingApplications = ISZ[(ISZ[String], String, ISZ[AST.CoreExp.Base], AST.CoreExp.Base)] + type UnificationMap = HashSMap[Local, AST.CoreExp.Base] type UnificationErrorMessages = ISZ[String] type UnificationResult = Either[UnificationMap, UnificationErrorMessages] @datatype class EvalConfig(val constantPropagation: B, val funApplication: B, val quantApplication: B, + val equality: B, val tupleProjection: B, val seqIndexing: B, val fieldAccess: B, val instanceOf: B) object EvalConfig { - val all: EvalConfig = EvalConfig(T, T, T, T, T, T, T) - val none: EvalConfig = EvalConfig(F, F, F, F, F, F, F) + val all: EvalConfig = EvalConfig(T, T, T, T, T, T, T, T) + val none: EvalConfig = EvalConfig(F, F, F, F, F, F, F, F) val funApplicationOnly: EvalConfig = none(funApplication = T) val quantApplicationOnly: EvalConfig = none(quantApplication = T) } @record class Substitutor(var map: HashMap[AST.CoreExp, AST.CoreExp.ParamVarRef]) extends AST.MCoreExpTransformer { - override def transformCoreExp(o: AST.CoreExp): MOption[AST.CoreExp] = { + override def transformCoreExpBase(o: AST.CoreExp.Base): MOption[AST.CoreExp.Base] = { map.get(o) match { case Some(pvr) => return MSome(pvr) case _ => @@ -65,7 +66,7 @@ object RewritingSystem { val r0: MOption[AST.CoreExp.Param] = transformCoreExpParam(o.param) val oldMap = map map = HashMap ++ (for (p <- map.entries) yield (p._1.incDeBruijn(1), p._2.incDeBruijn(1))) - val r1: MOption[AST.CoreExp] = transformCoreExp(o.exp) + val r1: MOption[AST.CoreExp.Base] = transformCoreExpBase(o.exp) map = oldMap if (r0.nonEmpty || r1.nonEmpty) { return MSome(o(param = r0.getOrElse(o.param), exp = r1.getOrElse(o.exp))) @@ -76,7 +77,7 @@ object RewritingSystem { val r0: MOption[AST.CoreExp.Param] = transformCoreExpParam(o.param) val oldMap = map map = HashMap ++ (for (p <- map.entries) yield (p._1.incDeBruijn(1), p._2.incDeBruijn(1))) - val r1: MOption[AST.CoreExp] = transformCoreExp(o.exp) + val r1: MOption[AST.CoreExp.Base] = transformCoreExpBase(o.exp) map = oldMap if (r0.nonEmpty || r1.nonEmpty) { return MSome(o(param = r0.getOrElse(o.param), exp = r1.getOrElse(o.exp))) @@ -85,13 +86,13 @@ object RewritingSystem { } case _ => } - return super.transformCoreExp(o) + return super.transformCoreExpBase(o) } } } @record class LocalPatternDetector(val localPatterns: LocalPatternSet, var hasLocalPattern: B) extends AST.MCoreExpTransformer { - override def postCoreExpLocalVarRef(o: AST.CoreExp.LocalVarRef): MOption[AST.CoreExp] = { + override def postCoreExpLocalVarRef(o: AST.CoreExp.LocalVarRef): MOption[AST.CoreExp.Base] = { if (localPatterns.contains((o.context, o.id))) { hasLocalPattern = T } @@ -100,7 +101,7 @@ object RewritingSystem { } @record class LocalSubstitutor(val map: UnificationMap) extends AST.MCoreExpTransformer { - override def postCoreExpLocalVarRef(o: AST.CoreExp.LocalVarRef): MOption[AST.CoreExp] = { + override def postCoreExpLocalVarRef(o: AST.CoreExp.LocalVarRef): MOption[AST.CoreExp.Base] = { map.get((o.context, o.id)) match { case Some(e) => return MSome(e) case _ => return MNone() @@ -111,52 +112,128 @@ object RewritingSystem { @datatype class TraceElement(val name: ISZ[String], val rightToLeft: B, val pattern: AST.CoreExp, - val original: AST.CoreExp, - val rewritten: AST.CoreExp, - val evaluated: AST.CoreExp, - val done: AST.CoreExp) { - @strictpure def toST: ST = - st"""by ${(name, ".")}: ${pattern.prettyPatternST} - | ${original.prettyST} - | ${if (rightToLeft) "<" else ">"} ${rewritten.prettyST} - | ≡ ${evaluated.prettyST} - | ∴ ${done.prettyST}""" + val original: AST.CoreExp.Base, + val rewritten: AST.CoreExp.Base, + val evaluatedOpt: Option[AST.CoreExp.Base], + val done: AST.CoreExp.Base) { + @strictpure def toST: ST = evaluatedOpt match { + case Some(evaluated) => + st"""by ${(name, ".")}: ${pattern.prettyPatternST} + | ${original.prettyST} + | ${if (rightToLeft) "<" else ">"} ${rewritten.prettyST} + | ≡ ${evaluated.prettyST} + | ∴ ${done.prettyST}""" + case _ => + st"""by ${(name, ".")}: ${pattern.prettyPatternST} + | ${original.prettyST} + | ${if (rightToLeft) "<" else ">"} ${rewritten.prettyST} + | ∴ ${done.prettyST}""" + } } @record class Rewriter(val th: TypeHierarchy, - val patterns: ISZ[Rewriter.Pattern], + val provenClaims: HashSSet[AST.CoreExp.Base], + val rwPatterns: ISZ[Rewriter.Pattern], var trace: ISZ[TraceElement]) extends AST.MCoreExpTransformer { - override def preCoreExpIf(o: AST.CoreExp.If): AST.MCoreExpTransformer.PreResult[AST.CoreExp] = { + override def preCoreExpIf(o: AST.CoreExp.If): AST.MCoreExpTransformer.PreResult[AST.CoreExp.Base] = { o.cond match { case cond: AST.CoreExp.LitB => return AST.MCoreExpTransformer.PreResult(T, if (cond.value) MSome(o.tExp) else MSome(o.fExp)) case _ => return AST.MCoreExpTransformer.PreResult(F, MNone()) } } - override def postCoreExp(o: AST.CoreExp): MOption[AST.CoreExp] = { - var rOpt = MOption.none[AST.CoreExp]() + override def postCoreExpBase(o: AST.CoreExp.Base): MOption[AST.CoreExp.Base] = { + var rOpt = MOption.none[AST.CoreExp.Base]() var i = 0 + var patterns = rwPatterns while (rOpt.isEmpty && i < patterns.size) { val pattern = patterns(i) - val (from, to): (AST.CoreExp, AST.CoreExp) = pattern.exp match { - case AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right) => - if (pattern.rightToLeft) (right, left) else (left, right) - case _ => halt("TODO") + var assumptions = ISZ[AST.CoreExp.Base]() + def arrowRec(e: AST.CoreExp): AST.CoreExp = { + e match { + case e: AST.CoreExp.Arrow => + assumptions = assumptions :+ e.left + return arrowRec(e.right) + case _ => return e + } } - unify(T, th, pattern.localPatternSet, ISZ(from), ISZ(o)) match { - case Either.Left(m) => - val o2 = LocalSubstitutor(m).transformCoreExp(to).getOrElse(o) - val o3 = eval(th, EvalConfig.all, o2).getOrElse(o) - trace = trace :+ TraceElement(pattern.name, pattern.rightToLeft, pattern.exp, o, o2, o3, o3) - rOpt = MSome(o3) - case _ => + def tryPattern(): Unit = { + val (from, to): (AST.CoreExp.Base, AST.CoreExp.Base) = arrowRec(pattern.exp) match { + case AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right) => + if (pattern.rightToLeft) (right, left) else (left, right) + case _ => halt("Infeasible") + } + def last(m: UnificationMap): Unit = { + val o2 = LocalSubstitutor(m).transformCoreExpBase(to).getOrElse(o) + val o3Opt = evalBase(th, EvalConfig.all, o2) + val o3 = o3Opt.getOrElse(o2) + if (o == o3) { + // skip + } else if (pattern.isPermutative && o < o3) { + // skip + } else { + trace = trace :+ TraceElement(pattern.name, pattern.rightToLeft, pattern.exp, o, o2, o3Opt, o3) + rOpt = MSome(o3) + } + } + if (assumptions.isEmpty) { + unify(T, th, pattern.localPatternSet, ISZ(from), ISZ(o)) match { + case Either.Left(m) => last(m) + case _ => + } + } else { + var done = F + val pces = provenClaims.elements + def recAssumption(pendingApplications: PendingApplications, + substMap: HashMap[String, AST.Typed], + apcs: ISZ[AST.CoreExp.Base], + map: UnificationMap, + j: Z): Unit = { + if (j >= assumptions.size) { + val ems: MBox[UnificationErrorMessages] = MBox(ISZ()) + val pas = MBox(pendingApplications) + val sm = MBox(substMap) + val m = unifyExp(T, th, pattern.localPatternSet, from, o, map, pas, sm, ems) + if (ems.value.isEmpty) { + unifyPendingApplications(T, th, pattern.localPatternSet, m, pas, sm, ems) match { + case Either.Left(m2) => + for (k <- 0 until apcs.size) { + for (apc <- toCondEquiv(th, apcs(k))) { + patterns = patterns :+ Rewriter.Pattern(pattern.name :+ s"Assumption$k", F, + isPermutative(apc), HashSSet.empty, apc) + } + } + last(m2) + done = T + case _ => + } + } + return + } + val assumption = assumptions(j) + var k = 0 + while (k < pces.size && !done) { + val ems: MBox[UnificationErrorMessages] = MBox(ISZ()) + val pas = MBox(pendingApplications) + val sm = MBox(substMap) + val pc = pces(k) + val m = unifyExp(T, th, pattern.localPatternSet, assumption, pc, map, pas, sm, ems) + if (ems.value.isEmpty) { + recAssumption(pas.value, sm.value, apcs :+ pc, m, j + 1) + } + k = k + 1 + } + } + recAssumption(ISZ(), HashMap.empty, ISZ(), HashSMap.empty, 0) + } } + tryPattern() i = i + 1 } return rOpt } - override def postCoreExpIf(o: AST.CoreExp.If): MOption[AST.CoreExp] = { + override def postCoreExpIf(o: AST.CoreExp.If): MOption[AST.CoreExp.Base] = { o.cond match { case cond: AST.CoreExp.LitB => return if (cond.value) MSome(o.tExp) else MSome(o.fExp) case _ => return MNone() @@ -167,14 +244,15 @@ object RewritingSystem { object Rewriter { @datatype class Pattern(val name: ISZ[String], val rightToLeft: B, + val isPermutative: B, val localPatternSet: LocalPatternSet, val exp: AST.CoreExp) } @strictpure def paramId(n: String): String = s"_$n" - @pure def translate(th: TypeHierarchy, exp: AST.Exp): AST.CoreExp = { - @pure def recBody(body: AST.Body, funStack: FunStack, localMap: LocalMap): AST.CoreExp = { + @pure def translate(th: TypeHierarchy, isPattern: B, exp: AST.Exp): AST.CoreExp.Base = { + @pure def recBody(body: AST.Body, funStack: FunStack, localMap: LocalMap): AST.CoreExp.Base = { val stmts = body.stmts var m = localMap for (i <- 0 until stmts.size - 2) { @@ -182,7 +260,7 @@ object RewritingSystem { } return recAssignExp(stmts(stmts.size - 1).asAssignExp, funStack, m) } - @pure def recStmt(stmt: AST.Stmt, funStack: FunStack, localMap: LocalMap): (Option[AST.CoreExp], LocalMap) = { + @pure def recStmt(stmt: AST.Stmt, funStack: FunStack, localMap: LocalMap): (Option[AST.CoreExp.Base], LocalMap) = { stmt match { case stmt: AST.Stmt.Expr => return (Some(rec(stmt.exp, funStack, localMap)), localMap) case stmt: AST.Stmt.Var => @@ -201,11 +279,11 @@ object RewritingSystem { case _ => halt(s"Infeasible: $stmt") } } - @pure def recAssignExp(ae: AST.AssignExp, funStack: FunStack, localMap: LocalMap): AST.CoreExp = { + @pure def recAssignExp(ae: AST.AssignExp, funStack: FunStack, localMap: LocalMap): AST.CoreExp.Base = { val (Some(r), _) = recStmt(ae.asStmt, funStack, localMap) return r } - @pure def rec(e: AST.Exp, funStack: FunStack, localMap: LocalMap): AST.CoreExp = { + @pure def rec(e: AST.Exp, funStack: FunStack, localMap: LocalMap): AST.CoreExp.Base = { e match { case e: AST.Exp.LitB => return AST.CoreExp.LitB(e.value) case e: AST.Exp.LitZ => return AST.CoreExp.LitZ(e.value) @@ -246,7 +324,7 @@ object RewritingSystem { return AST.CoreExp.ParamVarRef(stackSize - i, id, p._2) } } - return AST.CoreExp.LocalVarRef(res.context, id, e.typedOpt.get) + return AST.CoreExp.LocalVarRef(isPattern, res.context, id, e.typedOpt.get) case res: AST.ResolvedInfo.Var if res.isInObject => return AST.CoreExp.ObjectVarRef(res.owner, res.id, e.typedOpt.get) case res: AST.ResolvedInfo.Method => halt(s"TODO: $e") @@ -313,7 +391,7 @@ object RewritingSystem { case e: AST.Exp.StrictPureBlock => return recStmt(e.block, funStack, localMap)._1.get case e: AST.Exp.Invoke => - val args: ISZ[AST.CoreExp] = for (arg <- e.args) yield rec(arg, funStack, localMap) + val args: ISZ[AST.CoreExp.Base] = for (arg <- e.args) yield rec(arg, funStack, localMap) e.receiverOpt match { case Some(receiver) => return AST.CoreExp.Apply( @@ -323,7 +401,7 @@ object RewritingSystem { args, e.typedOpt.get) } case e: AST.Exp.InvokeNamed => - val args = MS.create[Z, AST.CoreExp](e.args.size, AST.CoreExp.LitB(F)) + val args = MS.create[Z, AST.CoreExp.Base](e.args.size, AST.CoreExp.LitB(F)) for (arg <- e.args) { args(arg.index) = rec(arg.arg, funStack, localMap) } @@ -345,13 +423,13 @@ object RewritingSystem { @pure def unifyExp(silent: B, th: TypeHierarchy, localPatterns: LocalPatternSet, - pattern: AST.CoreExp, - exp: AST.CoreExp, + pattern: AST.CoreExp.Base, + exp: AST.CoreExp.Base, init: UnificationMap, pendingApplications: MBox[PendingApplications], substMap: MBox[HashMap[String, AST.Typed]], errorMessages: MBox[UnificationErrorMessages]): UnificationMap = { - @pure def rootLocalPatternOpt(e: AST.CoreExp, args: ISZ[AST.CoreExp]): Option[(ISZ[String], String, AST.Typed, ISZ[AST.CoreExp])] = { + @pure def rootLocalPatternOpt(e: AST.CoreExp.Base, args: ISZ[AST.CoreExp.Base]): Option[(ISZ[String], String, AST.Typed, ISZ[AST.CoreExp.Base])] = { e match { case e: AST.CoreExp.LocalVarRef => val p = (e.context, e.id) @@ -430,7 +508,7 @@ object RewritingSystem { } } } - def matchPatternLocals(p: AST.CoreExp, e: AST.CoreExp): Unit = { + def matchPatternLocals(p: AST.CoreExp.Base, e: AST.CoreExp.Base): Unit = { if (errorMessages.value.nonEmpty) { return } @@ -439,7 +517,7 @@ object RewritingSystem { if (p != e) { err(p, e) } - case (p: AST.CoreExp.LocalVarRef, e) => + case (p: AST.CoreExp.LocalVarRef, e) if p.isPattern => val key = (p.context, p.id) if (localPatterns.contains(key)) { map.get(key) match { @@ -454,6 +532,10 @@ object RewritingSystem { } else if (p != e) { err(p, e) } + case (p: AST.CoreExp.LocalVarRef, e: AST.CoreExp.LocalVarRef) => + if (!(p.id == e.id && p.context == e.context)) { + err(p, e) + } case (p: AST.CoreExp.ParamVarRef, e: AST.CoreExp.ParamVarRef) => if (p.deBruijn != e.deBruijn) { err(p, e) @@ -531,8 +613,8 @@ object RewritingSystem { val n = args.size - i substitutions = substitutions + args(i) ~> AST.CoreExp.ParamVarRef(n, paramId(n.string), argTypes(i)) } - val se = Substitutor(substitutions).transformCoreExp(e).getOrElse(e) - var r: AST.CoreExp = AST.CoreExp.Fun(AST.CoreExp.Param(paramId(1.string), argTypes(args.size - 1)), se) + val se = Substitutor(substitutions).transformCoreExpBase(e).getOrElse(e) + var r: AST.CoreExp.Base = AST.CoreExp.Fun(AST.CoreExp.Param(paramId(1.string), argTypes(args.size - 1)), se) for (i <- args.size - 2 to 0 by -1) { r = AST.CoreExp.Fun(AST.CoreExp.Param(paramId((args.size - i).string), argTypes(i)), r) } @@ -572,9 +654,6 @@ object RewritingSystem { } else { matchPatternLocals(p.exp, e.exp) } - case (p: AST.CoreExp.Arrow, e: AST.CoreExp.Arrow) => - matchPatternLocals(p.left, e.left) - matchPatternLocals(p.right, e.right) case (_, _) => err(p, e) } @@ -585,47 +664,42 @@ object RewritingSystem { return map } - @pure def unify(silent: B, th: TypeHierarchy, localPatterns: LocalPatternSet, patterns: ISZ[AST.CoreExp], exps: ISZ[AST.CoreExp]): UnificationResult = { - val errorMessages: MBox[UnificationErrorMessages] = MBox(ISZ()) - val pendingApplications: MBox[PendingApplications] = MBox(ISZ()) - val substMap: MBox[HashMap[String, AST.Typed]] = MBox(HashMap.empty) - var m: UnificationMap = HashSMap.empty - for (i <- 0 until patterns.size) { - m = unifyExp(silent, th, localPatterns, patterns(i), exps(i), m, pendingApplications, substMap, errorMessages) + @pure def unifyPendingApplications(silent: B, + th: TypeHierarchy, + localPatterns: LocalPatternSet, + map: UnificationMap, + pendingApplications: MBox[PendingApplications], + substMap: MBox[HashMap[String, AST.Typed]], + errorMessages: MBox[UnificationErrorMessages]): UnificationResult = { + var m = map + //while (pendingApplications.value.nonEmpty) { + val pas = pendingApplications.value + pendingApplications.value = ISZ() + for (pa <- pas) { + val (context, id, args, e) = pa + m.get((context, id)) match { + case Some(f: AST.CoreExp.Fun) => + evalBase(th, EvalConfig.funApplicationOnly, AST.CoreExp.Apply(f, args, e.tipe)) match { + case Some(pattern) => + m = unifyExp(silent, th, localPatterns, pattern, e, m, pendingApplications, substMap, errorMessages) + case _ => + if (silent) { + if (errorMessages.value.isEmpty) { + errorMessages.value = errorMessages.value :+ "" + } + } else { + errorMessages.value = errorMessages.value :+ + st"Could not reduce '$f(${(for (arg <- args) yield arg.prettyST, ", ")})'".render + } + } + case Some(f) => errorMessages.value = errorMessages.value :+ s"Expecting to infer a function, but found '$f'" + case _ => + } if (errorMessages.value.nonEmpty) { return Either.Right(errorMessages.value) } } - - /* while (pendingApplications.value.nonEmpty) */ { - val pas = pendingApplications.value - pendingApplications.value = ISZ() - for (pa <- pas) { - val (context, id, args, e) = pa - m.get((context, id)) match { - case Some(f: AST.CoreExp.Fun) => - eval(th, EvalConfig.funApplicationOnly, AST.CoreExp.Apply(f, args, e.tipe)) match { - case Some(pattern) => - m = unifyExp(silent, th, localPatterns, pattern, e, m, pendingApplications, substMap, errorMessages) - case _ => - if (silent) { - if (errorMessages.value.isEmpty) { - errorMessages.value = errorMessages.value :+ "" - } - } else { - errorMessages.value = errorMessages.value :+ - st"Could not reduce '$f(${(for (arg <- args) yield arg.prettyST, ", ")})'".render - } - } - case Some(f) => errorMessages.value = errorMessages.value :+ s"Expecting to infer a function, but found '$f'" - case _ => - } - if (errorMessages.value.nonEmpty) { - return Either.Right(errorMessages.value) - } - } - } - + //} for (localPattern <- localPatterns.elements if !m.contains(localPattern)) { if (silent) { if (errorMessages.value.isEmpty) { @@ -639,6 +713,20 @@ object RewritingSystem { else Either.Left(HashSMap ++ (for (p <- m.entries) yield (p._1, p._2.subst(substMap.value)))) } + @pure def unify(silent: B, th: TypeHierarchy, localPatterns: LocalPatternSet, patterns: ISZ[AST.CoreExp.Base], exps: ISZ[AST.CoreExp.Base]): UnificationResult = { + val errorMessages: MBox[UnificationErrorMessages] = MBox(ISZ()) + val pendingApplications: MBox[PendingApplications] = MBox(ISZ()) + val substMap: MBox[HashMap[String, AST.Typed]] = MBox(HashMap.empty) + var m: UnificationMap = HashSMap.empty + for (i <- 0 until patterns.size) { + m = unifyExp(silent, th, localPatterns, patterns(i), exps(i), m, pendingApplications, substMap, errorMessages) + if (errorMessages.value.nonEmpty) { + return Either.Right(errorMessages.value) + } + } + return unifyPendingApplications(silent, th, localPatterns, m, pendingApplications, substMap, errorMessages) + } + @strictpure def evalBinaryLit(lit1: AST.CoreExp.Lit, op: String, lit2: AST.CoreExp.Lit): AST.CoreExp.Lit = lit1 match { case lit1: AST.CoreExp.LitB => @@ -788,9 +876,33 @@ object RewritingSystem { } @pure def eval(th: TypeHierarchy, config: EvalConfig, exp: AST.CoreExp): Option[AST.CoreExp] = { - @strictpure def incDeBruijnMap(deBruijnMap: HashMap[Z, AST.CoreExp], inc: Z): HashMap[Z, AST.CoreExp] = + exp match { + case exp: AST.CoreExp.Arrow => + var changed = F + val left: AST.CoreExp.Base = evalBase(th, config, exp.left) match { + case Some(l) => + changed = T + l + case _ => exp.left + } + val right: AST.CoreExp = eval(th, config, exp.right) match { + case Some(r) => + changed = T + r + case _ => exp.right + } + return if (changed) Some(AST.CoreExp.Arrow(left, right)) else None() + case exp: AST.CoreExp.Base => evalBase(th, config, exp) match { + case Some(e) => return Some(e) + case _ => return None() + } + } + } + + @pure def evalBase(th: TypeHierarchy, config: EvalConfig, exp: AST.CoreExp.Base): Option[AST.CoreExp.Base] = { + @strictpure def incDeBruijnMap(deBruijnMap: HashMap[Z, AST.CoreExp.Base], inc: Z): HashMap[Z, AST.CoreExp.Base] = HashMap ++ (for (p <- deBruijnMap.entries) yield (p._1 + inc, p._2)) - @pure def rec(deBruijnMap: HashMap[Z, AST.CoreExp], e: AST.CoreExp): Option[AST.CoreExp] = { + @pure def rec(deBruijnMap: HashMap[Z, AST.CoreExp.Base], e: AST.CoreExp.Base): Option[AST.CoreExp.Base] = { e match { case _: AST.CoreExp.Lit => return None() case _: AST.CoreExp.LocalVarRef => return None() @@ -802,18 +914,33 @@ object RewritingSystem { case _: AST.CoreExp.ObjectVarRef => return None() case e: AST.CoreExp.Binary => var changed = F - val left: AST.CoreExp = rec(deBruijnMap, e.left) match { + val left: AST.CoreExp.Base = rec(deBruijnMap, e.left) match { case Some(l) => changed = T l case _ => e.left } - val right: AST.CoreExp = rec(deBruijnMap, e.right) match { + val right: AST.CoreExp.Base = rec(deBruijnMap, e.right) match { case Some(r) => changed = T r case _ => e.right } + if (config.equality) { + if (left == right) { + e.op match { + case AST.Exp.BinaryOp.EquivUni => return Some(AST.CoreExp.LitB(T)) + case AST.Exp.BinaryOp.InequivUni => return Some(AST.CoreExp.LitB(F)) + case AST.Exp.BinaryOp.Lt => return Some(AST.CoreExp.LitB(F)) + case AST.Exp.BinaryOp.Le => return Some(AST.CoreExp.LitB(T)) + case AST.Exp.BinaryOp.Gt => return Some(AST.CoreExp.LitB(F)) + case AST.Exp.BinaryOp.Ge => return Some(AST.CoreExp.LitB(T)) + case AST.Exp.BinaryOp.Eq => return Some(AST.CoreExp.LitB(T)) + case AST.Exp.BinaryOp.Ne => return Some(AST.CoreExp.LitB(F)) + case _ => + } + } + } if (config.constantPropagation) { (left, right) match { case (left: AST.CoreExp.Lit, right: AST.CoreExp.Lit) => return Some(evalBinaryLit(left, e.op, right)) @@ -823,7 +950,7 @@ object RewritingSystem { return if (changed) Some(e(left = left, right = right)) else None() case e: AST.CoreExp.Unary => var changed = F - val ue: AST.CoreExp = rec(deBruijnMap, e.exp) match { + val ue: AST.CoreExp.Base = rec(deBruijnMap, e.exp) match { case Some(exp2) => changed = T exp2 @@ -843,13 +970,13 @@ object RewritingSystem { } case e: AST.CoreExp.Update => var changed = F - val receiver: AST.CoreExp = rec(deBruijnMap, e.exp) match { + val receiver: AST.CoreExp.Base = rec(deBruijnMap, e.exp) match { case Some(exp2) => changed = T exp2 case _ => e.exp } - val arg: AST.CoreExp = rec(deBruijnMap, e.arg) match { + val arg: AST.CoreExp.Base = rec(deBruijnMap, e.arg) match { case Some(arg2) => changed = T arg2 @@ -858,13 +985,13 @@ object RewritingSystem { return if (changed) Some(e(exp = receiver, arg = arg)) else None() case e: AST.CoreExp.Indexing => var changed = F - val receiver: AST.CoreExp = rec(deBruijnMap, e.exp) match { + val receiver: AST.CoreExp.Base = rec(deBruijnMap, e.exp) match { case Some(exp2) => changed = T exp2 case _ => e.exp } - val index: AST.CoreExp = rec(deBruijnMap, e.index) match { + val index: AST.CoreExp.Base = rec(deBruijnMap, e.index) match { case Some(index2) => changed = T index2 @@ -873,19 +1000,19 @@ object RewritingSystem { return if (changed) Some(e(exp = receiver, index = index)) else None() case e: AST.CoreExp.IndexingUpdate => var changed = F - val receiver: AST.CoreExp = rec(deBruijnMap, e.exp) match { + val receiver: AST.CoreExp.Base = rec(deBruijnMap, e.exp) match { case Some(exp2) => changed = T exp2 case _ => e.exp } - val index: AST.CoreExp = rec(deBruijnMap, e.index) match { + val index: AST.CoreExp.Base = rec(deBruijnMap, e.index) match { case Some(index2) => changed = T index2 case _ => e.index } - val arg: AST.CoreExp = rec(deBruijnMap, e.arg) match { + val arg: AST.CoreExp.Base = rec(deBruijnMap, e.arg) match { case Some(arg2) => changed = T arg2 @@ -894,7 +1021,7 @@ object RewritingSystem { return if (changed) Some(e(exp = receiver, index = index, arg = arg)) else None() case e: AST.CoreExp.Tuple => var changed = F - var args = ISZ[AST.CoreExp]() + var args = ISZ[AST.CoreExp.Base]() for (arg <- e.args) { rec(deBruijnMap, arg) match { case Some(arg2) => @@ -907,7 +1034,7 @@ object RewritingSystem { return if (changed) Some(e(args = args)) else None() case e: AST.CoreExp.If => var changed = F - val cond: AST.CoreExp = rec(deBruijnMap, e.cond) match { + val cond: AST.CoreExp.Base = rec(deBruijnMap, e.cond) match { case Some(AST.CoreExp.LitB(b)) if config.constantPropagation => return if (b) rec(deBruijnMap, e.tExp) else rec(deBruijnMap, e.fExp) case Some(c) => @@ -925,7 +1052,7 @@ object RewritingSystem { changed = T case _ => } - var args = ISZ[AST.CoreExp]() + var args = ISZ[AST.CoreExp.Base]() for (arg <- e.args) { rec(deBruijnMap, arg) match { case Some(arg2) => @@ -938,7 +1065,7 @@ object RewritingSystem { op match { case f: AST.CoreExp.Fun if config.funApplication => var params = ISZ[(String, AST.Typed)]() - def recParams(fe: AST.CoreExp): AST.CoreExp = { + def recParams(fe: AST.CoreExp.Base): AST.CoreExp.Base = { fe match { case fe: AST.CoreExp.Fun if params.size < args.size => params = params :+ (fe.param.id, fe.param.tipe) @@ -966,7 +1093,7 @@ object RewritingSystem { return if (changed) Some(e(exp = op, args = args)) else None() case e: AST.CoreExp.Fun => var changed = F - val body: AST.CoreExp = rec(incDeBruijnMap(deBruijnMap, 1), e.exp) match { + val body: AST.CoreExp.Base = rec(incDeBruijnMap(deBruijnMap, 1), e.exp) match { case Some(b) => changed = T b @@ -975,7 +1102,7 @@ object RewritingSystem { return if (changed) Some(e(exp = body)) else None() case e: AST.CoreExp.Quant => var changed = F - val body: AST.CoreExp = rec(incDeBruijnMap(deBruijnMap, 1), e.exp) match { + val body: AST.CoreExp.Base = rec(incDeBruijnMap(deBruijnMap, 1), e.exp) match { case Some(b) => changed = T b @@ -992,63 +1119,50 @@ object RewritingSystem { case _ => } return if (changed) Some(e(exp = receiver)) else None() - case e: AST.CoreExp.Arrow => - var changed = F - val left: AST.CoreExp = rec(deBruijnMap, e.left) match { - case Some(e1) => - changed = T - e1 - case _ => e.left - } - val right: AST.CoreExp = rec(deBruijnMap, e.right) match { - case Some(e2) => - changed = T - e2 - case _ => e.right - } - return if (changed) Some(e(left = left, right = right)) else None() } } return rec(HashMap.empty, exp) } @pure def toCondEquiv(th: TypeHierarchy, exp: AST.CoreExp): ISZ[AST.CoreExp] = { - var done = ISZ[AST.CoreExp]() - var r = ISZ[AST.CoreExp]() - exp match { - case exp: AST.CoreExp.Unary if exp.op == AST.Exp.UnaryOp.Not => - case exp: AST.CoreExp.Binary => - exp.op match { - case AST.Exp.BinaryOp.Arrow => done = done :+ exp - case AST.Exp.BinaryOp.EquivUni => done = done :+ exp - case AST.Exp.BinaryOp.Imply => done = done :+ AST.CoreExp.Arrow(exp.left, exp.right) - case AST.Exp.BinaryOp.And => r = r ++ ISZ[AST.CoreExp](exp.left, exp.right) - case _ => r = r :+ exp - } - case exp: AST.CoreExp.Quant if exp.kind == AST.CoreExp.Quant.Kind.ForAll => - r = r :+ eval(th, EvalConfig.quantApplicationOnly, - AST.CoreExp.Apply( - exp, - ISZ(AST.CoreExp.LocalVarRef(ISZ(), paramId(exp.param.id), exp.param.tipe)), - AST.Typed.b)).get - case exp: AST.CoreExp.If => done = done ++ ISZ[AST.CoreExp]( - AST.CoreExp.Arrow(exp.cond, exp.tExp), - AST.CoreExp.Arrow(AST.CoreExp.Unary(AST.Exp.UnaryOp.Not, exp.cond), exp.fExp) - ) - case _ => r = r :+ exp + @pure def toEquiv(e: AST.CoreExp.Base): AST.CoreExp.Base = { + e match { + case AST.CoreExp.Binary(_, AST.Exp.BinaryOp.EquivUni, _) => return e + case _ => return AST.CoreExp.Binary(e, AST.Exp.BinaryOp.EquivUni, AST.CoreExp.LitB(T)) + } } - for (e <- r) { + @pure def toCondEquivH(e: AST.CoreExp.Base): ISZ[AST.CoreExp] = { e match { - case e: AST.CoreExp.Arrow => done = done :+ e + case e: AST.CoreExp.Unary if e.op == AST.Exp.UnaryOp.Not => + return ISZ(AST.CoreExp.Binary(e.exp, AST.Exp.BinaryOp.EquivUni, AST.CoreExp.LitB(F))) case e: AST.CoreExp.Binary => e.op match { - case AST.Exp.BinaryOp.EquivUni => done = done :+ e - case _ => done = done :+ AST.CoreExp.Binary(e, AST.Exp.BinaryOp.EquivUni, AST.CoreExp.LitB(T)) + case AST.Exp.BinaryOp.EquivUni => return ISZ(e) + case AST.Exp.BinaryOp.Imply => + return for (r <- toCondEquivH(e.right)) yield AST.CoreExp.Arrow(e.left, r) + case AST.Exp.BinaryOp.And => + return toCondEquivH(e.left) ++ toCondEquivH(e.right) + case _ => return ISZ(toEquiv(e)) } - case _ => done = done :+ AST.CoreExp.Binary(e, AST.Exp.BinaryOp.EquivUni, AST.CoreExp.LitB(T)) + case e: AST.CoreExp.Quant if e.kind == AST.CoreExp.Quant.Kind.ForAll => + return toCondEquivH(evalBase(th, EvalConfig.quantApplicationOnly, + AST.CoreExp.Apply( + e, + ISZ(AST.CoreExp.LocalVarRef(T, ISZ(), paramId(e.param.id), e.param.tipe)), + AST.Typed.b)).get) + case e: AST.CoreExp.If => + return (for (t <- toCondEquivH(e.tExp)) yield AST.CoreExp.Arrow(e.cond, t).asInstanceOf[AST.CoreExp]) ++ + (for (f <- toCondEquivH(e.fExp)) yield AST.CoreExp.Arrow(e.cond, f).asInstanceOf[AST.CoreExp]) + case e => return ISZ(toEquiv(e)) } } - return done + @pure def rec(e: AST.CoreExp): ISZ[AST.CoreExp] = { + e match { + case AST.CoreExp.Arrow(left, right) => return for (r <- rec(right)) yield AST.CoreExp.Arrow(left, r) + case e: AST.CoreExp.Base => return toCondEquivH(e) + } + } + return rec(exp) } def patternsOf(th: TypeHierarchy, cache: Logika.Cache, name: ISZ[String], rightToLeft: B): ISZ[Rewriter.Pattern] = { @@ -1065,11 +1179,11 @@ object RewritingSystem { for (p <- params) { localPatternSet = localPatternSet + (info.name, p.idOpt.get.value) } - RewritingSystem.translate(th, c) - case c => RewritingSystem.translate(th, c) + RewritingSystem.translate(th, T, c) + case c => RewritingSystem.translate(th, T, c) } for (c <- RewritingSystem.toCondEquiv(th, claim)) { - r = r :+ Rewriter.Pattern(name, F, localPatternSet, c) + r = r :+ Rewriter.Pattern(name, F, isPermutative(c), localPatternSet, c) } case info: Info.Fact => var localPatternSet: RewritingSystem.LocalPatternSet = HashSSet.empty @@ -1079,11 +1193,11 @@ object RewritingSystem { for (p <- params) { localPatternSet = localPatternSet + (info.name, p.idOpt.get.value) } - RewritingSystem.translate(th, c) - case c => RewritingSystem.translate(th, c) + RewritingSystem.translate(th, T, c) + case c => RewritingSystem.translate(th, T, c) } for (c <- RewritingSystem.toCondEquiv(th, claim)) { - r = r :+ Rewriter.Pattern(name, F, localPatternSet, c) + r = r :+ Rewriter.Pattern(name, F, isPermutative(c), localPatternSet, c) } } case info: Info.RsVal => r = r ++ retrievePatterns(th, cache, info.ast.init) @@ -1134,4 +1248,12 @@ object RewritingSystem { } return r } + + @strictpure def isPermutative(exp: AST.CoreExp): B = + exp match { + case exp: AST.CoreExp.Arrow => isPermutative(exp.right) + case AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right) => + left.numberPattern(0)._2 == right.numberPattern(0)._2 + case _ => F + } } diff --git a/shared/src/main/scala/org/sireum/logika/StepProofContext.scala b/shared/src/main/scala/org/sireum/logika/StepProofContext.scala index 115ea93f..9689ce64 100644 --- a/shared/src/main/scala/org/sireum/logika/StepProofContext.scala +++ b/shared/src/main/scala/org/sireum/logika/StepProofContext.scala @@ -26,6 +26,7 @@ package org.sireum.logika import org.sireum._ +import org.sireum.lang.tipe.TypeHierarchy import org.sireum.lang.{ast => AST} @datatype trait StepProofContext { @@ -35,10 +36,14 @@ import org.sireum.lang.{ast => AST} object StepProofContext { - @datatype class Regular(val stepNo: AST.ProofAst.StepId, + @datatype class Regular(val th: TypeHierarchy, + val stepNo: AST.ProofAst.StepId, val exp: AST.Exp, val claims: ISZ[State.Claim]) extends StepProofContext { @strictpure override def prettyST: ST = st"(${stepNo.prettyST}, ${exp.prettyST}, ${(for (claim <- claims) yield claim.toRawST, ", ")})" + @memoize def coreExpClaim: AST.CoreExp.Base = { + return RewritingSystem.translate(th, F, exp) + } } @datatype class SubProof(val stepNo: AST.ProofAst.StepId, diff --git a/shared/src/main/scala/org/sireum/logika/Task.scala b/shared/src/main/scala/org/sireum/logika/Task.scala index 486dc66c..292a9f14 100644 --- a/shared/src/main/scala/org/sireum/logika/Task.scala +++ b/shared/src/main/scala/org/sireum/logika/Task.scala @@ -116,7 +116,7 @@ object Task { else theorem.claim) val spcEntries = p._2.entries for (i <- spcEntries.size - 1 to 0 by -1 if spcEntries(i)._2.isInstanceOf[StepProofContext.Regular]) { - val StepProofContext.Regular(_, claim, _) = spcEntries(i)._2 + val StepProofContext.Regular(_, _, claim, _) = spcEntries(i)._2 if (normClaim == th.normalizeExp(claim)) { if (logika.config.detailedInfo) { reporter.inform(normClaim.posOpt.get, Logika.Reporter.Info.Kind.Verified, diff --git a/shared/src/main/scala/org/sireum/logika/plugin/InceptionPlugin.scala b/shared/src/main/scala/org/sireum/logika/plugin/InceptionPlugin.scala index ac6af6d1..6ffb21fb 100644 --- a/shared/src/main/scala/org/sireum/logika/plugin/InceptionPlugin.scala +++ b/shared/src/main/scala/org/sireum/logika/plugin/InceptionPlugin.scala @@ -385,12 +385,12 @@ import InceptionPlugin._ // Uncomment to test RewritingSystem unification algorithm /* { - var patterns = ISZ[AST.CoreExp]() - var exps = ISZ[AST.CoreExp]() + var patterns = ISZ[AST.CoreExp.Base]() + var exps = ISZ[AST.CoreExp.Base]() for (q <- fromToStepIdChecks) { val (from, to, _, _) = q - patterns = patterns :+ org.sireum.logika.RewritingSystem.translate(logika.th, from) - exps = exps :+ org.sireum.logika.RewritingSystem.translate(logika.th, to) + patterns = patterns :+ org.sireum.logika.RewritingSystem.translate(logika.th, T, from) + exps = exps :+ org.sireum.logika.RewritingSystem.translate(logika.th, F, to) } val localPatternSet = HashSSet ++ (for (id <- paramIds.elements) yield (context, id)) org.sireum.logika.RewritingSystem.unify(F, logika.th, localPatternSet, patterns, exps) match { diff --git a/shared/src/main/scala/org/sireum/logika/plugin/PredNatDedPlugin.scala b/shared/src/main/scala/org/sireum/logika/plugin/PredNatDedPlugin.scala index 9cdcc5e9..e6ec808c 100644 --- a/shared/src/main/scala/org/sireum/logika/plugin/PredNatDedPlugin.scala +++ b/shared/src/main/scala/org/sireum/logika/plugin/PredNatDedPlugin.scala @@ -156,9 +156,9 @@ object PredNatDedPlugin { } val ISZ(existsP, subProofNo) = argsOpt.get val quant: AST.Exp.QuantType = spcMap.get(existsP) match { - case Some(StepProofContext.Regular(_, q@AST.Exp.QuantType(F, AST.Exp.Fun(_, _, _: AST.Stmt.Expr)), _)) => + case Some(StepProofContext.Regular(_, _, q@AST.Exp.QuantType(F, AST.Exp.Fun(_, _, _: AST.Stmt.Expr)), _)) => logika.th.normalizeQuantType(q).asInstanceOf[AST.Exp.QuantType] - case Some(StepProofContext.Regular(_, q@AST.Exp.QuantRange(F, _, _, _, AST.Exp.Fun(_, _, _: AST.Stmt.Expr)), _)) => + case Some(StepProofContext.Regular(_, _, q@AST.Exp.QuantRange(F, _, _, _, AST.Exp.Fun(_, _, _: AST.Stmt.Expr)), _)) => logika.th.normalizeQuantType(q).asInstanceOf[AST.Exp.QuantType] case _ => reporter.error(existsP.posOpt, Logika.kind, "Expecting a simple existential quantified type/range claim") diff --git a/shared/src/main/scala/org/sireum/logika/plugin/PropNatDedPlugin.scala b/shared/src/main/scala/org/sireum/logika/plugin/PropNatDedPlugin.scala index aa80d478..4c9f2dde 100644 --- a/shared/src/main/scala/org/sireum/logika/plugin/PropNatDedPlugin.scala +++ b/shared/src/main/scala/org/sireum/logika/plugin/PropNatDedPlugin.scala @@ -136,16 +136,16 @@ import org.sireum.logika.Logika.Reporter case string"OrE" => val ISZ(orClaimNo, leftSubProofNo, rightSubProofNo) = args val orClaim: AST.Exp.Binary = spcMap.get(orClaimNo) match { - case Some(StepProofContext.Regular(_, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryOr) => + case Some(StepProofContext.Regular(_, _, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryOr) => exp - case Some(StepProofContext.Regular(_, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryLe) => + case Some(StepProofContext.Regular(_, _, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryLe) => AST.Exp.Binary( exp(op = AST.Exp.BinaryOp.Lt, attr = exp.attr(resOpt = Some(AST.ResolvedInfo.BuiltIn(AST.ResolvedInfo.BuiltIn.Kind.BinaryLt)))), AST.Exp.BinaryOp.Or, exp(op = AST.Exp.BinaryOp.Eq, attr = exp.attr(resOpt = Some(AST.ResolvedInfo.BuiltIn(AST.ResolvedInfo.BuiltIn.Kind.BinaryEq)))), AST.ResolvedAttr(exp.posOpt, Some(AST.ResolvedInfo.BuiltIn(AST.ResolvedInfo.BuiltIn.Kind.BinaryOr)), AST.Typed.bOpt) ) - case Some(StepProofContext.Regular(_, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryGe) => + case Some(StepProofContext.Regular(_, _, exp: AST.Exp.Binary, _)) if isBuiltIn(exp, AST.ResolvedInfo.BuiltIn.Kind.BinaryGe) => AST.Exp.Binary( exp(op = AST.Exp.BinaryOp.Gt, attr = exp.attr(resOpt = Some(AST.ResolvedInfo.BuiltIn(AST.ResolvedInfo.BuiltIn.Kind.BinaryGt)))), AST.Exp.BinaryOp.Or, diff --git a/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala b/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala index 21785efd..b635d954 100644 --- a/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala +++ b/shared/src/main/scala/org/sireum/logika/plugin/RewritePlugin.scala @@ -29,7 +29,7 @@ package org.sireum.logika.plugin import org.sireum._ import org.sireum.lang.{ast => AST} import org.sireum.logika.Logika.Reporter -import org.sireum.logika.RewritingSystem.Rewriter +import org.sireum.logika.RewritingSystem.{Rewriter, toCondEquiv} import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext} @datatype class RewritePlugin extends JustificationPlugin { @@ -60,8 +60,13 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext spcMap: HashSMap[AST.ProofAst.StepId, StepProofContext], state: State, step: AST.ProofAst.Step.Regular, reporter: Logika.Reporter): Plugin.Result = { @strictpure def emptyResult: Plugin.Result = Plugin.Result(F, state.nextFresh, ISZ()) + @strictpure def checkRightMostLit(exp: AST.CoreExp): B = exp match { + case exp: AST.CoreExp.Arrow => checkRightMostLit(exp.right) + case _: AST.CoreExp.Lit => T + case _ => F + } val just = step.just.asInstanceOf[AST.ProofAst.Step.Justification.Apply] - val patterns = RewritingSystem.retrievePatterns(logika.th, cache, just.args(0)) + var patterns = RewritingSystem.retrievePatterns(logika.th, cache, just.args(0)) val from: AST.ProofAst.StepId = AST.Util.toStepIds(ISZ(just.args(1)), Logika.kind, reporter) match { case Some(s) => s(0) case _ => return emptyResult @@ -72,17 +77,36 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext reporter.error(from.posOpt, Logika.kind, s"Expecting a regular proof step") return emptyResult } - val rw = Rewriter(logika.th, patterns, ISZ()) - val fromCoreClaim = RewritingSystem.translate(logika.th, fromClaim) + var provenClaims = HashSSet.empty[AST.CoreExp.Base] + if (just.hasWitness) { + for (w <- just.witnesses) { + spcMap.get(w) match { + case Some(spc: StepProofContext.Regular) => + provenClaims = provenClaims + spc.coreExpClaim + case _ => + reporter.error(from.posOpt, Logika.kind, s"Expecting a regular proof step for $w") + } + } + } else { + for (spc <- spcMap.values) { + spc match { + case spc: StepProofContext.Regular => + provenClaims = provenClaims + spc.coreExpClaim + case _ => + } + } + } + val rwPc = Rewriter(logika.th, provenClaims, patterns, ISZ()) + val fromCoreClaim = RewritingSystem.translate(logika.th, F, fromClaim) var done = F var rwClaim = fromCoreClaim var i = 0 while (!done && i < maxIt) { - rwClaim = rw.transformCoreExp(rwClaim) match { + rwClaim = rwPc.transformCoreExpBase(rwClaim) match { case MSome(c) => - if (rw.trace.nonEmpty) { - val last = rw.trace.size - 1 - rw.trace = rw.trace(last ~> rw.trace(last)(done = c)) + if (rwPc.trace.nonEmpty) { + val last = rwPc.trace.size - 1 + rwPc.trace = rwPc.trace(last ~> rwPc.trace(last)(done = c)) } c case _ => @@ -91,7 +115,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext } i = i + 1 } - val stepClaim = RewritingSystem.translate(logika.th, step.claim) + val stepClaim = RewritingSystem.translate(logika.th, F, step.claim) if (rwClaim == stepClaim) { reporter.inform(just.id.attr.posOpt.get, Reporter.Info.Kind.Verified, st"""Matched: @@ -102,7 +126,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext | |Rewriting trace: | - |${(for (te <- rw.trace) yield te.toST, "\n\n")} + |${(for (te <- rwPc.trace) yield te.toST, "\n\n")} |""".render) val q = logika.evalRegularStepClaimRtCheck(smt2, cache, F, state, step.claim, step.id.posOpt, reporter) val (stat, nextFresh, claims) = (q._1, q._2, q._3 :+ q._4) @@ -117,7 +141,7 @@ import org.sireum.logika.{Logika, RewritingSystem, Smt2, State, StepProofContext | |Rewriting trace: | - |${(for (te <- rw.trace) yield te.toST, "\n\n")} + |${(for (te <- rwPc.trace) yield te.toST, "\n\n")} |""".render) return emptyResult }