Skip to content

Commit

Permalink
fix: bring elaborator in line with kernel for primitive projections (l…
Browse files Browse the repository at this point in the history
…eanprover#5822)

The kernel supports primitive projections for all inductive types with
one construtor. The elaborator was assuming primitive projections only
work for "structure-likes", non-recursive inductive types with no
indices.

Enables numeric projection notation for general one-constructor
inductives.

Extracted from leanprover#5783.
  • Loading branch information
kmill authored Oct 31, 2024
1 parent 0c8d28e commit 03c6e99
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 30 deletions.
9 changes: 4 additions & 5 deletions src/Lean/Compiler/LCNF/InferType.lean
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,12 @@ mutual
/- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/
return erasedExpr
else
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal =>
let structTypeArgs := structType.getAppArgs
if structVal.numParams + structVal.numIndices != structTypeArgs.size then
failed ()
else do
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams])
for _ in [:idx] do
match ctorType with
| .forallE _ _ body _ =>
Expand Down
24 changes: 12 additions & 12 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1188,19 +1188,19 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
if idx == 0 then
throwError "invalid projection, index must be greater than 0"
let env ← getEnv
unless isStructureLike env structName do
throwLValError e eType "invalid projection, structure expected"
let numFields := getStructureLikeNumFields env structName
if idx - 1 < numFields then
if isStructure env structName then
let fieldNames := getStructureFields env structName
return LValResolution.projFn structName structName fieldNames[idx - 1]!
let failK _ := throwLValError e eType "invalid projection, structure expected"
matchConstStructure eType.getAppFn failK fun _ _ ctorVal => do
let numFields := ctorVal.numFields
if idx - 1 < numFields then
if isStructure env structName then
let fieldNames := getStructureFields env structName
return LValResolution.projFn structName structName fieldNames[idx - 1]!
else
/- `structName` was declared using `inductive` command.
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
return LValResolution.projIdx structName (idx - 1)
else
/- `structName` was declared using `inductive` command.
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
return LValResolution.projIdx structName (idx - 1)
else
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
| some structName, LVal.fieldName _ fieldName _ _ =>
let env ← getEnv
let searchEnv : Unit → TermElabM LValResolution := fun _ => do
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1975,7 +1975,7 @@ where
assign `?m`.
-/
return false
let ctorVal := getStructureCtor (← getEnv) structName
let some ctorVal := getStructureLikeCtor? (← getEnv) structName | return false
if ctorVal.numFields != 1 then
return false -- It is not a structure with a single field.
let sType ← whnf (← inferType s)
Expand Down Expand Up @@ -2013,7 +2013,7 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do
/-- Return `true` if the type of the given expression is an inductive datatype with a single constructor with no fields. -/
private def isDefEqUnitLike (t : Expr) (s : Expr) : MetaM Bool := do
let tType ← whnf (← inferType t)
matchConstStruct tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do
matchConstStructureLike tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do
if ctorVal.numFields != 0 then
return false
else if (← useEtaStruct ctorVal.induct) then
Expand Down
9 changes: 4 additions & 5 deletions src/Lean/Meta/InferType.lean
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,12 @@ private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Exp
let structType ← whnf structType
let failed {α} : Unit → MetaM α := fun _ =>
throwError "invalid projection{indentExpr (mkProj structName idx e)} from type {structType}"
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal =>
let structTypeArgs := structType.getAppArgs
if structVal.numParams + structVal.numIndices != structTypeArgs.size then
failed ()
else do
let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structParams
let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams]
for i in [:idx] do
ctorType ← whnf ctorType
match ctorType with
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Constructor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _root_.Lean.MVarId.existsIntro (mvarId : MVarId) (w : Expr) : MetaM MVarId :
mvarId.withContext do
mvarId.checkNotAssigned `exists
let target ← mvarId.getType'
matchConstStruct target.getAppFn
matchConstStructure target.getAppFn
(fun _ => throwTacticEx `exists mvarId "target is not an inductive datatype with one constructor")
fun _ us cval => do
if cval.numFields < 2 then
Expand Down
21 changes: 20 additions & 1 deletion src/Lean/MonadEnv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,26 @@ def getConstInfoRec [Monad m] [MonadEnv m] [MonadError m] (constName : Name) : m
| ConstantInfo.recInfo v => pure v
| _ => throwError "'{.ofConstName constName}' is not a recursor"

@[inline] def matchConstStruct [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
/--
Matches if `e` is a constant that is an inductive type with one constructor.
Such types can be used with primitive projections.
See also `Lean.matchConstStructLike` for a more restrictive version.
-/
@[inline] def matchConstStructure [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
matchConstInduct e failK fun ival us => do
match ival.ctors with
| [ctor] =>
match (← getConstInfo ctor) with
| ConstantInfo.ctorInfo cval => k ival us cval
| _ => failK ()
| _ => failK ()

/--
Matches if `e` is a constant that is an non-recursive inductive type with no indices and with one constructor.
Such a type satisfies `Lean.isStructureLike`.
See also `Lean.matchConstStructure` for a less restrictive version.
-/
@[inline] def matchConstStructureLike [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
matchConstInduct e failK fun ival us => do
if ival.isRec || ival.numIndices != 0 then failK ()
else match ival.ctors with
Expand Down
32 changes: 28 additions & 4 deletions src/Lean/Structure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,17 @@ def getStructureInfo (env : Environment) (structName : Name) : StructureInfo :=
else
panic! "structure expected"

/--
Gets the constructor of an inductive type that has exactly one constructor.
This is meant to be used with types that have had been registered as a structure by `registerStructure`,
but this is not checked.
Warning: these do *not* need to be "structure-likes". A structure-like is non-recursive,
and structure-likes have special kernel support.
-/
def getStructureCtor (env : Environment) (constName : Name) : ConstructorVal :=
match env.find? constName with
| some (.inductInfo { isRec := false, ctors := [ctorName], .. }) =>
| some (.inductInfo { ctors := [ctorName], .. }) =>
match env.find? ctorName with
| some (ConstantInfo.ctorInfo val) => val
| _ => panic! "ill-formed environment"
Expand Down Expand Up @@ -223,9 +231,10 @@ def getStructureFieldsFlattened (env : Environment) (structName : Name) (include
getStructureFieldsFlattenedAux env structName #[] includeSubobjectFields

/--
Return true if `constName` is the name of an inductive datatype
Returns true if `constName` is the name of an inductive datatype
created using the `structure` or `class` commands.
These are inductive types for which structure information has been registered with `registerStructure`.
See also `Lean.getStructureInfo?`.
-/
def isStructure (env : Environment) (constName : Name) : Bool :=
Expand Down Expand Up @@ -270,18 +279,33 @@ partial def getPathToBaseStructureAux (env : Environment) (baseStructName : Name
| some projFn => getPathToBaseStructureAux env baseStructName parentStructName (projFn :: path)

/--
If `baseStructName` is an ancestor structure for `structName`, then return a sequence of projection functions
If `baseStructName` is an ancestor structure for `structName`, then returns a sequence of projection functions
to go from `structName` to `baseStructName`.
-/
def getPathToBaseStructure? (env : Environment) (baseStructName : Name) (structName : Name) : Option (List Name) :=
getPathToBaseStructureAux env baseStructName structName []

/-- Return true iff `constName` is the a non-recursive inductive datatype that has only one constructor. -/
/--
Returns true iff `constName` is a non-recursive inductive datatype that has only one constructor and no indices.
Such types have special kernel support. This must be in sync with `is_structure_like`.
-/
def isStructureLike (env : Environment) (constName : Name) : Bool :=
match env.find? constName with
| some (.inductInfo { isRec := false, ctors := [_], numIndices := 0, .. }) => true
| _ => false

/--
Returns the constructor of the structure named `constName` if it is a non-recursive single-constructor inductive type with no indices.
-/
def getStructureLikeCtor? (env : Environment) (constName : Name) : Option ConstructorVal :=
match env.find? constName with
| some (.inductInfo { isRec := false, ctors := [ctorName], numIndices := 0, .. }) =>
match env.find? ctorName with
| some (ConstantInfo.ctorInfo val) => val
| _ => panic! "ill-formed environment"
| _ => none

/-- Return number of fields for a structure-like type -/
def getStructureLikeNumFields (env : Environment) (constName : Name) : Nat :=
match env.find? constName with
Expand Down
56 changes: 56 additions & 0 deletions tests/lean/run/inductive_rec_proj.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/-!
# Tests for numeric projections of inductive types
-/

/-!
Non-recursive, no indices.
-/
inductive I0 where
| mk (x : Nat) (xs : List Nat)
/-- info: fun v => v.1 : I0 → Nat -/
#guard_msgs in #check fun (v : I0) => v.1
/-- info: fun v => v.2 : I0 → List Nat -/
#guard_msgs in #check fun (v : I0) => v.2

/-!
Recursive, no indices.
-/
inductive I1 where
| mk (x : Nat) (xs : I1)
/-- info: fun v => v.1 : I1 → Nat -/
#guard_msgs in #check fun (v : I1) => v.1
/-- info: fun v => v.2 : I1 → I1 -/
#guard_msgs in #check fun (v : I1) => v.2

/-!
Non-recursive, indices.
-/
inductive I2 : Nat → Type where
| mk (x : Nat) (xs : List (Fin x)) : I2 (x + 1)
/-- info: fun v => v.1 : I2 2 → Nat -/
#guard_msgs in #check fun (v : I2 2) => v.1
/-- info: fun v => v.2 : (v : I2 2) → List (Fin v.1) -/
#guard_msgs in #check fun (v : I2 2) => v.2

/-!
Recursive, indices.
-/
inductive I3 : Nat → Type where
| mk (x : Nat) (xs : I3 (x + 1)) : I3 x
/-- info: fun v => v.1 : I3 2 → Nat -/
#guard_msgs in #check fun (v : I3 2) => v.1
/-- info: fun v => v.2 : (v : I3 2) → I3 (v.1 + 1) -/
#guard_msgs in #check fun (v : I3 2) => v.2


/-!
Make sure these can be compiled.
-/
def f0_1 (v : I0) : Nat := v.1
def f0_2 (v : I0) : List Nat := v.2
def f1_1 (v : I1) : Nat := v.1
def f1_2 (v : I1) : I1 := v.2
def f2_1 (v : I2 n) : Nat := v.1
def f2_2 (v : I2 n) : List (Fin (f2_1 v)) := v.2
def f3_1 (v : I3 n) : Nat := v.1
def f3_2 (v : I3 n) : I3 (f3_1 v + 1) := v.2

0 comments on commit 03c6e99

Please sign in to comment.