Skip to content

Commit

Permalink
Under betterFors don't drop the trailing map if it would result i…
Browse files Browse the repository at this point in the history
…n a different type (also drop `_ => ()`) (#22619)

closes #21804
  • Loading branch information
KacperFKorban authored Feb 21, 2025
1 parent 4d48bce commit d4421d0
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 15 deletions.
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import parsing.Parser
import Phases.Phase
import transform.*
import backend.jvm.{CollectSuperCalls, GenBCode}
import localopt.StringInterpolatorOpt
import localopt.{StringInterpolatorOpt, DropForMap}

/** The central class of the dotc compiler. The job of a compiler is to create
* runs, which process given `phases` in a given `rootContext`.
Expand Down Expand Up @@ -68,7 +68,8 @@ class Compiler {
new InlineVals, // Check right hand-sides of an `inline val`s
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
new ElimRepeated, // Rewrite vararg parameters and arguments
new RefChecks) :: // Various checks mostly related to abstract members and overriding
new RefChecks, // Various checks mostly related to abstract members and overriding
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file
List(new init.Checker) :: // Check initialization of objects
List(new ProtectedAccessors, // Add accessors for protected members
Expand Down
31 changes: 19 additions & 12 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ object desugar {
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** An attachment key to indicate that an Apply is created as a last `map`
* scall in a for-comprehension.
*/
val TrailingForMap: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -1967,14 +1972,8 @@ object desugar {
*
* 3.
*
* for (P <- G) yield P ==> G
*
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
*
* for (P <- G) yield E ==> G.map (P => E)
*
* Otherwise
*
* 4.
*
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
Expand Down Expand Up @@ -2147,14 +2146,20 @@ object desugar {
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
case _ => false

def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit =
if betterForsEnabled
&& selectName == mapName
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil)))
then
aply.putAttachment(TrailingForMap, ())

enums match {
case Nil if betterForsEnabled => body
case (gen: GenFrom) :: Nil =>
if betterForsEnabled
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
markTrailingMap(aply, gen, mapName)
aply
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
Expand All @@ -2165,7 +2170,9 @@ object desugar {
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
markTrailingMap(aply, gen, selectName)
aply
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
Expand Down
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dotty.tools.dotc
package transform.localopt

import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
import dotty.tools.dotc.ast.desugar

/** Drop unused trailing map calls in for comprehensions.
* We can drop the map call if:
* - it won't change the type of the expression, and
* - the function is an identity function or a const function to unit.
*
* The latter condition is checked in [[Desugar.scala#makeFor]]
*/
class DropForMap extends MiniPhase:
import DropForMap.*

override def phaseName: String = DropForMap.name

override def description: String = DropForMap.description

override def transformApply(tree: Apply)(using Context): Tree =
if !tree.hasAttachment(desugar.TrailingForMap) then tree
else tree match
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
f // drop the map call
case _ =>
tree.removeAttachment(desugar.TrailingForMap)
tree

private object Lambda:
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
tree match
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
Some((defdef.termParamss.flatten, defdef.rhs))
case _ => None

private object MapCall:
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
case Select(f, nme.map) => Some(f)
case Apply(fn, _) => unapply(fn)
case TypeApply(fn, _) => unapply(fn)
case _ => None

object DropForMap:
val name: String = "dropForMap"
val description: String = "Drop unused trailing map calls in for comprehensions"
2 changes: 1 addition & 1 deletion docs/_docs/reference/experimental/better-fors.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Additionally this extension changes the way `for`-comprehensions are desugared.
This change makes the desugaring more intuitive and avoids unnecessary `map` calls, when an alias is not followed by a guard.

2. **Avoiding Redundant `map` Calls**:
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. but th eequality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables.
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. But the equality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. There is also a special case for dropping the `map`, if its body is a constant function, that returns `()` (`Unit` constant).
**Current Desugaring**:
```scala
for {
Expand Down
13 changes: 13 additions & 0 deletions tests/pos/better-fors-i21804.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.language.experimental.betterFors

case class Container[A](val value: A) {
def map[B](f: A => B): Container[B] = Container(f(value))
}

sealed trait Animal
case class Dog() extends Animal

def opOnDog(dog: Container[Dog]): Container[Animal] =
for
v <- dog
yield v
4 changes: 4 additions & 0 deletions tests/run/better-fors-map-elim.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
MySome(())
MySome(2)
MySome((2,3))
MySome((2,(3,4)))
64 changes: 64 additions & 0 deletions tests/run/better-fors-map-elim.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import scala.language.experimental.betterFors

class myOptionModule(doOnMap: => Unit) {
sealed trait MyOption[+A] {
def map[B](f: A => B): MyOption[B] = this match {
case MySome(x) => {
doOnMap
MySome(f(x))
}
case MyNone => MyNone
}
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
case MySome(x) => f(x)
case MyNone => MyNone
}
}
case class MySome[A](x: A) extends MyOption[A]
case object MyNone extends MyOption[Nothing]
object MyOption {
def apply[A](x: A): MyOption[A] = MySome(x)
}
}

object Test extends App {

val myOption = new myOptionModule(println("map called"))

import myOption.*

def portablePrintMyOption(opt: MyOption[Any]): Unit =
if opt == MySome(()) then
println("MySome(())")
else
println(opt)

val z = for {
a <- MyOption(1)
b <- MyOption(())
} yield ()

portablePrintMyOption(z)

val z2 = for {
a <- MyOption(1)
b <- MyOption(2)
} yield b

portablePrintMyOption(z2)

val z3 = for {
a <- MyOption(1)
(b, c) <- MyOption((2, 3))
} yield (b, c)

portablePrintMyOption(z3)

val z4 = for {
a <- MyOption(1)
(b, (c, d)) <- MyOption((2, (3, 4)))
} yield (b, (c, d))

portablePrintMyOption(z4)

}

0 comments on commit d4421d0

Please sign in to comment.