Skip to content

Fix #23224: Optimize simple tuple extraction #23373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,27 @@ object desugar {
sel
end match

case class TuplePatternInfo(arity: Int, varNum: Int, wildcardNum: Int)
object TuplePatternInfo:
def apply(pat: Tree)(using Context): TuplePatternInfo = pat match
case Tuple(pats) =>
var arity = 0
var varNum = 0
var wildcardNum = 0
pats.foreach: p =>
arity += 1
p match
case id: Ident if !isBackquoted(id) =>
if id.name.isVarPattern then
varNum += 1
if id.name == nme.WILDCARD then
wildcardNum += 1
case _ =>
TuplePatternInfo(arity, varNum, wildcardNum)
case _ =>
TuplePatternInfo(-1, -1, -1)
end TuplePatternInfo

/** If `pat` is a variable pattern,
*
* val/var/lazy val p = e
Expand Down Expand Up @@ -1483,30 +1504,47 @@ object desugar {
|please bind to an identifier and use an alias given.""", bind)
false

def isTuplePattern(arity: Int): Boolean = pat match {
case Tuple(pats) if pats.size == arity =>
pats.forall(isVarPattern)
case _ => false
}

val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => isTuplePattern(es.length) && !hasNamedArg(es)
case _ => false
}
val tuplePatternInfo = TuplePatternInfo(pat)

// When desugaring a PatDef in general, we use pattern matching on the rhs
// and collect the variable values in a tuple, then outside the match,
// we destructure the tuple to get the individual variables.
// We can achieve two kinds of tuple optimizations if the pattern is a tuple
// of simple variables or wildcards:
// 1. Full optimization:
// If the rhs is known to produce a literal tuple of the same arity,
// we can directly fetch the values from the tuple.
// For example: `val (x, y) = if ... then (1, "a") else (2, "b")` becomes
// `val $1$ = if ...; val x = $1$._1; val y = $1$._2`.
// 2. Partial optimization:
// If the rhs can be typed as a tuple and matched with correct arity, we can
// return the tuple itself in the case if there are no more than one variable
// in the pattern, or return the the value if there is only one variable.

val fullTupleOptimizable =
val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => tuplePatternInfo.varNum == es.length && !hasNamedArg(es)
case _ => false
}
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
&& forallResults(rhs, isMatchingTuple)

// We can only optimize `val pat = if (...) e1 else e2` if:
// - `e1` and `e2` are both tuples of arity N
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
val partialTupleOptimizable =
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
// We exclude the case where there is only one variable,
// because it should be handled by `makeTuple` directly.
&& tuplePatternInfo.wildcardNum < tuplePatternInfo.arity - 1

val inAliasGenerator = original match
case _: GenAlias => true
case _ => false

val vars =
if (tupleOptimizable) // include `_`
val vars: List[VarInfo] =
if fullTupleOptimizable || partialTupleOptimizable then // include `_`
pat match
case Tuple(pats) => pats.map { case id: Ident => id -> TypeTree() }
case Tuple(pats) => pats.map { case id: Ident => (id, TypeTree()) }
else
getVariables(
tree = pat,
Expand All @@ -1517,12 +1555,27 @@ object desugar {
errorOnGivenBinding
) // no `_`

val ids = for ((named, _) <- vars) yield Ident(named.name)
val ids = for ((named, tpt) <- vars) yield Ident(named.name)

val matchExpr =
if (tupleOptimizable) rhs
if fullTupleOptimizable then rhs
else
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
val caseDef =
if partialTupleOptimizable then
val tmpTuple = UniqueName.fresh()
// Replace all variables with wildcards in the pattern
val pat1 = pat match
case Tuple(pats) =>
val wildcardPats = pats.map(p => Ident(nme.WILDCARD).withSpan(p.span))
Tuple(wildcardPats).withSpan(pat.span)
CaseDef(
Bind(tmpTuple, pat1),
EmptyTree,
Ident(tmpTuple).withAttachment(ForArtifact, ())
)
else CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)

vars match {
case Nil if !mods.is(Lazy) =>
matchExpr
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,16 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
}

/** Checks whether predicate `p` is true for all result parts of this expression,
* where we zoom into Ifs, Matches, and Blocks.
* where we zoom into Ifs, Matches, Tries, and Blocks.
*/
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match {
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Match(_, cases) => cases.forall(c => forallResults(c.body, p))
case Try(_, cases, finalizer) =>
cases.forall(c => forallResults(c.body, p))
&& (finalizer.isEmpty || forallResults(finalizer, p))
case Block(_, expr) => forallResults(expr, p)
case _ => p(tree)
}

/** The tree stripped of the possibly nested applications (term and type).
* The original tree if it's not an application.
Expand Down
13 changes: 11 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if !isFullyDefined(pt, ForceDegree.all) then
return errorTree(tree, em"expected type of $tree is not fully defined")
val body1 = typed(tree.body, pt)

// If the body is a named tuple pattern, we need to use pt for symbol type,
// because the desugared body is a regular tuple unapply.
def isNamedTuplePattern =
ctx.mode.is(Mode.Pattern)
&& pt.dealias.isNamedTupleType
&& tree.body.match
case untpd.Tuple((_: NamedArg) :: _) => true
case _ => false

body1 match {
case UnApply(fn, Nil, arg :: Nil)
if fn.symbol.exists && (fn.symbol.owner.derivesFrom(defn.TypeTestClass) || fn.symbol.owner == defn.ClassTagClass) && !body1.tpe.isError =>
Expand All @@ -2799,8 +2809,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
body1.isInstanceOf[RefTree] && !isWildcardArg(body1)
|| body1.isInstanceOf[Literal]
val symTp =
if isStableIdentifierOrLiteral || pt.dealias.isNamedTupleType then pt
// need to combine tuple element types with expected named type
if isStableIdentifierOrLiteral || isNamedTuplePattern then pt
else if isWildcardStarArg(body1)
|| pt == defn.ImplicitScrutineeTypeRef
|| body1.tpe <:< pt // There is some strange interaction with gadt matching.
Expand Down
43 changes: 43 additions & 0 deletions tests/pos/simple-tuple-extract.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

class Test:
def f1: (Int, String, AnyRef) = (1, "2", "3")
def f2: (x: Int, y: String) = (0, "y")

def test1 =
val (a, b, c) = f1
// Desugared to:
// val $2$: (Int, String, AnyRef) =
// this.f1:(Int, String, AnyRef) @unchecked match
// {
// case $1$ @ Tuple3.unapply[Int, String, Object](_, _, _) =>
// $1$:(Int, String, AnyRef)
// }
// val a: Int = $2$._1
// val b: String = $2$._2
// val c: AnyRef = $2$._3
a + b.length() + c.toString.length()

// This pattern will not be optimized:
// val (a1, b1, c1: String) = f1

def test2 =
val (_, b, c) = f1
b.length() + c.toString.length()

val (a2, _, c2) = f1
a2 + c2.toString.length()

val (a3, _, _) = f1
a3 + 1

def test3 =
val (_, b, _) = f1
b.length() + 1

def test4 =
val (x, y) = f2
x + y.length()

def test5 =
val (_, b) = f2
b.length() + 1
Loading