Skip to content

Commit 938a0a3

Browse files
authored
Systematically recompute denotations when needed in rechecker phases (#24302)
Fixes #23582 The first commit was a quick hack. The second commit implements a more principled solution.
2 parents 13372c9 + 1d8ce4d commit 938a0a3

File tree

10 files changed

+188
-77
lines changed

10 files changed

+188
-77
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,19 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte
6262
case _ => false
6363

6464
override def mapWith(tm: TypeMap)(using Context) =
65-
val elems = refs.elems.toList
66-
val elems1 = elems.mapConserve(tm.mapCapability(_))
67-
if elems1 eq elems then this
68-
else if elems1.forall:
69-
case elem1: Capability => elem1.isWellformed
70-
case _ => false
71-
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed)
72-
else EmptyAnnotation
65+
if ctx.phase.id > Phases.checkCapturesPhase.id then
66+
// Annotation is no longer relevant, can be dropped.
67+
// This avoids running into illegal states in mapCapability.
68+
EmptyAnnotation
69+
else
70+
val elems = refs.elems.toList
71+
val elems1 = elems.mapConserve(tm.mapCapability(_))
72+
if elems1 eq elems then this
73+
else if elems1.forall:
74+
case elem1: Capability => elem1.isWellformed
75+
case _ => false
76+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[Capability]]*), boxed)
77+
else EmptyAnnotation
7378

7479
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
7580
refs.elems.exists {

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,28 @@ class CheckCaptures extends Recheck, SymTransformer:
246246

247247
override def isRunnable(using Context) = super.isRunnable && Feature.ccEnabledSomewhere
248248

249+
/** We normally need a recompute if the prefix is a SingletonType and the
250+
* last denotation is not a SymDenotation. The SingletonType requirement is
251+
* so that we don't widen TermRefs with non-path prefixes to their underlying
252+
* type when recomputing their denotations with asSeenFrom. Such widened types
253+
* would become illegal members of capture sets.
254+
*
255+
* The SymDenotation requirement is so that we don't recompute termRefs of Symbols
256+
* which should be handled by SymTransformers alone. However, if the underlying type
257+
* of the prefix is a capturing type, we do need to recompute since in that case
258+
* the prefix might carry a parameter refinement created in Setup, and we need to
259+
* take these refinements into account.
260+
*/
261+
override def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean =
262+
tp.prefix match
263+
case prefix: TermRef =>
264+
!lastDenotation.isInstanceOf[SymDenotation]
265+
|| !prefix.info.captureSet.isAlwaysEmpty
266+
case prefix: SingletonType =>
267+
!lastDenotation.isInstanceOf[SymDenotation]
268+
case _ =>
269+
false
270+
249271
def newRechecker()(using Context) = CaptureChecker(ctx)
250272

251273
override def run(using Context): Unit =
@@ -694,12 +716,6 @@ class CheckCaptures extends Recheck, SymTransformer:
694716
markFree(ref.readOnly, tree)
695717
else
696718
val sel = ref.select(pt.select.symbol).asInstanceOf[TermRef]
697-
sel.recomputeDenot()
698-
// We need to do a recomputeDenot here since we have not yet properly
699-
// computed the type of the full path. This means that we erroneously
700-
// think the denotation is the same as in the previous phase so no
701-
// member computation is performed. A test case where this matters is
702-
// read-only-use.scala, where the error on r3 goes unreported.
703719
markPathFree(sel, pt.pt, pt.select)
704720
case _ =>
705721
markFree(ref.adjustReadOnly(pt), tree)
@@ -1114,11 +1130,11 @@ class CheckCaptures extends Recheck, SymTransformer:
11141130
if sym.is(Module) then sym.info // Modules are checked by checking the module class
11151131
else
11161132
if sym.is(Mutable) && !sym.hasAnnotation(defn.UncheckedCapturesAnnot) then
1117-
val addendum = capturedBy.get(sym) match
1133+
val addendum = setup.capturedBy.get(sym) match
11181134
case Some(encl) =>
11191135
val enclStr =
11201136
if encl.isAnonymousFunction then
1121-
val location = anonFunCallee.get(encl) match
1137+
val location = setup.anonFunCallee.get(encl) match
11221138
case Some(meth) if meth.exists => i" argument in a call to $meth"
11231139
case _ => ""
11241140
s"an anonymous function$location"
@@ -1943,49 +1959,12 @@ class CheckCaptures extends Recheck, SymTransformer:
19431959
traverseChildren(t)
19441960
end checkOverrides
19451961

1946-
/** Used for error reporting:
1947-
* Maps mutable variables to the symbols that capture them (in the
1948-
* CheckCaptures sense, i.e. symbol is referred to from a different method
1949-
* than the one it is defined in).
1950-
*/
1951-
private val capturedBy = util.HashMap[Symbol, Symbol]()
1952-
1953-
/** Used for error reporting:
1954-
* Maps anonymous functions appearing as function arguments to
1955-
* the function that is called.
1956-
*/
1957-
private val anonFunCallee = util.HashMap[Symbol, Symbol]()
1958-
1959-
/** Used for error reporting:
1960-
* Populates `capturedBy` and `anonFunCallee`. Called by `checkUnit`.
1961-
*/
1962-
private def collectCapturedMutVars(using Context) = new TreeTraverser:
1963-
def traverse(tree: Tree)(using Context) = tree match
1964-
case id: Ident =>
1965-
val sym = id.symbol
1966-
if sym.isMutableVar && sym.owner.isTerm then
1967-
val enclMeth = ctx.owner.enclosingMethod
1968-
if sym.enclosingMethod != enclMeth then
1969-
capturedBy(sym) = enclMeth
1970-
case Apply(fn, args) =>
1971-
for case closureDef(mdef) <- args do
1972-
anonFunCallee(mdef.symbol) = fn.symbol
1973-
traverseChildren(tree)
1974-
case Inlined(_, bindings, expansion) =>
1975-
traverse(bindings)
1976-
traverse(expansion)
1977-
case mdef: DefDef =>
1978-
if !mdef.symbol.isInlineMethod then traverseChildren(tree)
1979-
case _ =>
1980-
traverseChildren(tree)
1981-
19821962
private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup]
19831963

19841964
override def checkUnit(unit: CompilationUnit)(using Context): Unit =
19851965
capt.println(i"cc check ${unit.source}")
19861966
ccState.start()
19871967
setup.setupUnit(unit.tpdTree, this)
1988-
collectCapturedMutVars.traverse(unit.tpdTree)
19891968

19901969
if ctx.settings.YccPrintSetup.value then
19911970
val echoHeader = "[[syntax tree at end of cc setup]]"

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ trait SetupAPI:
4040
/** Check to do after the capture checking traversal */
4141
def postCheck()(using Context): Unit
4242

43+
/** Used for error reporting:
44+
* Maps mutable variables to the symbols that capture them (in the
45+
* CheckCaptures sense, i.e. symbol is referred to from a different method
46+
* than the one it is defined in).
47+
*/
48+
def capturedBy: collection.Map[Symbol, Symbol]
49+
50+
/** Used for error reporting:
51+
* Maps anonymous functions appearing as function arguments to
52+
* the function that is called.
53+
*/
54+
def anonFunCallee: collection.Map[Symbol, Symbol]
55+
end SetupAPI
56+
4357
object Setup:
4458

4559
val name: String = "setupCC"
@@ -518,6 +532,18 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
518532

519533
def traverse(tree: Tree)(using Context): Unit =
520534
tree match
535+
case tree: Ident =>
536+
val sym = tree.symbol
537+
if sym.isMutableVar && sym.owner.isTerm then
538+
val enclMeth = ctx.owner.enclosingMethod
539+
if sym.enclosingMethod != enclMeth then
540+
capturedBy(sym) = enclMeth
541+
542+
case Apply(fn, args) =>
543+
for case closureDef(mdef) <- args do
544+
anonFunCallee(mdef.symbol) = fn.symbol
545+
traverseChildren(tree)
546+
521547
case tree @ DefDef(_, paramss, tpt: TypeTree, _) =>
522548
val meth = tree.symbol
523549
if isExcluded(meth) then
@@ -567,9 +593,12 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
567593
traverse(body)
568594
catches.foreach(traverse)
569595
traverse(finalizer)
596+
570597
case tree: New =>
598+
571599
case _ =>
572600
traverseChildren(tree)
601+
573602
postProcess(tree)
574603
checkProperUseOrConsume(tree)
575604
end traverse
@@ -889,11 +918,16 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
889918
else t
890919
case _ => mapFollowingAliases(t)
891920

921+
val capturedBy: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()
922+
923+
val anonFunCallee: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()
924+
892925
/** Run setup on a compilation unit with given `tree`.
893926
* @param recheckDef the function to run for completing a val or def
894927
*/
895928
def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit =
896-
setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase))
929+
inContext(ctx.withPhase(thisPhase)):
930+
setupTraverser(checker).traverse(tree)
897931

898932
// ------ Checks to run at Setup ----------------------------------------
899933

compiler/src/dotty/tools/dotc/core/ContextOps.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,8 @@ object ContextOps:
135135
if (pkg.is(Package)) ctx.fresh.setOwner(pkg.moduleClass).setTree(tree).setNewScope
136136
else ctx
137137
}
138+
139+
def isRechecking: Boolean =
140+
(ctx.base.recheckPhaseIds & (1L << ctx.phaseId)) != 0
141+
138142
end ContextOps

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ object Phases {
4141
// drop NoPhase at beginning
4242
def allPhases: Array[Phase] = (if (fusedPhases.nonEmpty) fusedPhases else phases).tail
4343

44+
private var myRecheckPhaseIds: Long = 0
45+
46+
/** A bitset of the ids of the phases extending `transform.Recheck`.
47+
* Recheck phases must have id 63 or less.
48+
*/
49+
def recheckPhaseIds: Long = myRecheckPhaseIds
50+
51+
def recordRecheckPhase(phase: Recheck): Unit =
52+
val id = phase.id
53+
assert(id < 64, s"Recheck phase with id $id outside permissible range 0..63")
54+
myRecheckPhaseIds |= (1L << id)
55+
4456
object SomePhase extends Phase {
4557
def phaseName: String = "<some phase>"
4658
def run(using Context): Unit = unsupported("run")

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ import config.Printers.{core, typr, matchTypes}
3838
import reporting.{trace, Message}
3939
import java.lang.ref.WeakReference
4040
import compiletime.uninitialized
41+
import ContextOps.isRechecking
4142
import cc.*
4243
import CaptureSet.IdentityCaptRefMap
4344
import Capabilities.*
45+
import transform.Recheck.currentRechecker
4446

4547
import scala.annotation.internal.sharable
4648
import scala.annotation.threadUnsafe
@@ -2509,15 +2511,31 @@ object Types extends TypeUtils {
25092511
lastDenotation match {
25102512
case lastd0: SingleDenotation =>
25112513
val lastd = lastd0.skipRemoved
2512-
if lastd.validFor.runId == ctx.runId && checkedPeriod.code != NowhereCode then
2514+
var needsRecompute = false
2515+
if lastd.validFor.runId == ctx.runId
2516+
&& checkedPeriod.code != NowhereCode
2517+
&& !(ctx.isRechecking
2518+
&& {
2519+
needsRecompute = currentRechecker.needsRecompute(this, lastd)
2520+
needsRecompute
2521+
}
2522+
)
2523+
then
25132524
finish(lastd.current)
2514-
else lastd match {
2515-
case lastd: SymDenotation =>
2516-
if stillValid(lastd) && checkedPeriod.code != NowhereCode then finish(lastd.current)
2517-
else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private)))
2518-
case _ =>
2519-
fromDesignator
2520-
}
2525+
else
2526+
val newd = lastd match
2527+
case lastd: SymDenotation =>
2528+
if stillValid(lastd) && checkedPeriod.code != NowhereCode && !needsRecompute
2529+
then finish(lastd.current)
2530+
else finish(memberDenot(lastd.initial.name, allowPrivate = lastd.is(Private)))
2531+
case _ =>
2532+
fromDesignator
2533+
if needsRecompute && (newd.info ne lastd.info) then
2534+
// Record the previous denotation, so that it can be reset at the end
2535+
// of the rechecker phase
2536+
currentRechecker.prevSelDenots(this) = lastd
2537+
//println(i"NEW PATH $this: ${newd.info} at ${ctx.phase}, prefix = $prefix")
2538+
newd
25212539
case _ => fromDesignator
25222540
}
25232541
}

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ object Recheck:
5050
case None =>
5151
tree
5252

53+
/** The currently running rechecker
54+
* @pre ctx.isRechecking
55+
*/
56+
def currentRechecker(using Context): Recheck =
57+
ctx.phase.asInstanceOf[Recheck]
58+
5359
extension (sym: Symbol)(using Context)
5460

5561
/** Update symbol's info to newInfo after `prevPhase`.
@@ -143,6 +149,7 @@ abstract class Recheck extends Phase, SymTransformer:
143149
else symd
144150

145151
def run(using Context): Unit =
152+
ctx.base.recordRecheckPhase(this)
146153
val rechecker = newRechecker()
147154
rechecker.checkUnit(ctx.compilationUnit)
148155
rechecker.reset()
@@ -151,6 +158,19 @@ abstract class Recheck extends Phase, SymTransformer:
151158
try super.runOn(units)
152159
finally preRecheckPhase.pastRecheck = true
153160

161+
/** A hook to determine whether the denotation of a NamedType should be recomputed
162+
* from its symbol and prefix, instead of just evolving the previous denotation with
163+
* `current`. This should return true if there are complex changes to types that
164+
* are not reflected in `current`.
165+
*/
166+
def needsRecompute(tp: NamedType, lastDenotation: SingleDenotation)(using Context): Boolean =
167+
false
168+
169+
/** A map from NamedTypes to the denotations they had before this phase.
170+
* Needed so that we can `reset` them after this phase.
171+
*/
172+
val prevSelDenots = util.HashMap[NamedType, Denotation]()
173+
154174
def newRechecker()(using Context): Rechecker
155175

156176
/** The typechecker pass */
@@ -192,17 +212,13 @@ abstract class Recheck extends Phase, SymTransformer:
192212
def resetNuTypes()(using Context): Unit =
193213
nuTypes.clear(resetToInitial = false)
194214

195-
/** A map from NamedTypes to the denotations they had before this phase.
196-
* Needed so that we can `reset` them after this phase.
197-
*/
198-
private val prevSelDenots = util.HashMap[NamedType, Denotation]()
199-
200215
/** Reset all references in `prevSelDenots` to the denotations they had
201216
* before this phase.
202217
*/
203218
def reset()(using Context): Unit =
204219
for (ref, mbr) <- prevSelDenots.iterator do
205220
ref.withDenot(mbr)
221+
prevSelDenots.clear()
206222

207223
/** Constant-folded rechecked type `tp` of tree `tree` */
208224
protected def constFold(tree: Tree, tp: Type)(using Context): Type =

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -412,16 +412,16 @@ object TreeChecker {
412412
assert(false, s"The type of a non-Super tree must not be a SuperType, but $tree has type $tp")
413413
case _ =>
414414

415-
override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree = {
416-
val tpdTree = super.typed(tree, pt)
417-
Typer.assertPositioned(tree)
418-
checkSuper(tpdTree)
419-
if (ctx.erasedTypes)
420-
// Can't be checked in earlier phases since `checkValue` is only run in
421-
// Erasure (because running it in Typer would force too much)
422-
checkIdentNotJavaClass(tpdTree)
423-
tpdTree
424-
}
415+
override def typed(tree: untpd.Tree, pt: Type = WildcardType)(using Context): Tree =
416+
trace(i"checking $tree against $pt"):
417+
val tpdTree = super.typed(tree, pt)
418+
Typer.assertPositioned(tree)
419+
checkSuper(tpdTree)
420+
if (ctx.erasedTypes)
421+
// Can't be checked in earlier phases since `checkValue` is only run in
422+
// Erasure (because running it in Typer would force too much)
423+
checkIdentNotJavaClass(tpdTree)
424+
tpdTree
425425

426426
override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
427427
try
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i23582.scala:27:26 ---------------------------------------
2+
27 | parReduce(1 to 1000): (x, y) => // error
3+
| ^
4+
|Found: (x: Int, y: Int) ->{write, read} Int
5+
|Required: (Int, Int) ->{cap.only[Read]} Int
6+
|
7+
|Note that capability write is not included in capture set {cap.only[Read]}.
8+
|
9+
|where: cap is a fresh root capability created in method test when checking argument to parameter op of method parReduce
10+
28 | write(x)
11+
29 | x + y + read()
12+
|
13+
| longer explanation available when compiling with `-explain`

0 commit comments

Comments
 (0)