Skip to content

Commit

Permalink
TreeOps: move getArgs from FormatOps, make partial
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Jun 8, 2024
1 parent 8df63da commit 57546ef
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.scalafmt.internal

import org.scalafmt.Error.UnexpectedTree
import org.scalafmt.config.BinPack
import org.scalafmt.config.IndentOperator
import org.scalafmt.config.Indents
Expand All @@ -11,7 +10,6 @@ import org.scalafmt.config.TrailingCommas
import org.scalafmt.internal.Length.Num
import org.scalafmt.internal.Policy.NoPolicy
import org.scalafmt.util.InfixApp._
import org.scalafmt.util.LoggerOps._
import org.scalafmt.util._

import org.scalameta.FileLine
Expand Down Expand Up @@ -534,7 +532,7 @@ class FormatOps(
if fullInfix.parent.contains(prevOwner) && !(prevOwner match {
case po: Member.ArgClause => po.parent.exists(isInfixApp)
case po => isInfixApp(po)
}) && isSeqSingle(getArgs(prevOwner, orNil = true)) =>
}) && isSeqSingle(getArgsOrNil(prevOwner)) =>
Some(getLastToken(fullInfix))
case _ => None
}
Expand Down Expand Up @@ -1255,24 +1253,6 @@ class FormatOps(
else nextAfterNonComment(maybeArrow)
}

def getArgs(owner: Tree, orNil: Boolean = false): Seq[Tree] = owner match {
case _: Lit.Unit => Nil
case t: Term.Super => t.superp :: Nil
case Member.Tuple(v) => v
case Member.SyntaxValuesClause(v) => v
case t: Member.Function => t.paramClause.values
case _ if orNil => Nil
case t =>
logger.debug(
s"""|getApplyArgs: unknown tree
|Tree: ${log(t)}
|Parent: ${log(t.parent)}
|GrandParent: ${log(t.parent.flatMap(_.parent))}
|""".stripMargin,
)
throw UnexpectedTree[Member.SyntaxValuesClause](t)
}

@tailrec
final def findPrevSelectAndApply(
tree: Tree,
Expand Down Expand Up @@ -2567,8 +2547,8 @@ class FormatOps(
def isCloseDelimForTrailingCommasMultiple(ft: FormatToken): Boolean =
ft.meta.rightOwner match {
case x: Importer => x.importees.lengthCompare(1) > 0
case x => // take last arg when multiple
getArgs(x, orNil = true).view.drop(1).lastOption match {
// take last arg when multiple
case x => getArgsOrNil(x).view.drop(1).lastOption match {
case None | Some(_: Term.Repeated) => false
case Some(t: Term.Param) => !t.decltpe.exists(_.is[Type.Repeated])
case _ => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.scalafmt.internal.FormatTokens
import org.scalafmt.internal.Modification
import org.scalafmt.internal.Space
import org.scalafmt.util.InfixApp._
import org.scalafmt.util.LoggerOps._

import scala.meta._
import scala.meta.classifiers.Classifier
Expand Down Expand Up @@ -1029,4 +1030,29 @@ object TreeOps {
case _ => false
}

val getArgsPartial: PartialFunction[Tree, List[Tree]] = {
case _: Lit.Unit => Nil
case t: Term.Super => t.superp :: Nil
case Member.Tuple(v) => v
case Member.SyntaxValuesClause(v) => v
case t: Member.Function => t.paramClause.values
}

def getArgsOrNil(owner: Tree): List[Tree] = getArgsPartial.lift(owner)
.getOrElse(Nil)

private def throwUnexpectedGetArgs(t: Tree): Nothing = {
logger.debug(
s"""|getArgs: unknown tree
|Tree: ${log(t)}
|Parent: ${log(t.parent)}
|GrandParent: ${log(t.parent.flatMap(_.parent))}
|""".stripMargin,
)
throw Error.UnexpectedTree[Member.SyntaxValuesClause](t)
}

def getArgs(owner: Tree): Seq[Tree] = getArgsPartial
.applyOrElse(owner, throwUnexpectedGetArgs)

}

0 comments on commit 57546ef

Please sign in to comment.