Skip to content

Commit

Permalink
RPINF: optimise pinfCore isProof checking
Browse files Browse the repository at this point in the history
  • Loading branch information
JLimperg committed Nov 22, 2024
1 parent 8a577fe commit d1344b4
Showing 1 changed file with 66 additions and 14 deletions.
80 changes: 66 additions & 14 deletions Aesop/RPINF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,50 @@ instance [Monad m] [MonadRPINF m] : MonadHashMapCacheAdapter Expr Expr m where

abbrev RPINFT m [STWorld ω m] := StateRefT RPINFCache m

/-- Given a type `t = ∀ (x₁ : T₁) ... (xₙ : Tₙ), U` and arguments `args` that
are type-correct when applied to a function of type `t`, returns the types of
the `args`. These are `T₁[args[0]]`, `T₂[args[0], args[1]]`, ...,
`Tₙ[args[0], ..., args[n]]`. -/
partial def getArgTypes (t : Expr) (args : Array Expr) : MetaM (Array Expr) :=
go (Array.mkEmpty args.size) 0 t
where
go (acc : Array Expr) (n : Nat) (t : Expr) : MetaM (Array Expr) := do
if n == args.size then
return acc
match t with
| .forallE _ t b _ =>
let t := t.instantiateRevRange 0 n args
go (acc.push t) (n + 1) b
| _ =>
let t ← withAtLeastTransparency .default $ whnf t
if t.isForall then
go acc n t
else
panic! "not enough forall binders"

variable [Monad m] [MonadRPINF m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
[MonadMCtx m] [MonadLiftT (ST IO.RealWorld) m] [MonadError m] [MonadRecDepth m]
[MonadLiftT BaseIO m]

-- `type?` may be `none` iff `e` is a type.
@[specialize]
partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) : m Expr :=
partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) (type? : Option Expr) : m Expr :=
withIncRecDepth do
checkCache e λ _ => do
let (isPrf, nanos) ← time $ withDefault $ isProof e
statsRef.modify (· + nanos)
if isPrf then
return .mdata (mdataSetIsProof {}) e
dbg_trace e
let e ← whnf e
match e with
| .app .. =>
let f ← pinfCore statsRef e.getAppFn'
ifNotProof do
let f := e.getAppFn'
let fType ← inferType f
let f ← pinfCore statsRef f fType
let mut args := e.getAppArgs'
let argTypes ← getArgTypes fType args
for i in [:args.size] do
let arg := args[i]!
args := args.set! i default -- prevent nonlinear access to args[i]
let arg ← pinfCore statsRef arg
let arg ← pinfCore statsRef arg argTypes[i]!
args := args.set! i arg
if f.isConstOf ``Nat.succ && args.size == 1 && args[0]!.isRawNatLit then
return mkRawNatLit (args[0]!.rawNatLit?.get! + 1)
Expand All @@ -123,14 +146,22 @@ partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) : m Expr :=
| .lam .. =>
-- TODO disable cache?
lambdaTelescope e λ xs e => withNewFVars xs do
mkLambdaFVars xs (← pinfCore statsRef e)
let eType? ← type?.mapM λ type =>
withAtLeastTransparency .default $ instantiateForall type xs
mkLambdaFVars xs (← pinfCore statsRef e eType?)
| .forallE .. =>
-- TODO disable cache?
-- A `forall` is necessarily a type, and so is its conclusion.
forallTelescope e λ xs e => withNewFVars xs do
mkForallFVars xs (← pinfCore statsRef e)
mkForallFVars xs (← pinfCore statsRef e none)
| .proj t i e =>
return .proj t i (← pinfCore statsRef e)
| .sort .. | .mvar .. | .lit .. | .const .. | .fvar .. =>
ifNotProof do
return .proj t i (← pinfCore statsRef e (← inferType e))
| .mvar .. | .const .. | .fvar .. =>
ifNotProof do
return e
| .sort .. | .lit .. =>
-- These cannot be proofs.
return e
| .letE .. | .mdata .. | .bvar .. => unreachable!
where
Expand All @@ -139,15 +170,36 @@ where
for fvar in fvars do
let fvarId := fvar.fvarId!
let ldecl ← fvarId.getDecl
let ldecl := ldecl.setType $ ← pinfCore statsRef ldecl.type
let ldecl := ldecl.setType $ ← pinfCore statsRef ldecl.type none
lctx := lctx.modifyLocalDecl fvarId λ _ => ldecl
withLCtx lctx (← getLocalInstances) k

isPropD (e : Expr) : MetaM Bool :=
-- withAtLeastTransparency .default $ isProp e -- TODO test perf impact
isProp e

ifNotProof (k : m Expr) : m Expr := do
if let some type := type? then
if ← isPropD type then
return .mdata (mdataSetIsProof {}) e
k

def pinf (statsRef : IO.Ref Nanos) (e : Expr) : m Expr := do
pinfCore statsRef (← instantiateMVars e)
pinfCore statsRef (← instantiateMVars e) (← getType? e)
where
-- Returns the type of `e`, or `none` if `e` is a type.
getType? (e : Expr) : MetaM (Option Expr) := do
match ← isTypeQuick e with
| .true => pure none
| .false => inferType e
| .undef =>
let type ← inferType e
match type with
| .sort .. => pure none
| _ => pure $ some type

def pinf' (statsRef : IO.Ref Nanos) (e : Expr) : MetaM Expr := do
(pinfCore statsRef (← instantiateMVars e) : RPINFT MetaM _).run' {}
(pinf statsRef e : RPINFT MetaM _).run' {}

def rpinfExpr (statsRef : IO.Ref Nanos) (e : Expr) : m Expr :=
withReducible $ pinf statsRef e
Expand Down

0 comments on commit d1344b4

Please sign in to comment.