From 289112a0fa7b11963ec8f794fd0f25e7bb6c76de Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Fri, 4 Oct 2024 16:51:14 -0700 Subject: [PATCH] FormatTokens: return FormatToken in matchingXxx --- .../scalafmt/internal/BestFirstSearch.scala | 12 ++-- .../org/scalafmt/internal/FormatOps.scala | 46 ++++++------ .../org/scalafmt/internal/FormatTokens.scala | 12 ++-- .../org/scalafmt/internal/FormatWriter.scala | 2 +- .../scala/org/scalafmt/internal/Router.scala | 71 +++++++++++-------- .../scala/org/scalafmt/internal/State.scala | 2 +- .../scalafmt/rewrite/RedundantBraces.scala | 12 ++-- .../scalafmt/rewrite/RedundantParens.scala | 2 +- .../rewrite/RemoveScala3OptionalBraces.scala | 2 +- .../scala/org/scalafmt/rewrite/Rewrite.scala | 2 +- .../rewrite/RewriteTrailingCommas.scala | 2 +- .../scala/org/scalafmt/util/StyleMap.scala | 9 +-- .../scala/org/scalafmt/util/TreeOps.scala | 35 +++++---- 13 files changed, 114 insertions(+), 95 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala index 0525c03eba..cb3d7ffcee 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala @@ -38,9 +38,11 @@ private class BestFirstSearch private (range: Set[Range])(implicit private def getBlockCloseToRecurse(ft: FormatToken, stop: Token)(implicit style: ScalafmtConfig, - ): Option[Token] = getEndOfBlock(ft, parensToo = true).filter { close => - // Block must span at least 3 lines to be worth recursing. - close != stop && distance(ft.left, close) > style.maxColumn * 3 + ): Option[Token] = getEndOfBlock(ft, parensToo = true).collect { + case close if close.left != stop && { + // Block must span at least 3 lines to be worth recursing. + distance(ft.left, close.left) > style.maxColumn * 3 + } => close.left } private val memo = mutable.Map.empty[Long, State] @@ -338,13 +340,13 @@ object BestFirstSearch { !styleMap.at(t).newlines.keep case _: Term.Apply => true // legacy: when enclosed in parens case _ => false - }) => expire = tokens.matching(t) + }) => expire = tokens.matching(t).left case FormatToken(t: Token.LeftBrace, _, m) if (m.leftOwner match { // Type compounds can be inside defn.defs case lo: meta.Stat.Block => lo.parent.is[Type.Refine] case _: Type.Refine => true case _ => false - }) => expire = tokens.matching(t) + }) => expire = tokens.matching(t).left case _ => } result.result() diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index 628525526a..ed151f6b88 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -176,7 +176,7 @@ class FormatOps( val isDefnSite = isParamClauseSite(owner) implicit val clauseSiteFlags: ClauseSiteFlags = ClauseSiteFlags(owner, isDefnSite) - val bpFlags = getBinpackSiteFlags(tokens(matching(t)), start, false) + val bpFlags = getBinpackSiteFlags(matching(t), start, false) if (bpFlags.scalaJsStyle) scalaJsOptCloseOnRight(start, bpFlags) else if ( !start.left.is[T.RightParen] || @@ -217,7 +217,7 @@ class FormatOps( new ExtractFromMeta[Tree](meta => statementStarts.get(meta.idx + 1)) def parensTuple(token: T): TokenRanges = matchingOpt(token) - .fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other))) + .fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other.left))) def insideBlock[A](start: FormatToken, end: T)(implicit classifier: Classifier[T, A], @@ -236,18 +236,19 @@ class FormatOps( ): TokenRanges = insideBlock(x => getEndOfBlock(x, parensToo))(start, end) def insideBlock( - matches: FormatToken => Option[T], + matches: FormatToken => Option[FormatToken], )(start: FormatToken, end: T): TokenRanges = { var result = TokenRanges.empty @tailrec def run(tok: FormatToken): Unit = if (tok.left.start < end.start) { - val nextTokOpt = matches(tok).flatMap { close => + val nextTokOpt = matches(tok).flatMap { closeFt => val open = tok.left + val close = closeFt.left if (open.start >= close.end) None else { result = result.append(TokenRange(open, close)) - Some(tokens(close)) + Some(closeFt) } } val nextTok = nextTokOpt.getOrElse(next(tok)) @@ -791,7 +792,7 @@ class FormatOps( val otherSplits = closeOpt.fold { val nlSplit = Split(nlMod, 1 + breakPenalty) Seq(nlSplit.withIndent(nlIndent).withPolicy(nlPolicy & delayedBreak)) - } { close => + } { closeFt => val noSingleLine = newStmtMod.isDefined || breakMany || rightAsInfix.exists(10 < infixSequenceLength(_)) val nextOp = @@ -807,8 +808,9 @@ class FormatOps( val nlSplit = Split(nlMod, 0).andPolicy(breakAfterClose) .withIndent(nlIndent).withPolicy(nlPolicy) val singleLineSplit = Split(spaceMod, 0).notIf(noSingleLine) - .withSingleLine(endOfNextOp.fold(close)(_.left)) - .andPolicy(breakAfterClose).andPolicy(getSingleLineInfixPolicy(close)) + .withSingleLine(endOfNextOp.getOrElse(closeFt).left) + .andPolicy(breakAfterClose) + .andPolicy(getSingleLineInfixPolicy(closeFt.left)) Seq(singleLineSplit, nlSplit) } @@ -1081,7 +1083,7 @@ class FormatOps( val values = clause.values if ( values.lengthCompare(cfg.minCount) >= 0 && - (cfg.minSpan == 0 || cfg.minSpan < distance(ftOpen.left, close)) + (cfg.minSpan == 0 || cfg.minSpan < distance(ftOpen.left, close.left)) ) { forces += ftOpen.meta.idx values.foreach(x => clearQueues += getHead(x).meta.idx) @@ -1111,7 +1113,7 @@ class FormatOps( val lpOwner = ft.meta.leftOwner val FormatToken(open, r, _) = ft - val close = matching(open) + val close = matching(open).left val indentParam = Num(style.indent.getDefnSite(lpOwner)) val indentSep = Num((indentParam.n - 2).max(0)) val isBracket = open.is[T.LeftBracket] @@ -1459,7 +1461,7 @@ class FormatOps( case Some(x) => x case None => findXmlLastLineIndent(prev(ft)) } - case t: T.Xml.SpliceEnd => findXmlLastLineIndent(tokens(matching(t), -1)) + case t: T.Xml.SpliceEnd => findXmlLastLineIndent(prev(matching(t))) case _ => findXmlLastLineIndent(prev(ft)) } @@ -1468,8 +1470,8 @@ class FormatOps( ): Seq[Split] = if (style.xmlLiterals.assumeFormatted) { val end = matching(tok) - val indent = Num(findXmlLastLineIndent(tokens(end, -1)), true) - splits.map(_.withIndent(indent, end, ExpiresOn.After)) + val indent = Num(findXmlLastLineIndent(prev(end)), true) + splits.map(_.withIndent(indent, end.left, ExpiresOn.After)) } else splits def withIndentOnXmlSpliceStart(ft: FormatToken, splits: Seq[Split])(implicit @@ -1478,7 +1480,7 @@ class FormatOps( case t: T.Xml.SpliceStart if style.xmlLiterals.assumeFormatted => val end = matching(t) val indent = Num(findXmlLastLineIndent(prev(ft)), true) - splits.map(_.withIndent(indent, end, ExpiresOn.After)) + splits.map(_.withIndent(indent, end.left, ExpiresOn.After)) case _ => splits } @@ -1608,7 +1610,7 @@ class FormatOps( _: Term.NewAnonymous => getSplits(getSpaceSplit(1)) case t: Term.ForYield => nextNonComment(bheadFT).right match { // skipping `for` case x @ LeftParenOrBrace() => - val exclude = TokenRanges(TokenRange(x, matching(x))) + val exclude = TokenRanges(TokenRange(x, matching(x).left)) (t.body match { case b: Term.Block => getBracesIfEnclosed(b) .map(x => getPolicySplits(1, getSlb(x._1.left, exclude))) @@ -1754,7 +1756,7 @@ class FormatOps( val right = nextNonComment(ft).right val rpOpt = if (right.is[T.LeftParen]) matchingOpt(right) else None val expire = nextNonCommentSameLine(rpOpt.fold(endFt) { rp => - if (rp.end >= endFt.left.end) before(rp) else endFt + if (rp.left.end >= endFt.left.end) rp else endFt }) nlSplit.withIndent(Num(style.indent.main), expire.left, ExpiresOn.After) } @@ -2180,7 +2182,7 @@ class FormatOps( case t: Term.While => t.expr match { case b: Term.Block if isMultiStatBlock(b) && - !matchingOpt(nft.right).exists(_.end >= b.pos.end) => + !matchingOpt(nft.right).exists(_.left.end >= b.pos.end) => Some(new OptionalBracesRegion { def owner = Some(t) def splits = Some { @@ -2358,7 +2360,7 @@ class FormatOps( val nr = nft.right t.cond match { case b: Term.Block if (matchingOpt(nr) match { - case Some(t) => t.end < b.pos.end + case Some(t) => t.left.end < b.pos.end case None => isMultiStatBlock(b) }) => Some(new OptionalBracesRegion { @@ -2473,7 +2475,7 @@ class FormatOps( case _: T.LeftBrace => false case _ => !isTreeSingleExpr(thenp) && (!before.right.is[T.LeftBrace] || matchingOpt(before.right) - .exists(rb => rb.end < thenp.pos.end)) + .exists(_.left.end < thenp.pos.end)) } } @@ -2722,11 +2724,11 @@ class FormatOps( def getEndOfBlock(ft: FormatToken, parensToo: => Boolean)(implicit style: ScalafmtConfig, - ): Option[T] = ft.left match { + ): Option[FormatToken] = ft.left match { case x: T.LeftBrace => matchingOpt(x) case x: T.LeftParen => if (parensToo) matchingOpt(x) else None case _ => OptionalBraces.get(ft) - .flatMap(_.rightBrace.map(x => nextNonCommentSameLine(x).left)) + .flatMap(_.rightBrace.map(x => nextNonCommentSameLine(x))) } def isCloseDelimForTrailingCommasMultiple(ft: FormatToken): Boolean = @@ -2887,7 +2889,7 @@ class FormatOps( implicit val style: ScalafmtConfig = styleMap.at(open) implicit val clauseSiteFlags: ClauseSiteFlags = ClauseSiteFlags .atCallSite(ftAfterClose.meta.rightOwner) - val bpFlagsAfter = getBinpackCallSiteFlags(tokens(open), ftAfterClose) + val bpFlagsAfter = getBinpackCallSiteFlags(open, ftAfterClose) scalaJsOptCloseOnRight(ftAfterClose, bpFlagsAfter) } else ftBeforeClose } else ftBeforeClose diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala index 420590d6fd..2e9a777190 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala @@ -23,8 +23,8 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken] result.result() }(arr) - private lazy val matchingParentheses: Map[TokenHash, Token] = TreeOps - .getMatchingParentheses(arr.view.map(_.right)) + private lazy val matchingParentheses: Map[TokenHash, FormatToken] = TreeOps + .getMatchingParentheses(arr.view)(_.left) override def length: Int = arr.length override def apply(idx: Int): FormatToken = arr(idx) @@ -77,19 +77,19 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken] def next(ft: FormatToken): FormatToken = apply(ft, 1) @inline - def matching(token: Token): Token = matchingParentheses.getOrElse( + def matching(token: Token): FormatToken = matchingParentheses.getOrElse( FormatTokens.thash(token), FormatTokens.throwNoToken(token, "Missing matching token index"), ) @inline - def matchingOpt(token: Token): Option[Token] = matchingParentheses + def matchingOpt(token: Token): Option[FormatToken] = matchingParentheses .get(FormatTokens.thash(token)) @inline def hasMatching(token: Token): Boolean = matchingParentheses .contains(FormatTokens.thash(token)) @inline def areMatching(t1: Token)(t2: => Token): Boolean = matchingOpt(t1) match { - case Some(x) => x eq t2 + case Some(x) => x.left eq t2 case _ => false } @@ -100,7 +100,7 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken] .flatMap { head => matchingOpt(head.left).flatMap { other => val last = getLastNonTrivial(tokens, tree) - if (last.left eq other) Some((head, last)) else None + if (last eq other) Some((head, last)) else None } } def getDelimsIfEnclosed(tree: Tree): Option[(FormatToken, FormatToken)] = diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala index 70d2169a3b..fade4926df 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala @@ -153,7 +153,7 @@ class FormatWriter(formatOps: FormatOps) { tok.left match { case rb: T.RightBrace // look for "foo { bar }" if RedundantBraces.canRewriteWithParensOnRightBrace(tok) => - val beg = tokens(matching(rb)).meta.idx + val beg = matching(rb).idx val bloc = locations(beg) val style = bloc.style if ( diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index f1b96af8ea..eae7ae1859 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -76,7 +76,7 @@ class Router(formatOps: FormatOps) { case FormatToken(_: T.BOF, _, _) => Seq(Split(NoSplit, 0)) case FormatToken(_: T.Shebang, _, _) => Seq(Split(Newline2x(ft), 0)) case FormatToken(start: T.Interpolation.Start, _, m) => - val end = matching(start) + val end = matching(start).left val policy = { val penalty = BreakSingleLineInterpolatedString if (style.newlines.inInterpolation eq Newlines.InInterpolation.avoid) @@ -130,8 +130,9 @@ class Router(formatOps: FormatOps) { // Import left brace case FormatToken(open: T.LeftBrace, _, _) if existsParentOfType[ImportExportStat](leftOwner) => - val close = matching(open) - val beforeClose = justBefore(close) + val closeFt = matching(open) + val close = closeFt.left + val beforeClose = prev(closeFt) val policy = SingleLineBlock( close, okSLC = style.importSelectors eq ImportSelectors.singleLine, @@ -163,7 +164,8 @@ class Router(formatOps: FormatOps) { // Interpolated string left brace case FormatToken(open @ T.LeftBrace(), _, _) if prev(ft).left.is[T.Interpolation.SpliceStart] => - val close = matching(open) + val closeFt = matching(open) + val close = closeFt.left val alignIndents = if (style.align.inInterpolation) Some { Seq( @@ -202,7 +204,7 @@ class Router(formatOps: FormatOps) { * - 2 Interpolation.Part (next string) * - 3 Interpolation.End (quotes) or Interpolation.SliceStart/LBrace (${) */ - val afterClose = tokens(close, 3) + val afterClose = tokens(closeFt, 3) val lastPart = afterClose.left.is[T.Interpolation.End] val slbEnd = if (lastPart) afterClose.left else afterClose.right Seq(spaceSplit.withSingleLine(slbEnd), newlineSplit(1)) @@ -236,8 +238,8 @@ class Router(formatOps: FormatOps) { // { ... } Blocks case FormatToken(open: T.LeftBrace, right, _) => - val close = matching(open) - val closeFT = tokens(close) + val closeFT = matching(open) + val close = closeFT.left val newlineBeforeClosingCurly = decideNewlinesOnlyBeforeClose(close) val isSelfAnnotationNL = style.optIn.selfAnnotationNewline && (hasBreak() || style.newlines.sourceIgnored) && @@ -501,7 +503,7 @@ class Router(formatOps: FormatOps) { // 2020-01: break after same-line comments, and any open brace val nonComment = nextNonCommentSameLine(ft) val hasBlock = nonComment.right.is[T.LeftBrace] && - (matching(nonComment.right) eq endOfFunction) + (matching(nonComment.right).left eq endOfFunction) val noSplit = if (!hasBlock && (nonComment eq ft)) Split(noSingleLine, 0)(Space) .withSingleLine(endOfFunction) @@ -661,7 +663,8 @@ class Router(formatOps: FormatOps) { indentLen: Int, shouldAlignBefore: Align => Boolean, )(lastSyntaxClause: => Option[Member.SyntaxValuesClause]) = { - val close = matching(open) + val closeFt = matching(open) + val close = closeFt.left val indent = Indent(indentLen, close, ExpiresOn.After) val isAlignFirstParen = shouldAlignBefore(style.align) && !prevNonComment(ft).left.is[T.RightParen] @@ -669,7 +672,7 @@ class Router(formatOps: FormatOps) { if (isAlignFirstParen) baseNoSplit else baseNoSplit.withSingleLine(close) def afterClose: Option[T] = { - val ftAfterClose = tokenAfter(close) + val ftAfterClose = nextNonComment(closeFt) val tokAfterClose = ftAfterClose.right val matches = tokAfterClose match { case _: T.LeftParen => true @@ -785,19 +788,23 @@ class Router(formatOps: FormatOps) { // Term.Apply and friends case FormatToken(lp: T.LeftParen, _, LambdaAtSingleArgCallSite(lambda)) => - val close = matching(lp) + val closeFt = matching(lp) + val close = closeFt.left val newlinePolicy = Policy ? style.danglingParentheses.callSite && decideNewlinesOnlyBeforeClose(close) val noSplitMod = if ( style.newlines.alwaysBeforeCurlyLambdaParams || - getMustDangleForTrailingCommas(justBefore(close)) + getMustDangleForTrailingCommas(prev(closeFt)) ) null else getNoSplitAfterOpening(ft, commentNL = null) def multilineSpaceSplit(implicit fileLine: FileLine): Split = { - val lambdaLeft: Option[T] = matchingOpt(functionExpire(lambda)._1) - .filter(_.is[T.LeftBrace]) + val lambdaLeft: Option[T] = + matchingOpt(functionExpire(lambda)._1) match { + case Some(FormatToken(lb: T.LeftBrace, _, _)) => Some(lb) + case _ => None + } val arrowFt = getFuncArrow(lambda).get val lambdaIsABlock = lambdaLeft.contains(arrowFt.right) @@ -839,8 +846,8 @@ class Router(formatOps: FormatOps) { else style.binPack.defnSiteFor(open) == BinPack.Site.Never && isParamClauseSite(leftOwner) } => - val close = matching(open) - val afterClose = tokens(close) + val afterClose = matching(open) + val close = afterClose.left val beforeClose = prev(afterClose) val tupleSite = isTuple(leftOwner) val anyDefnSite = isParamClauseSite(leftOwner) @@ -1102,7 +1109,8 @@ class Router(formatOps: FormatOps) { case FormatToken(open @ LeftParenOrBracket(), right, _) if style.binPack.defnSiteFor(open) != BinPack.Site.Never && isParamClauseSite(leftOwner) => - val close = matching(open) + val closeFt = matching(open) + val close = closeFt.left val noSplitMod = Space(style.spaces.inParentheses) if (close eq right) Seq(Split(noSplitMod, 0)) else { @@ -1113,7 +1121,6 @@ class Router(formatOps: FormatOps) { if (isBracket) Some(Constants.BracketPenalty) else None val penalizeBrackets = bracketPenalty .map(p => PenalizeAllNewlines(close, p + 3)) - val afterClose = after(close) val binpack = style.binPack.defnSiteFor(isBracket) val firstArg = argumentStarts.get(ft.meta.idx) @@ -1123,7 +1130,7 @@ class Router(formatOps: FormatOps) { } val nextCommaOneline = if (binpack.isOneline) nextComma else None - val flags = getBinpackDefnSiteFlags(ft, prev(afterClose)) + val flags = getBinpackDefnSiteFlags(ft, prev(closeFt)) val (nlOnly, nlCloseOnOpen) = flags.nlOpenClose() val noNLPolicy = flags.noNLPolicy val slbOrNL = nlOnly || noNLPolicy == null @@ -1188,8 +1195,9 @@ class Router(formatOps: FormatOps) { case FormatToken(open @ LeftParenOrBracket(), right, _) if style.binPack.callSiteFor(open) != BinPack.Site.Never && isArgClauseSite(leftOwner) => - val close = matching(open) - val beforeClose = justBefore(close) + val closeFt = matching(open) + val close = closeFt.left + val beforeClose = prev(closeFt) val isBracket = open.is[T.LeftBracket] val bracketPenalty = if (isBracket) Constants.BracketPenalty else 1 @@ -1269,7 +1277,7 @@ class Router(formatOps: FormatOps) { xft.right match { case `close` | _: T.RightBrace | _: T.RightArrow => null case x: T.Comma => Right(x) - case x: T.LeftBrace => Left(tokens(matching(x))) + case x: T.LeftBrace => Left(matching(x)) case _ => Left(next(xft)) } }.toOption @@ -1399,7 +1407,7 @@ class Router(formatOps: FormatOps) { case FormatToken(_: T.Comma, open: T.LeftBrace, _) if !style.poorMansTrailingCommasInConfigStyle && isArgClauseSite(leftOwner) => - val close = matching(open) + val close = matching(open).left val binPackIsEnabled = style.binPack.callSiteFor(leftOwner) != BinPack.Site.Never val useSpace = !style.newlines.keepBreak(newlines) @@ -1448,13 +1456,13 @@ class Router(formatOps: FormatOps) { startsInfix(p) case _: Term.ArgClause => false case p => isTokenHeadOrBefore(lb, p) && matchingOpt(lb) - .exists(isTokenLastOrAfter(_, roPos)) + .exists(x => isTokenLastOrAfter(x.left, roPos)) } } => val slbParensSplit = if (!initStyle.rewrite.bracesToParensForOneLineApply) None - else RedundantBraces - .noSplitForParensOnRightBrace(tokens(matching(lb))).map { rbft => + else RedundantBraces.noSplitForParensOnRightBrace(matching(lb)) + .map { rbft => // copy logic from `( ...`, binpack=never, defining `slbSplit` val isBeforeOpenParen = style.newlines.isBeforeOpenParenCallSite val optimal: T = @@ -1759,7 +1767,7 @@ class Router(formatOps: FormatOps) { if (prevChain) noSplit else chainExpire match { // allow newlines in final {} block case x: T.RightBrace => noSplit - .withSingleLine(matching(x), noSyntaxNL = true) + .withSingleLine(matching(x).left, noSyntaxNL = true) case x => noSplit .withSingleLineNoOptimal(x, noSyntaxNL = true) } @@ -1962,7 +1970,7 @@ class Router(formatOps: FormatOps) { !isTokenHeadOrBefore(open, leftOwner) case _ => false }) => - val close = matching(open) + val close = matching(open).left val indentLen = style.indent.ctrlSite.getOrElse(style.indent.callSite) def indents = if (style.align.openParenCtrlSite) getOpenParenAlignIndents(close) @@ -2089,7 +2097,7 @@ class Router(formatOps: FormatOps) { case _: Term.If => false case b @ Term.Block((_: Term.If) :: Nil) => matchingOpt(nextNonComment(ft).right) - .exists(_.end >= b.pos.end) + .exists(_.left.end >= b.pos.end) case _ => true } case x => throw new UnexpectedTree[Term.If](x) @@ -2118,8 +2126,9 @@ class Router(formatOps: FormatOps) { Seq(Split(NoSplit, 0)) case FormatToken(open: T.LeftParen, right, _) => - val close = matching(open) - val beforeClose = justBefore(close) + val closeFt = matching(open) + val close = closeFt.left + val beforeClose = prev(closeFt) implicit val clauseSiteFlags = ClauseSiteFlags.atCallSite(leftOwner) val isConfig = couldPreserveConfigStyle(ft, beforeClose.hasBreak) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala index c57dc01abd..5a39659e59 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala @@ -168,7 +168,7 @@ final case class State( else lineStartsStatement(isComment) val delay = startFtOpt.exists { case FormatToken(_, t: Token.Interpolation.Start, _) => tokens - .matching(t) ne ft.right + .matching(t).left ne ft.right case _ => true } // if delaying, estimate column if the split had been a newline diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 5a6585a176..13e4ea48a3 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -163,8 +163,8 @@ class RedundantBraces(implicit val ftoks: FormatTokens) case ReplacementType.Remove => val resOpt = getRightBraceBeforeRightParen(false).map { rb => ft.meta.rightOwner match { - case ac: Term.ArgClause => ftoks.matchingOpt(rb.left) - .map(ftoks.justBefore).foreach { lb => + case ac: Term.ArgClause => ftoks.matchingOpt(rb.left).map(ftoks.prev) + .foreach { lb => session.rule[RemoveScala3OptionalBraces].foreach { r => session.getClaimed(lb.meta.idx).foreach { case (leftIdx, _) => val repl = r.onLeftForArgClause(ac)(lb, left.style) @@ -419,7 +419,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) nft.noBreak || style.formatInfix(p) && !nft.right.is[Token.Comment] } def checkClose = { - val nft = ftoks(ftoks.matching(ft.right), -1) + val nft = ftoks.prev(ftoks.matching(ft.right)) nft.noBreak || style.formatInfix(p) && !nft.left.is[Token.Comment] } checkOpen && checkClose @@ -525,7 +525,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) // inside exists, return true if rewrite is OK !stat.tokens.headOption.exists { case x: Token.LeftParen => ftoks.matchingOpt(x) match { - case Some(y) if y ne stat.tokens.last => + case Some(y) if y.left ne stat.tokens.last => session.rule[RedundantParens].exists { _.onToken(ftoks(x, -1), session, style).exists(_.isRemove) } @@ -534,8 +534,8 @@ class RedundantBraces(implicit val ftoks: FormatTokens) case _ => true } case x: Token.LeftBrace => ftoks.matchingOpt(x) match { - case Some(y) if y ne stat.tokens.last => - findFirstTreeBetween(stat, x, y).exists { + case Some(y) if y.left ne stat.tokens.last => + findFirstTreeBetween(stat, x, y.left).exists { case z: Term.Block => okToRemoveBlock(z) case _ => false } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala index 40ef052fc2..ffb3ef4ec4 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala @@ -228,7 +228,7 @@ class RedundantParens(implicit val ftoks: FormatTokens) .map((cnt, _)) } - ftoks.matchingOpt(ft.right).flatMap(rt => iter(ft, ftoks.after(rt), 1)) + ftoks.matchingOpt(ft.right).flatMap(rt => iter(ft, rt, 1)) } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala index 093f7c98bd..d773f5174c 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala @@ -60,7 +60,7 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) } case t: Term.EnumeratorsBlock if allowOldSyntax || !t.parent.is[Term.For] || { - val rbFt = ftoks(ftoks.matching(ft.right)) + val rbFt = ftoks.matching(ft.right) ftoks.nextNonComment(rbFt).right.is[Token.KwDo] } => removeToken case _: Tree.CasesBlock => removeToken diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala index 29dea9e01f..a33a946083 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala @@ -22,7 +22,7 @@ case class RewriteCtx(style: ScalafmtConfig, input: Input, tree: Tree) { val tokens = tree.tokens val tokenTraverser = new TokenTraverser(tokens, input) - val matchingParens = TreeOps.getMatchingParentheses(tokens) + val matchingParens = TreeOps.getMatchingParentheses(tokens)(identity) @inline def getMatching(a: Token): Token = matchingParens(TokenOps.hash(a)) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala index 4b03948e9c..35f0f83ccb 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala @@ -65,7 +65,7 @@ private class RewriteTrailingCommas(implicit val ftoks: FormatTokens) case rp: Token.RightParen => delimOwner .isAny[Member.SyntaxValuesClause, Member.Tuple] || ftoks.matchingOpt(rp).exists { lp => - val claimant = session.claimedRule(ftoks.justBefore(lp)) + val claimant = session.claimedRule(ftoks.prev(lp)) claimant.forall(_.rule.isInstanceOf[RedundantParens]) } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/StyleMap.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/StyleMap.scala index 835090140f..041df2e791 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/StyleMap.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/StyleMap.scala @@ -25,7 +25,7 @@ class StyleMap(tokens: FormatTokens, val init: ScalafmtConfig) { val styleBuilder = Array.newBuilder[ScalafmtConfig] startBuilder += 0 styleBuilder += init - val disableBinPack = mutable.Map.empty[Token, BinPack.Site] + val disableBinPack = mutable.Map.empty[Int, BinPack.Site] def warn(err: String)(implicit fileLine: FileLine): Unit = logger.elem(err) tokens.arr.foreach { ft => def changeStyle(style: ScalafmtConfig): Option[ScalafmtConfig] = { @@ -57,10 +57,11 @@ class StyleMap(tokens: FormatTokens, val init: ScalafmtConfig) { forcedBinPack += ft.meta.leftOwner changeStyle(setBinPack(curr, callSite = BinPack.Site.Always)) .foreach { x => - tokens.matchingOpt(tok) - .foreach(disableBinPack.update(_, x.binPack.callSite)) + tokens.matchingOpt(tok).foreach { y => + disableBinPack.update(y.idx, x.binPack.callSite) + } } - case tok: Token.RightParen => disableBinPack.remove(tok) + case _: Token.RightParen => disableBinPack.remove(ft.idx) .foreach(x => changeStyle(setBinPack(curr, callSite = x))) case _ => } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala index 2b5c3533b9..76fc72d121 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala @@ -208,23 +208,28 @@ object TreeOps { * * Contains lookup keys in both directions, opening [({ and closing })]. */ - def getMatchingParentheses(tokens: Iterable[Token]): Map[TokenHash, Token] = { - val ret = Map.newBuilder[TokenHash, Token] - var stack = List.empty[Token] - tokens.foreach { - case open @ (LeftBrace() | LeftBracket() | LeftParen() | Interpolation - .Start() | Xml.Start() | Xml.SpliceStart()) => stack = open :: stack - case close @ (RightBrace() | RightBracket() | RightParen() | Interpolation - .End() | Xml.End() | Xml.SpliceEnd()) => - val open = stack.head - assertValidParens(open, close) - ret += hash(open) -> close - ret += hash(close) -> open - stack = stack.tail - case _ => + def getMatchingParentheses[A]( + coll: Iterable[A], + )(f: A => Token): Map[TokenHash, A] = { + val ret = Map.newBuilder[TokenHash, A] + var stack = List.empty[(Token, A)] + coll.foreach { elem => + f(elem) match { + case open @ (_: Token.OpenDelim | _: Interpolation.Start | + _: Xml.Start | _: Xml.SpliceStart) => stack = (open, elem) :: stack + case close @ (_: Token.CloseDelim | _: Interpolation.End | _: Xml.End | + _: Xml.SpliceEnd) => + val (open, openElem) = stack.head + assertValidParens(open, close) + ret += hash(open) -> elem + ret += hash(close) -> openElem + stack = stack.tail + case _ => + } } if (stack.nonEmpty) throw new IllegalArgumentException( - stack.map(x => s"[${x.end}]$x").mkString("Orphan parens (", ", ", ")"), + stack.map { case (x, _) => s"[${x.end}]$x" } + .mkString("Orphan parens (", ", ", ")"), ) val result = ret.result() result