-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Under
betterFors
don't drop the trailing map
if it would result i…
- Loading branch information
1 parent
4d48bce
commit d4421d0
Showing
7 changed files
with
158 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package dotty.tools.dotc | ||
package transform.localopt | ||
|
||
import dotty.tools.dotc.ast.tpd.* | ||
import dotty.tools.dotc.core.Decorators.* | ||
import dotty.tools.dotc.core.Contexts.* | ||
import dotty.tools.dotc.core.StdNames.* | ||
import dotty.tools.dotc.core.Symbols.* | ||
import dotty.tools.dotc.core.Types.* | ||
import dotty.tools.dotc.transform.MegaPhase.MiniPhase | ||
import dotty.tools.dotc.ast.desugar | ||
|
||
/** Drop unused trailing map calls in for comprehensions. | ||
* We can drop the map call if: | ||
* - it won't change the type of the expression, and | ||
* - the function is an identity function or a const function to unit. | ||
* | ||
* The latter condition is checked in [[Desugar.scala#makeFor]] | ||
*/ | ||
class DropForMap extends MiniPhase: | ||
import DropForMap.* | ||
|
||
override def phaseName: String = DropForMap.name | ||
|
||
override def description: String = DropForMap.description | ||
|
||
override def transformApply(tree: Apply)(using Context): Tree = | ||
if !tree.hasAttachment(desugar.TrailingForMap) then tree | ||
else tree match | ||
case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) | ||
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change | ||
f // drop the map call | ||
case _ => | ||
tree.removeAttachment(desugar.TrailingForMap) | ||
tree | ||
|
||
private object Lambda: | ||
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = | ||
tree match | ||
case Block(List(defdef: DefDef), Closure(Nil, ref, _)) | ||
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => | ||
Some((defdef.termParamss.flatten, defdef.rhs)) | ||
case _ => None | ||
|
||
private object MapCall: | ||
def unapply(tree: Tree)(using Context): Option[Tree] = tree match | ||
case Select(f, nme.map) => Some(f) | ||
case Apply(fn, _) => unapply(fn) | ||
case TypeApply(fn, _) => unapply(fn) | ||
case _ => None | ||
|
||
object DropForMap: | ||
val name: String = "dropForMap" | ||
val description: String = "Drop unused trailing map calls in for comprehensions" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import scala.language.experimental.betterFors | ||
|
||
case class Container[A](val value: A) { | ||
def map[B](f: A => B): Container[B] = Container(f(value)) | ||
} | ||
|
||
sealed trait Animal | ||
case class Dog() extends Animal | ||
|
||
def opOnDog(dog: Container[Dog]): Container[Animal] = | ||
for | ||
v <- dog | ||
yield v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
MySome(()) | ||
MySome(2) | ||
MySome((2,3)) | ||
MySome((2,(3,4))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import scala.language.experimental.betterFors | ||
|
||
class myOptionModule(doOnMap: => Unit) { | ||
sealed trait MyOption[+A] { | ||
def map[B](f: A => B): MyOption[B] = this match { | ||
case MySome(x) => { | ||
doOnMap | ||
MySome(f(x)) | ||
} | ||
case MyNone => MyNone | ||
} | ||
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match { | ||
case MySome(x) => f(x) | ||
case MyNone => MyNone | ||
} | ||
} | ||
case class MySome[A](x: A) extends MyOption[A] | ||
case object MyNone extends MyOption[Nothing] | ||
object MyOption { | ||
def apply[A](x: A): MyOption[A] = MySome(x) | ||
} | ||
} | ||
|
||
object Test extends App { | ||
|
||
val myOption = new myOptionModule(println("map called")) | ||
|
||
import myOption.* | ||
|
||
def portablePrintMyOption(opt: MyOption[Any]): Unit = | ||
if opt == MySome(()) then | ||
println("MySome(())") | ||
else | ||
println(opt) | ||
|
||
val z = for { | ||
a <- MyOption(1) | ||
b <- MyOption(()) | ||
} yield () | ||
|
||
portablePrintMyOption(z) | ||
|
||
val z2 = for { | ||
a <- MyOption(1) | ||
b <- MyOption(2) | ||
} yield b | ||
|
||
portablePrintMyOption(z2) | ||
|
||
val z3 = for { | ||
a <- MyOption(1) | ||
(b, c) <- MyOption((2, 3)) | ||
} yield (b, c) | ||
|
||
portablePrintMyOption(z3) | ||
|
||
val z4 = for { | ||
a <- MyOption(1) | ||
(b, (c, d)) <- MyOption((2, (3, 4))) | ||
} yield (b, (c, d)) | ||
|
||
portablePrintMyOption(z4) | ||
|
||
} |