Skip to content

Commit

Permalink
Rewriting system.
Browse files Browse the repository at this point in the history
  • Loading branch information
robby-phd committed Feb 21, 2024
1 parent 1ec009e commit 137b450
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 143 deletions.
24 changes: 23 additions & 1 deletion jvm/src/test/scala/org/sireum/logika/example/rewrite.sc
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,26 @@ import Rules._
11 (a(i ~> t1)(k ~> t2)(i) == t2) by Simpl
//@formatter:on
)
}
}


@strictpure def incN(x: Z, n: Z): Z = x + n

@pure def incNTest(a: Z): Unit = {
Deduce(
//@formatter:off
1 (a + 1 == incN(a, 1)) by Simpl
//@formatter:on
)
}


@abs def inc(x: Z): Z = x + 1

@pure def incTest(a: Z): Unit = {
Deduce(
//@formatter:off
1 (a + 1 == inc(a)) by RSimpl(RS(inc _))
//@formatter:on
)
}
2 changes: 1 addition & 1 deletion shared/src/main/scala/org/sireum/logika/Logika.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2677,7 +2677,7 @@ import Util._
val (receiverModified, modLocals) = contract.modifiedLocalVars(lComp.context.receiverLocalTypeOpt, typeSubstMap)

val pfOpt: Option[State.ProofFun] = if (config.pureFun ||
(info.sig.isPure && !info.hasBody && info.contract.isEmpty)) {
(info.sig.funType.isPureFun && !info.hasBody && info.contract.isEmpty)) {
val typedAttr = AST.TypedAttr(posOpt, None())
val (s8, pf) = Util.pureMethod(context.nameExePathMap, context.maxCores, context.fileOptions, th, config,
plugins, smt2, cache, s1, lComp.context.receiverTypeOpt, info.sig.funType.subst(typeSubstMap),
Expand Down
157 changes: 123 additions & 34 deletions shared/src/main/scala/org/sireum/logika/RewritingSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ object RewritingSystem {
type UnificationMap = HashSMap[Local, AST.CoreExp.Base]
type UnificationErrorMessages = ISZ[String]
type UnificationResult = Either[UnificationMap, UnificationErrorMessages]
type MethodDesc = (ISZ[String], B)
type MethodPatternMap = HashSMap[MethodDesc, Rewriter.Pattern.Method]
type UnfoldingNumMap = HashSMap[MethodDesc, Z]

@datatype class EvalConfig(val constant: B,
val unary: B,
Expand Down Expand Up @@ -167,13 +170,13 @@ object RewritingSystem {
@record class Rewriter(val maxCores: Z,
val th: TypeHierarchy,
val provenClaims: HashSMap[AST.ProofAst.StepId, AST.CoreExp.Base],
val rwPatterns: ISZ[Rewriter.Pattern],
val patterns: ISZ[Rewriter.Pattern.Claim],
val methodPatterns: MethodPatternMap,
val shouldTrace: B,
val shouldTraceEval: B,
var done: B,
var trace: ISZ[Trace]) {
val patterns: ISZ[Rewriter.Pattern.Claim] = for (p <- rwPatterns if p.isInstanceOf[Rewriter.Pattern.Claim]) yield
p.asInstanceOf[Rewriter.Pattern.Claim]
val unfoldingMap: MBox[UnfoldingNumMap] = MBox(HashSMap.empty)
@memoize def provenClaimStepIdMap: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId] = {
@strictpure def conjuncts(e: AST.CoreExp.Base): ISZ[AST.CoreExp.Base] = {
e match {
Expand Down Expand Up @@ -343,6 +346,32 @@ object RewritingSystem {
}
val hasChanged: B = r.nonEmpty
val o2: AST.CoreExp.Base = r.getOrElse(o)
val shouldUnfold: B = o2 match {
case o2: AST.CoreExp.ObjectVarRef if methodPatterns.contains((o2.owner :+ o2.id, F)) => T
case o2: AST.CoreExp.Select =>
val infoOpt: Option[Info.Method] = o2.exp.tipe match {
case t: AST.Typed.Name =>
th.typeMap.get(t.ids).get match {
case ti: TypeInfo.Adt => ti.methods.get(o2.id)
case ti: TypeInfo.Sig => ti.methods.get(o2.id)
case _ => None()
}
case _ => None()
}
infoOpt match {
case Some(info) if methodPatterns.contains((info.name, info.isInObject)) => T
case _ => F
}
case _ => F
}
if (shouldUnfold) {
evalBase(th, EvalConfig.none, cache, methodPatterns, unfoldingMap, 1, HashSMap.empty, o2, shouldTraceEval) match {
case Some((o3, t)) =>
trace = trace ++ t
return Some(o3)
case _ =>
}
}
val postR = rewrite(cache, o2)
if (postR.nonEmpty) {
return postR
Expand Down Expand Up @@ -375,7 +404,7 @@ object RewritingSystem {
case AST.CoreExp.Binary(left, AST.Exp.BinaryOp.EquivUni, right, _) => (left, right)
case _ => halt("Infeasible")
}
def last(m: UnificationMap, patterns2: ISZ[Rewriter.Pattern], apcs: ISZ[(AST.ProofAst.StepId, AST.CoreExp.Base)]): Unit = {
def last(m: UnificationMap, patterns2: ISZ[Rewriter.Pattern.Claim], apcs: ISZ[(AST.ProofAst.StepId, AST.CoreExp.Base)]): Unit = {
val o2: AST.CoreExp.Base = if (m.isEmpty) {
to
} else {
Expand All @@ -385,13 +414,14 @@ object RewritingSystem {
if (patterns2.isEmpty) {
o
} else {
Rewriter(maxCores, th, HashSMap.empty, patterns2, F, F, F, ISZ()).
Rewriter(maxCores, th, HashSMap.empty, patterns2, methodPatterns, F, F, F, ISZ()).
transformCoreExpBase(cache, o).getOrElse(o)
}
}
}
val (o3Opt, t): (Option[AST.CoreExp.Base], ISZ[Trace]) =
evalBase(th, EvalConfig.allButEquivSubst, cache, provenClaimStepIdMap, o2, shouldTraceEval) match {
evalBase(th, EvalConfig.allButEquivSubst, cache, methodPatterns, unfoldingMap, 1,
provenClaimStepIdMap, o2, shouldTraceEval) match {
case Some((o3o, t)) => (Some(o3o), t)
case _ => (None(), ISZ())
}
Expand Down Expand Up @@ -433,10 +463,11 @@ object RewritingSystem {
def r2l(p: Rewriter.Pattern.Claim): Rewriter.Pattern.Claim = {
return if (pattern.rightToLeft) p.toRightToLeft else p
}
val patterns2: ISZ[Rewriter.Pattern] =
for (k <- 0 until apcs.size; apc <- toCondEquiv(th, apcs(k)._2)) yield
r2l(Rewriter.Pattern.Claim(pattern.name :+ s"Assumption$k", F, isPermutative(apc), HashSSet.empty, apc))
val o2 = Rewriter(maxCores, th, HashSMap.empty, patterns2, F, F, F, ISZ()).
val patterns2: ISZ[Rewriter.Pattern.Claim] =
(for (k <- 0 until apcs.size; apc <- toCondEquiv(th, apcs(k)._2)) yield
r2l(Rewriter.Pattern.Claim(pattern.name :+ s"Assumption$k", F, isPermutative(apc), HashSSet.empty, apc))) ++
patterns
val o2 = Rewriter(maxCores, th, HashSMap.empty, patterns2, methodPatterns, F, F, F, ISZ()).
transformCoreExpBase(cache, o).getOrElse(o)
val m = unifyExp(T, th, pattern.localPatternSet, from, o2, map, pas, sm, ems)
if (ems.value.isEmpty) {
Expand Down Expand Up @@ -515,7 +546,26 @@ object RewritingSystem {
val owner: ISZ[String],
val id: String,
val params: ISZ[(String, AST.Typed)],
val exp: AST.CoreExp.Base) extends Pattern
val exp: AST.CoreExp.Base) extends Pattern {
@memoize def toFun: AST.CoreExp.Fun = {
val context = owner :+ id
var map = HashMap.empty[AST.CoreExp, AST.CoreExp.ParamVarRef]
for (i <- params.size - 1 to 0 by -1) {
val (id, t) = params(i)
map = map + AST.CoreExp.LocalVarRef(F, context, id, t) ~> AST.CoreExp.ParamVarRef(params.size - i, id, t)
}
var r = Substitutor(map).transformCoreExpBase(exp).getOrElse(exp)
if (params.isEmpty) {
r = AST.CoreExp.Fun(AST.CoreExp.Param("_", AST.Typed.unit), r)
} else {
for (i <- params.size - 1 to 0 by -1) {
val (id, t) = params(i)
r = AST.CoreExp.Fun(AST.CoreExp.Param(id, t), r)
}
}
return r.asInstanceOf[AST.CoreExp.Fun]
}
}
}
}

Expand Down Expand Up @@ -1096,7 +1146,8 @@ object RewritingSystem {
val (context, id, args, e) = pa
m.get((context, id)) match {
case Some(f: AST.CoreExp.Fun) =>
evalBase(th, EvalConfig.funApplicationOnly, noCache, HashSMap.empty, AST.CoreExp.Apply(F, f, args, e.tipe), F) match {
evalBase(th, EvalConfig.funApplicationOnly, noCache, HashSMap.empty, MBox(HashSMap.empty), 1, HashSMap.empty,
AST.CoreExp.Apply(F, f, args, e.tipe), F) match {
case Some((pattern, _)) =>
m = unifyExp(silent, th, localPatterns, pattern, e, m, pendingApplications, substMap, errorMessages)
case _ =>
Expand Down Expand Up @@ -1300,29 +1351,35 @@ object RewritingSystem {
@pure def eval(th: TypeHierarchy,
config: EvalConfig,
cache: Logika.Cache,
methodPatterns: MethodPatternMap,
unfoldingMap: MBox[UnfoldingNumMap],
maxUnfolding: Z,
provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId],
exp: AST.CoreExp,
shouldTrace: B): Option[(AST.CoreExp, ISZ[Trace])] = {
exp match {
case exp: AST.CoreExp.Arrow =>
var changed = F
var trace = ISZ[Trace]()
val left: AST.CoreExp.Base = evalBase(th, config, cache, provenClaims, exp.left, shouldTrace) match {
val left: AST.CoreExp.Base = evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding,
provenClaims, exp.left, shouldTrace) match {
case Some((l, t)) =>
trace = trace ++ t
changed = T
l
case _ => exp.left
}
val right: AST.CoreExp = eval(th, config, cache, provenClaims, exp.right, shouldTrace) match {
val right: AST.CoreExp = eval(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding, provenClaims,
exp.right, shouldTrace) match {
case Some((r, t)) =>
trace = trace ++ t
changed = T
r
case _ => exp.right
}
return if (changed) Some((AST.CoreExp.Arrow(left, right), trace)) else None()
case exp: AST.CoreExp.Base => evalBase(th, config, cache, provenClaims, exp, shouldTrace) match {
case exp: AST.CoreExp.Base => evalBase(th, config, cache, methodPatterns, unfoldingMap, maxUnfolding,
provenClaims, exp, shouldTrace) match {
case Some((e, t)) => return Some((e, t))
case _ => return None()
}
Expand All @@ -1332,6 +1389,9 @@ object RewritingSystem {
@pure def evalBase(th: TypeHierarchy,
config: EvalConfig,
cache: Logika.Cache,
methodPatterns: MethodPatternMap,
unfoldingMap: MBox[UnfoldingNumMap],
maxUnfolding: Z,
provenClaims: HashSMap[AST.CoreExp.Base, AST.ProofAst.StepId],
exp: AST.CoreExp.Base,
shouldTrace: B): Option[(AST.CoreExp.Base, ISZ[Trace])] = {
Expand All @@ -1346,7 +1406,20 @@ object RewritingSystem {
if (right < left) AST.CoreExp.Binary(right, AST.Exp.BinaryOp.InequivUni, left, AST.Typed.b)
else AST.CoreExp.Binary(left, AST.Exp.BinaryOp.InequivUni, right, AST.Typed.b)
@strictpure def incDeBruijnMap(deBruijnMap: HashMap[Z, AST.CoreExp.Base], inc: Z): HashMap[Z, AST.CoreExp.Base] =
HashMap ++ (for (p <- deBruijnMap.entries) yield (p._1 + inc, p._2))
if (inc != 0) HashMap ++ (for (p <- deBruijnMap.entries) yield (p._1 + inc, p._2)) else deBruijnMap
def shouldUnfold(info: Info.Method): B = {
return !info.ast.hasContract && info.ast.isStrictPure &&
((info.ast.purity == AST.Purity.Abs) ->: methodPatterns.contains((info.name, info.isInObject)))
}
def unfold(info: Info.Method, receiverOpt: Option[AST.CoreExp.Base]): AST.CoreExp.Base = {
val pattern = methodPatternOf(th, cache, info)
val f = pattern.toFun
receiverOpt match {
case Some(receiver) => return AST.CoreExp.Apply(T, f, ISZ(receiver), f.exp.tipe)
case _ => return f
}
}

val eqMap: HashMap[AST.CoreExp.Base, (AST.CoreExp.Base, AST.ProofAst.StepId)] = if (config.equivSubst) {
var r = HashMap.empty[AST.CoreExp.Base, (AST.CoreExp.Base, AST.ProofAst.StepId)]
for (p <- provenClaims.entries) {
Expand Down Expand Up @@ -1411,7 +1484,17 @@ object RewritingSystem {
case Some(e2) => return Some(evalBaseH(deBruijnMap, e2).getOrElse(e2))
case _ => return None()
}
case _: AST.CoreExp.ObjectVarRef => return None()
case e: AST.CoreExp.ObjectVarRef =>
th.nameMap.get(e.owner :+ e.id) match {
case Some(info: Info.Method) if shouldUnfold(info) =>
val r = unfold(info, None())
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"unfolding", e, r)
}
return Some(r)
case _ =>
}
return None()
case e: AST.CoreExp.Binary =>
var changed = F
val left: AST.CoreExp.Base = evalBaseH(deBruijnMap, e.left) match {
Expand Down Expand Up @@ -1694,20 +1777,18 @@ object RewritingSystem {
case _ =>
}
}
receiver.tipe match {
val infoOpt: Option[Info.Method] = receiver.tipe match {
case rt: AST.Typed.Name =>
th.typeMap.get(rt.ids).get match {
case ti: TypeInfo.Adt =>
ti.methods.get(e.id) match {
case Some(info) if !info.ast.hasContract && info.ast.purity == AST.Purity.StrictPure =>
return Some(unfold(th, cache, info))
case _ =>
case io@Some(info) if shouldUnfold(info) => io
case _ => None()
}
case ti: TypeInfo.Sig =>
ti.methods.get(e.id) match {
case Some(info) if !info.ast.hasContract && info.ast.purity == AST.Purity.StrictPure =>
return Some(unfold(th, cache, info))
case _ =>
case io@Some(info) if shouldUnfold(info) => io
case _ => None()
}
case info: TypeInfo.SubZ if config.fieldAccess =>
val r: AST.CoreExp.Base = e.id match {
Expand All @@ -1729,8 +1810,20 @@ object RewritingSystem {
}
return Some(r)
case info: TypeInfo.Enum if config.fieldAccess => halt("TODO")
case _ =>
case _ => None()
}
case _ => None()
}
infoOpt match {
case Some(info) =>
val fApp = unfold(info, Some(receiver))
if (shouldTrace) {
trace = trace :+ Trace.Eval(st"unfolding", e(exp = receiver), fApp)
}
val p = evalBase(th, EvalConfig.funApplicationOnly, cache, HashSMap.empty,
MBox(HashSMap.empty), 1, HashSMap.empty, fApp, shouldTrace).get
trace = trace ++ p._2
return Some(p._1)
case _ =>
}
if (config.fieldAccess) {
Expand Down Expand Up @@ -1958,8 +2051,8 @@ object RewritingSystem {
}
val body = recParamsFun(f)
var map = incDeBruijnMap(deBruijnMap, params.size)
for (i <- params.size - 1 to 0 by -1) {
map = map + (i + 1) ~> args(i)
for (i <- 0 until params.size) {
map = map + (params.size - i) ~> args(i)
}
evalBaseH(map, body) match {
case Some(body2) =>
Expand Down Expand Up @@ -1990,8 +2083,8 @@ object RewritingSystem {
}
val body = recParamsQuqnt(q)
var map = incDeBruijnMap(deBruijnMap, params.size)
for (i <- params.size - 1 to 0 by -1) {
map = map + (i + 1) ~> args(i)
for (i <- 0 until params.size) {
map = map + (params.size - i) ~> args(i)
}
evalBaseH(map, body) match {
case Some(body2) =>
Expand Down Expand Up @@ -2063,11 +2156,6 @@ object RewritingSystem {
}
}

def unfold(th: TypeHierarchy, cache: Logika.Cache, info: Info.Method): AST.CoreExp.Base = {
val pattern = methodPatternOf(th, cache, info)
halt("TODO")
}

@pure def toCondEquiv(th: TypeHierarchy, exp: AST.CoreExp): ISZ[AST.CoreExp] = {
@pure def toEquiv(e: AST.CoreExp.Base): AST.CoreExp.Base = {
e match {
Expand All @@ -2090,6 +2178,7 @@ object RewritingSystem {
}
case e: AST.CoreExp.Quant if e.kind == AST.CoreExp.Quant.Kind.ForAll =>
return toCondEquivH(evalBase(th, EvalConfig.quantApplicationOnly, noCache, HashSMap.empty,
MBox(HashSMap.empty), 1, HashSMap.empty,
AST.CoreExp.Apply(F, e,
ISZ(AST.CoreExp.LocalVarRef(T, ISZ(), paramId(e.param.id), e.param.tipe)),
AST.Typed.b), F).get._1)
Expand Down
Loading

0 comments on commit 137b450

Please sign in to comment.