Skip to content

Commit

Permalink
feat: Track whether simp_mem made progress in the monad state. [5/?]
Browse files Browse the repository at this point in the history
This approach has the advantage of making the code far less noisy.
This prepares it for the refactor where we will pass goal MVarIds around,
which will increase the amount of noise again :)
  • Loading branch information
bollu committed Oct 31, 2024
1 parent 49914dd commit eadeaf8
Showing 1 changed file with 40 additions and 47 deletions.
87 changes: 40 additions & 47 deletions Arm/Memory/SeparateAutomation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ def Context.init (cfg : SimpMemConfig) : MetaM Context := do
structure State where
hypotheses : Array Memory.Hypothesis := #[]
rewriteFuel : Nat
changed : Bool

def State.init (cfg : SimpMemConfig) : State :=
{ rewriteFuel := cfg.rewriteFuel}
{ rewriteFuel := cfg.rewriteFuel, changed := false }

abbrev SimpMemM := StateRefT State (ReaderT Context TacticM)

Expand Down Expand Up @@ -169,6 +170,10 @@ def processingEmoji : String := "⚙️"
def consumeRewriteFuel : SimpMemM Unit :=
modify fun s => { s with rewriteFuel := s.rewriteFuel - 1 }

def setChanged : SimpMemM Unit := modify fun s => { s with changed := true }

def resetChanged : SimpMemM Unit := modify fun s => { s with changed := false }

def outofRewriteFuel? : SimpMemM Bool := do
return (← get).rewriteFuel == 0

Expand All @@ -193,15 +198,13 @@ Pattern match for memory patterns, and simplify them.
Close memory side conditions with `simplifyGoal`.
Returns if progress was made.
-/
partial def SimpMemM.simplifyExpr (e : Expr) (hyps : Array Memory.Hypothesis) : SimpMemM Bool := do
partial def SimpMemM.simplifyExpr (e : Expr) (hyps : Array Memory.Hypothesis) : SimpMemM Unit := do
consumeRewriteFuel
if ← outofRewriteFuel? then
trace[simp_mem.info] "out of fuel for rewriting, stopping."
return false

if e.isSort then
trace[simp_mem.info] "skipping sort '{e}'."
return false

if let .some er := ReadBytesExpr.ofExpr? e then
if let .some ew := WriteBytesExpr.ofExpr? er.mem then
Expand All @@ -215,79 +218,66 @@ partial def SimpMemM.simplifyExpr (e : Expr) (hyps : Array Memory.Hypothesis) :
if let .some separateProof ← proveWithOmega? separate (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {separate}"
MemSeparateProof.rewriteReadOfSeparatedWrite er ew separateProof
return true
setChanged
else if let .some subsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then do
trace[simp_mem.info] "{checkEmoji} {subset}"
MemSubsetProof.rewriteReadOfSubsetWrite er ew subsetProof
return true
setChanged
else
trace[simp_mem.info] "{crossEmoji} Could not prove {er.span} ⟂/⊆ {ew.span}"
return false
else
-- read
trace[simp_mem.info] "{checkEmoji} Found read {er}."
-- TODO: we don't need a separate `subset` branch for the writes: instead, for the write,
-- we can add the theorem that `(write region).read = write val`.
-- Then this generic theory will take care of it.
let changedInCurrentIter? ← withTraceNode m!"Searching for overlapping read {er.span}." do
let mut changedInCurrentIter? := false
withTraceNode m!"Searching for overlapping read {er.span}." do
for hyp in hyps do
if let Hypothesis.read_eq hReadEq := hyp then do
changedInCurrentIter? := changedInCurrentIter? ||
(← withTraceNode m!"{processingEmoji} ... ⊆ {hReadEq.read.span} ? " do
-- the read we are analyzing should be a subset of the hypothesis
let subset := (MemSubsetProp.mk er.span hReadEq.read.span)
if let some hSubsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then
trace[simp_mem.info] "{checkEmoji} ... ⊆ {hReadEq.read.span}"
MemSubsetProof.rewriteReadOfSubsetRead er hReadEq hSubsetProof
pure true
else
trace[simp_mem.info] "{crossEmoji} ... ⊊ {hReadEq.read.span}"
pure false)
pure changedInCurrentIter?
return changedInCurrentIter?
withTraceNode m!"{processingEmoji} ... ⊆ {hReadEq.read.span} ? " do
-- the read we are analyzing should be a subset of the hypothesis
let subset := (MemSubsetProp.mk er.span hReadEq.read.span)
if let some hSubsetProof ← proveWithOmega? subset (← getBvToNatSimpCtx) (← getBvToNatSimprocs) hyps then
trace[simp_mem.info] "{checkEmoji} ... ⊆ {hReadEq.read.span}"
MemSubsetProof.rewriteReadOfSubsetRead er hReadEq hSubsetProof
setChanged
else
trace[simp_mem.info] "{crossEmoji} ... ⊊ {hReadEq.read.span}"
else
if e.isForall then
Lean.Meta.forallTelescope e fun xs b => do
let mut changedInCurrentIter? := false
for x in xs do
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr x hyps)
SimpMemM.simplifyExpr x hyps
-- we may have a hypothesis like
-- ∀ (x : read_mem (read_mem_bytes ...) ... = out).
-- we want to simplify the *type* of x.
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr (← inferType x) hyps)
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr b hyps)
return changedInCurrentIter?
SimpMemM.simplifyExpr (← inferType x) hyps
SimpMemM.simplifyExpr b hyps
else if e.isLambda then
Lean.Meta.lambdaTelescope e fun xs b => do
let mut changedInCurrentIter? := false
for x in xs do
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr x hyps)
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr (← inferType x) hyps)
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr b hyps)
return changedInCurrentIter?
SimpMemM.simplifyExpr x hyps
SimpMemM.simplifyExpr (← inferType x) hyps
SimpMemM.simplifyExpr b hyps
else
-- check if we have expressions.
match e with
| .app f x =>
let mut changedInCurrentIter? := false
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr f hyps)
changedInCurrentIter? := changedInCurrentIter? || (← SimpMemM.simplifyExpr x hyps)
return changedInCurrentIter?
| _ => return false
SimpMemM.simplifyExpr f hyps
SimpMemM.simplifyExpr x hyps
| _ => return ()


/--
simplify the goal state, closing legality, subset, and separation goals,
and simplifying all other expressions. Returns `true` if an improvement has been made
in the current iteration.
-/
partial def SimpMemM.simplifyGoal (g : MVarId) (hyps : Array Memory.Hypothesis) : SimpMemM Bool := do
partial def SimpMemM.simplifyGoal (g : MVarId) (hyps : Array Memory.Hypothesis) : SimpMemM Unit := do
SimpMemM.withContext g do
let gt ← g.getType
let changedInCurrentIter? ← withTraceNode m!"Simplifying goal." do
withTraceNode m!"Simplifying goal." do
SimpMemM.simplifyExpr (← whnf gt) hyps
return changedInCurrentIter?
end

/--
Expand All @@ -309,20 +299,23 @@ partial def SimpMemM.simplifyLoop : SimpMemM Unit := do
trace[simp_mem.info] m!"{i+1}) {h}"

-- goal was not closed, try and improve.
let mut changedInAnyIter? := false
let mut everChanged := false
while true do
resetChanged
if (← outofRewriteFuel?) then break

let changedInCurrentIter? ← withTraceNode m!"Performing Rewrite At Main Goal" do
SimpMemM.simplifyGoal (← getMainGoal) foundHyps
changedInAnyIter? := changedInAnyIter? || changedInCurrentIter?
withTraceNode m!"Performing Rewrite At Main Goal" do
let _ ← SimpMemM.simplifyGoal (← getMainGoal) foundHyps
let changed := (← getThe State).changed
everChanged := everChanged || changed

if !changedInCurrentIter? then
/- we didn't change on this iteration, so we break out of the loop. -/
if !changed then
trace[simp_mem.info] "{crossEmoji} No progress made in this iteration. halting."
break

/- we haven't changed ever.. -/
if !changedInAnyIter? && (← getConfig).failIfUnchanged then
/- we haven't changed ever, so we throw an error. -/
if !everChanged && (← getConfig).failIfUnchanged then
throwError "{crossEmoji} simp_mem failed to make any progress."

/--
Expand Down

0 comments on commit eadeaf8

Please sign in to comment.