Skip to content

Commit

Permalink
feat: resolve generalized field notation using all parents
Browse files Browse the repository at this point in the history
  • Loading branch information
kmill committed Oct 31, 2024
1 parent 0fcee10 commit 63340ff
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 59 deletions.
61 changes: 32 additions & 29 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1135,24 +1135,29 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE
throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}"

/--
`findMethod? env S fName`.
- If `env` contains `S ++ fName`, return `(S, S++fName)`
- Otherwise if `env` contains private name `prv` for `S ++ fName`, return `(S, prv)`, o
- Otherwise for each parent structure `S'` of `S`, we try `findMethod? env S' fname`
`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`:
- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)`
- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)`
-/
private partial def findMethod? (env : Environment) (structName fieldName : Name) : Option (Name × Name) :=
let fullName := structName ++ fieldName
match env.find? fullName with
| some _ => some (structName, fullName)
| none =>
private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do
let env ← getEnv
let find? structName' : MetaM (Option (Name × Name)) := do
let fullName := structName' ++ fieldName
if env.contains fullName then
return some (structName', fullName)
let fullNamePrv := mkPrivateName env fullName
match env.find? fullNamePrv with
| some _ => some (structName, fullNamePrv)
| none =>
if isStructure env structName then
(getStructureSubobjects env structName).findSome? fun parentStructName => findMethod? env parentStructName fieldName
else
none
if env.contains fullNamePrv then
return some (structName', fullNamePrv)
return none
-- Optimization: the first element of the resolution order is `structName`,
-- so we can skip computing the resolution order in the common case
-- of the name resolving in the `structName` namespace.
find? structName <||> do
let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName]
for h : i in [1:resolutionOrder.size] do
if let some res ← find? resolutionOrder[i] then
return res
return none

/--
Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and
Expand Down Expand Up @@ -1204,7 +1209,7 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
| some structName, LVal.fieldName _ fieldName _ _ =>
let env ← getEnv
let searchEnv : Unit → TermElabM LValResolution := fun _ => do
if let some (baseStructName, fullName) := findMethod? env structName (.mkSimple fieldName) then
if let some (baseStructName, fullName) findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then
return LValResolution.const structName' structName' fullName
Expand Down Expand Up @@ -1390,19 +1395,17 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
loop f lvals
| LValResolution.projFn baseStructName structName fieldName =>
let f ← mkBaseProjections baseStructName structName f
if let some info := getFieldInfo? (← getEnv) baseStructName fieldName then
if isPrivateNameFromImportedModule (← getEnv) info.projFn then
throwError "field '{fieldName}' from structure '{structName}' is private"
let projFn ← mkConst info.projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f, suppressDeps := true }
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
let some info := getFieldInfo? (← getEnv) baseStructName fieldName | unreachable!
if isPrivateNameFromImportedModule (← getEnv) info.projFn then
throwError "field '{fieldName}' from structure '{structName}' is private"
let projFn ← mkConst info.projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f, suppressDeps := true }
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
unreachable!
let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.const baseStructName structName constName =>
let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f
let projFn ← mkConst constName
Expand Down
30 changes: 28 additions & 2 deletions src/Lean/Elab/Structure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ namespace Lean.Elab.Command

register_builtin_option structureDiamondWarning : Bool := {
defValue := false
descr := "enable/disable warning messages for structure diamonds"
descr := "if true, enable warnings when a structure has diamond inheritance"
}

register_builtin_option structure.strictResolutionOrder : Bool := {
defValue := false
descr := "if true, require a strict resolution order for structures"
}

open Meta
Expand Down Expand Up @@ -943,6 +948,23 @@ private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : L
instantiateMVars (← mkForallFVars params type)
return { name := view.declName, type := ← instantiateMVars type, ctors := [{ ctor with type := ← instantiateMVars ctorType }] }

/--
Precomputes the structure's resolution order.
Option `structure.strictResolutionOrder` controls whether to create a warning if the C3 algorithm failed.
-/
private def checkResolutionOrder (structName : Name) : TermElabM Unit := do
let resolutionOrderResult ← computeStructureResolutionOrder structName (relaxed := !structure.strictResolutionOrder.get (← getOptions))
trace[Elab.structure.resolutionOrder] "computed resolution order: {resolutionOrderResult.resolutionOrder}"
unless resolutionOrderResult.conflicts.isEmpty do
let mut defects : List MessageData := []
for conflict in resolutionOrderResult.conflicts do
let parentKind direct := if direct then "parent" else "indirect parent"
let conflicts := conflict.conflicts.map fun (isDirect, name) =>
m!"{parentKind isDirect} '{MessageData.ofConstName name}'"
defects := m!"- {parentKind conflict.isDirectParent} '{MessageData.ofConstName conflict.badParent}' \
must come after {MessageData.andList conflicts.toList}" :: defects
logWarning m!"failed to compute strict resolution order:\n{MessageData.joinSep defects.reverse "\n"}"

def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do
let scopeLevelNames ← Term.getLevelNames
let isUnsafe := view.modifiers.isUnsafe
Expand Down Expand Up @@ -1008,6 +1030,8 @@ def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit :=
else
mkCoercionToCopiedParent levelParams params view parent.structName parent.type
setStructureParents view.declName parentInfos
checkResolutionOrder view.declName

let lctx ← getLCtx
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
Expand Down Expand Up @@ -1045,6 +1069,8 @@ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit :=
pure view
elabStructureViewPostprocessing view

builtin_initialize registerTraceClass `Elab.structure
builtin_initialize
registerTraceClass `Elab.structure
registerTraceClass `Elab.structure.resolutionOrder

end Lean.Elab.Command
2 changes: 1 addition & 1 deletion src/Lean/Server/Completion.lean
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ where
let .const typeName _ := type.getAppFn | return ()
modify fun s => s.insert typeName
if isStructure (← getEnv) typeName then
for parentName in getAllParentStructures (← getEnv) typeName do
for parentName in (← getAllParentStructures typeName) do
modify fun s => s.insert parentName
let some type ← unfoldeDefinitionGuarded? type | return ()
visit type
Expand Down
170 changes: 145 additions & 25 deletions src/Lean/Structure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def getStructureCtor (env : Environment) (constName : Name) : ConstructorVal :=
def getStructureFields (env : Environment) (structName : Name) : Array Name :=
(getStructureInfo env structName).fieldNames

/-- Get the `StructureFieldInfo` for the given direct field of the structure. -/
def getFieldInfo? (env : Environment) (structName : Name) (fieldName : Name) : Option StructureFieldInfo :=
if let some info := getStructureInfo? env structName then
info.fieldInfo.binSearch { fieldName := fieldName, projFn := default, subobject? := none, binderInfo := default } StructureFieldInfo.lt
Expand All @@ -180,21 +181,7 @@ If a direct parent cannot itself be represented as a subobject,
sometimes one of its parents (or one of their parents, etc.) can.
-/
def getStructureSubobjects (env : Environment) (structName : Name) : Array Name :=
let fieldNames := getStructureFields env structName;
fieldNames.foldl (init := #[]) fun acc fieldName =>
match isSubobjectField? env structName fieldName with
| some parentStructName => acc.push parentStructName
| none => acc

-- TODO: use actual parents, not just subobjects.
/-- Return all parent structures -/
partial def getAllParentStructures (env : Environment) (structName : Name) : Array Name :=
visit structName |>.run #[] |>.2
where
visit (structName : Name) : StateT (Array Name) Id Unit := do
for p in getStructureSubobjects env structName do
modify fun s => s.push p
visit p
(getStructureFields env structName).filterMap (isSubobjectField? env structName)

/--
Return the name of the structure that contains the field relative to structure `structName`.
Expand Down Expand Up @@ -269,18 +256,23 @@ partial def getPathToBaseStructureAux (env : Environment) (baseStructName : Name
if baseStructName == structName then
some path.reverse
else
let fieldNames := getStructureFields env structName;
fieldNames.findSome? fun fieldName =>
match isSubobjectField? env structName fieldName with
| none => none
| some parentStructName =>
match getProjFnForField? env structName fieldName with
| none => none
| some projFn => getPathToBaseStructureAux env baseStructName parentStructName (projFn :: path)
if let some info := getStructureInfo? env structName then
-- Prefer subobject projections
(info.fieldInfo.findSome? fun field =>
match field.subobject? with
| none => none
| some parentStructName => getPathToBaseStructureAux env baseStructName parentStructName (field.projFn :: path))
-- Otherwise, consider other parents
<|> info.parentInfo.findSome? fun parent =>
if parent.subobject then
none
else
getPathToBaseStructureAux env baseStructName parent.structName (parent.projFn :: path)
else none

/--
If `baseStructName` is an ancestor structure for `structName`, then returns a sequence of projection functions
to go from `structName` to `baseStructName`.
If `baseStructName` is an ancestor structure for `structName`, then return a sequence of projection functions
to go from `structName` to `baseStructName`. Returns `[]` if `baseStructName == structName`.
-/
def getPathToBaseStructure? (env : Environment) (baseStructName : Name) (structName : Name) : Option (List Name) :=
getPathToBaseStructureAux env baseStructName structName []
Expand Down Expand Up @@ -315,4 +307,132 @@ def getStructureLikeNumFields (env : Environment) (constName : Name) : Nat :=
| _ => 0
| _ => 0

/-!
### Resolution orders
This section is for computations to determine which namespaces to visit when resolving field notation.
While the set of namespaces is clear (after a structure's namespace, it is the namespaces for *all* parents),
the question is the order to visit them in.
We use the C3 superclass linearization algorithm from Barrett et al., "A Monotonic Superclass Linearization for Dylan", OOPSLA 1996.
For reference, the C3 linearization is known as the "method resolution order" (MRO) [in Python](https://docs.python.org/3/howto/mro.html).
The basic idea is that we want to find a resolution order with the following property:
For each structure `S` that appears in the resolution order, if its direct parents are `P₁ .. Pₙ`,
then `S P₁ ... Pₙ` forms a subsequence of the resolution order.
This has a stability property where if `S` extends `S'`, then the resolution order of `S` contains the resolution order of `S'` as a subsequence.
It also has the key property that if `P` and `P'` are parents of `S`, then we visit `P` and `P'` before we visit the shared parents of `P` and `P'`.
Finding such a resolution order might not be possible.
Still, we can enable a relaxation of the algorithm by ignoring one or more parent resolution orders, starting from the end.
In Hivert and Thiéry "Controlling the C3 super class linearization algorithm for large hierarchies of classes"
https://arxiv.org/pdf/2401.12740 the authors discuss how in SageMath, which has thousands of classes,
C3 can be difficult to control, since maintaining correct direct parent orders is a burden.
They give suggestions that have worked for the SageMath project.
We may consider introducing an environment extension with ordering hints to help guide the algorithm if we see similar difficulties.
-/

structure StructureResolutionState where
resolutions : PHashMap Name (Array Name) := {}
deriving Inhabited

/--
We use an environment extension to cache resolution orders.
These are not expensive to compute, but worth caching, and we save olean storage space.
-/
builtin_initialize structureResolutionExt : EnvExtension StructureResolutionState ←
registerEnvExtension (pure {})

/-- Gets the resolution order if it has already been cached. -/
private def getStructureResolutionOrder? (env : Environment) (structName : Name) : Option (Array Name) :=
(structureResolutionExt.getState env).resolutions.find? structName

/-- Caches a structure's resolution order. -/
private def setStructureResolutionOrder [MonadEnv m] (structName : Name) (resolutionOrder : Array Name) : m Unit :=
modifyEnv fun env => structureResolutionExt.modifyState env fun s =>
{ s with resolutions := s.resolutions.insert structName resolutionOrder }

/-- "The `badParent` must come after the `conflicts`. -/
structure StructureResolutionOrderConflict where
isDirectParent : Bool
badParent : Name
/-- Conflicts that must come before `badParent`. The flag is whether it is a direct parent. -/
conflicts : Array (Bool × Name)
deriving Inhabited

structure StructureResolutionOrderResult where
resolutionOrder : Array Name
conflicts : Array StructureResolutionOrderConflict := #[]
deriving Inhabited

/--
Computes and caches the C3 linearization. Assumes parents have already been set with `setStructureParents`.
If `relaxed` is false, then if the linearization cannot be computed, conflicts are recorded in the return value.
-/
partial def computeStructureResolutionOrder [Monad m] [MonadEnv m]
(structName : Name) (relaxed : Bool) : m StructureResolutionOrderResult := do
let env ← getEnv
if let some resOrder := getStructureResolutionOrder? env structName then
return { resolutionOrder := resOrder }
let parentNames := getStructureParentInfo env structName |>.map (·.structName)
-- Don't be strict about parents: if they were supposed to be checked, they were already checked.
let parentResOrders ← parentNames.mapM fun parentName => return (← computeStructureResolutionOrder parentName true).resolutionOrder

-- `resOrders` contains the resolution orders to merge.
-- The parent list is inserted as a pseudo resolution order to ensure immediate parents come out in order,
-- and it is added first to be the primary ordering constraint when there are ordering errors.
let mut resOrders := parentResOrders.insertAt 0 parentNames |>.filter (!·.isEmpty)

let mut resOrder : Array Name := #[structName]
let mut defects : Array StructureResolutionOrderConflict := #[]
-- Every iteration of the loop, the sum of the sizes of the arrays in `resOrders` decreases by at least one,
-- so it terminates.
while !resOrders.isEmpty do
let (good, name) ← selectParent resOrders

unless good || relaxed do
let conflicts := resOrders |>.filter (·[1:].any (· == name)) |>.map (·[0]!) |>.qsort Name.lt |>.eraseReps
defects := defects.push {
isDirectParent := parentNames.contains name
badParent := name
conflicts := conflicts.map fun c => (parentNames.contains c, c)
}

resOrder := resOrder.push name
resOrders := resOrders
|>.map (fun resOrder => resOrder.filter (· != name))
|>.filter (!·.isEmpty)

setStructureResolutionOrder structName resOrder
return { resolutionOrder := resOrder, conflicts := defects }
where
selectParent (resOrders : Array (Array Name)) : m (Bool × Name) := do
-- Assumption: every resOrder is nonempty.
-- `n'` is for relaxation, to stop paying attention to end of `resOrders` when finding a good parent.
for n' in [0 : resOrders.size] do
let hi := resOrders.size - n'
for i in [0 : hi] do
let parent := resOrders[i]![0]!
let consistent resOrder := resOrder[1:].all (· != parent)
if resOrders[0:i].all consistent && resOrders[i+1:hi].all consistent then
return (n' == 0, parent)
-- unreachable, but correct default:
return (false, resOrders[0]![0]!)

/--
Gets the resolution order for a structure.
-/
def getStructureResolutionOrder [Monad m] [MonadEnv m]
(structName : Name) : m (Array Name) :=
(·.resolutionOrder) <$> computeStructureResolutionOrder structName (relaxed := true)

/--
Returns the transitive closure of all parent structures of the structure.
This is the same as `Lean.getStructureResolutionOrder` but without including `structName`.
-/
partial def getAllParentStructures [Monad m] [MonadEnv m] (structName : Name) : m (Array Name) :=
(·.erase structName) <$> getStructureResolutionOrder structName

end Lean
Loading

0 comments on commit 63340ff

Please sign in to comment.