From d52712112539ca21b7582d4aac947ab4fbad8393 Mon Sep 17 00:00:00 2001 From: Jannis Limperg Date: Fri, 22 Nov 2024 17:42:36 +0100 Subject: [PATCH] RPINF: optimise pinfCore isProof checking --- Aesop/RPINF.lean | 83 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/Aesop/RPINF.lean b/Aesop/RPINF.lean index 02212543..f353ab04 100644 --- a/Aesop/RPINF.lean +++ b/Aesop/RPINF.lean @@ -94,27 +94,53 @@ instance [Monad m] [MonadRPINF m] : MonadHashMapCacheAdapter Expr Expr m where abbrev RPINFT m [STWorld ω m] := StateRefT RPINFCache m +def ensureNForallBinders (n : Nat) (e : Expr) : MetaM Expr := + withAtLeastTransparency .default do + forallBoundedTelescope e n mkForallFVars + +/-- Given a type `t = ∀ (x₁ : T₁) ... (xₙ : Tₙ), U` and arguments `args` that +are type-correct when applied to a function of type `t`, returns the types of +the `args`. These are `T₁[args[0]]`, `T₂[args[0], args[1]]`, ..., +`Tₙ[args[0], ..., args[n]]`. -/ +partial def getArgTypes (t : Expr) (args : Array Expr) : MetaM (Array Expr) := do + go (Array.mkEmpty args.size) 0 (← ensureNForallBinders args.size t) +where + go (acc : Array Expr) (n : Nat) (t : Expr) : MetaM (Array Expr) := do + if n == args.size then + return acc + match t with + | .forallE _ t b _ => + let t := t.instantiateRevRange 0 n args + go (acc.push t) (n + 1) b + | _ => panic! "not enough forall binders" + variable [Monad m] [MonadRPINF m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadMCtx m] [MonadLiftT (ST IO.RealWorld) m] [MonadError m] [MonadRecDepth m] [MonadLiftT BaseIO m] +-- `type?` may be `none` iff `e` is a type. @[specialize] -partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) : m Expr := +partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) (type? : Option Expr) : m Expr := withIncRecDepth do checkCache e λ _ => do - let (isPrf, nanos) ← time $ withDefault $ isProof e - statsRef.modify (· + nanos) - if isPrf then - return .mdata (mdataSetIsProof {}) e + if e.hasLooseBVars then + throwError "loose bvar in e:{indentD $ toString e}" + if let some type := type? then + if type.hasLooseBVars then + throwError "loose bvar in type:{indentD $ toString type}" let e ← whnf e match e with | .app .. => - let f ← pinfCore statsRef e.getAppFn' + ifNotProof do + let f := e.getAppFn' + let fType ← inferType f + let f ← pinfCore statsRef f fType let mut args := e.getAppArgs' + let argTypes ← getArgTypes fType args for i in [:args.size] do let arg := args[i]! args := args.set! i default -- prevent nonlinear access to args[i] - let arg ← pinfCore statsRef arg + let arg ← pinfCore statsRef arg argTypes[i]! args := args.set! i arg if f.isConstOf ``Nat.succ && args.size == 1 && args[0]!.isRawNatLit then return mkRawNatLit (args[0]!.rawNatLit?.get! + 1) @@ -123,14 +149,22 @@ partial def pinfCore (statsRef : IO.Ref Nanos) (e : Expr) : m Expr := | .lam .. => -- TODO disable cache? lambdaTelescope e λ xs e => withNewFVars xs do - mkLambdaFVars xs (← pinfCore statsRef e) + let eType? ← type?.mapM λ type => + withAtLeastTransparency .default $ instantiateForall type xs + mkLambdaFVars xs (← pinfCore statsRef e eType?) | .forallE .. => -- TODO disable cache? + -- A `forall` is necessarily a type, and so is its conclusion. forallTelescope e λ xs e => withNewFVars xs do - mkForallFVars xs (← pinfCore statsRef e) + mkForallFVars xs (← pinfCore statsRef e none) | .proj t i e => - return .proj t i (← pinfCore statsRef e) - | .sort .. | .mvar .. | .lit .. | .const .. | .fvar .. => + ifNotProof do + return .proj t i (← pinfCore statsRef e (← inferType e)) + | .mvar .. | .const .. | .fvar .. => + ifNotProof do + return e + | .sort .. | .lit .. => + -- These cannot be proofs. return e | .letE .. | .mdata .. | .bvar .. => unreachable! where @@ -139,15 +173,36 @@ where for fvar in fvars do let fvarId := fvar.fvarId! let ldecl ← fvarId.getDecl - let ldecl := ldecl.setType $ ← pinfCore statsRef ldecl.type + let ldecl := ldecl.setType $ ← pinfCore statsRef ldecl.type none lctx := lctx.modifyLocalDecl fvarId λ _ => ldecl withLCtx lctx (← getLocalInstances) k + isPropD (e : Expr) : MetaM Bool := + -- withAtLeastTransparency .default $ isProp e -- TODO test perf impact + isProp e + + ifNotProof (k : m Expr) : m Expr := do + if let some type := type? then + if ← isPropD type then + return .mdata (mdataSetIsProof {}) e + k + def pinf (statsRef : IO.Ref Nanos) (e : Expr) : m Expr := do - pinfCore statsRef (← instantiateMVars e) + pinfCore statsRef (← instantiateMVars e) (← getType? e) +where + -- Returns the type of `e`, or `none` if `e` is a type. + getType? (e : Expr) : MetaM (Option Expr) := do + match ← isTypeQuick e with + | .true => pure none + | .false => inferType e + | .undef => + let type ← inferType e + match type with + | .sort .. => pure none + | _ => pure $ some type def pinf' (statsRef : IO.Ref Nanos) (e : Expr) : MetaM Expr := do - (pinfCore statsRef (← instantiateMVars e) : RPINFT MetaM _).run' {} + (pinf statsRef e : RPINFT MetaM _).run' {} def rpinfExpr (statsRef : IO.Ref Nanos) (e : Expr) : m Expr := withReducible $ pinf statsRef e