Skip to content

Commit

Permalink
FormatTokens: return FormatToken in matchingXxx
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Oct 5, 2024
1 parent ad48a32 commit 289112a
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ private class BestFirstSearch private (range: Set[Range])(implicit

private def getBlockCloseToRecurse(ft: FormatToken, stop: Token)(implicit
style: ScalafmtConfig,
): Option[Token] = getEndOfBlock(ft, parensToo = true).filter { close =>
// Block must span at least 3 lines to be worth recursing.
close != stop && distance(ft.left, close) > style.maxColumn * 3
): Option[Token] = getEndOfBlock(ft, parensToo = true).collect {
case close if close.left != stop && {
// Block must span at least 3 lines to be worth recursing.
distance(ft.left, close.left) > style.maxColumn * 3
} => close.left
}

private val memo = mutable.Map.empty[Long, State]
Expand Down Expand Up @@ -338,13 +340,13 @@ object BestFirstSearch {
!styleMap.at(t).newlines.keep
case _: Term.Apply => true // legacy: when enclosed in parens
case _ => false
}) => expire = tokens.matching(t)
}) => expire = tokens.matching(t).left
case FormatToken(t: Token.LeftBrace, _, m) if (m.leftOwner match {
// Type compounds can be inside defn.defs
case lo: meta.Stat.Block => lo.parent.is[Type.Refine]
case _: Type.Refine => true
case _ => false
}) => expire = tokens.matching(t)
}) => expire = tokens.matching(t).left
case _ =>
}
result.result()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class FormatOps(
val isDefnSite = isParamClauseSite(owner)
implicit val clauseSiteFlags: ClauseSiteFlags =
ClauseSiteFlags(owner, isDefnSite)
val bpFlags = getBinpackSiteFlags(tokens(matching(t)), start, false)
val bpFlags = getBinpackSiteFlags(matching(t), start, false)
if (bpFlags.scalaJsStyle) scalaJsOptCloseOnRight(start, bpFlags)
else if (
!start.left.is[T.RightParen] ||
Expand Down Expand Up @@ -217,7 +217,7 @@ class FormatOps(
new ExtractFromMeta[Tree](meta => statementStarts.get(meta.idx + 1))

def parensTuple(token: T): TokenRanges = matchingOpt(token)
.fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other)))
.fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other.left)))

def insideBlock[A](start: FormatToken, end: T)(implicit
classifier: Classifier[T, A],
Expand All @@ -236,18 +236,19 @@ class FormatOps(
): TokenRanges = insideBlock(x => getEndOfBlock(x, parensToo))(start, end)

def insideBlock(
matches: FormatToken => Option[T],
matches: FormatToken => Option[FormatToken],
)(start: FormatToken, end: T): TokenRanges = {
var result = TokenRanges.empty

@tailrec
def run(tok: FormatToken): Unit = if (tok.left.start < end.start) {
val nextTokOpt = matches(tok).flatMap { close =>
val nextTokOpt = matches(tok).flatMap { closeFt =>
val open = tok.left
val close = closeFt.left
if (open.start >= close.end) None
else {
result = result.append(TokenRange(open, close))
Some(tokens(close))
Some(closeFt)
}
}
val nextTok = nextTokOpt.getOrElse(next(tok))
Expand Down Expand Up @@ -791,7 +792,7 @@ class FormatOps(
val otherSplits = closeOpt.fold {
val nlSplit = Split(nlMod, 1 + breakPenalty)
Seq(nlSplit.withIndent(nlIndent).withPolicy(nlPolicy & delayedBreak))
} { close =>
} { closeFt =>
val noSingleLine = newStmtMod.isDefined || breakMany ||
rightAsInfix.exists(10 < infixSequenceLength(_))
val nextOp =
Expand All @@ -807,8 +808,9 @@ class FormatOps(
val nlSplit = Split(nlMod, 0).andPolicy(breakAfterClose)
.withIndent(nlIndent).withPolicy(nlPolicy)
val singleLineSplit = Split(spaceMod, 0).notIf(noSingleLine)
.withSingleLine(endOfNextOp.fold(close)(_.left))
.andPolicy(breakAfterClose).andPolicy(getSingleLineInfixPolicy(close))
.withSingleLine(endOfNextOp.getOrElse(closeFt).left)
.andPolicy(breakAfterClose)
.andPolicy(getSingleLineInfixPolicy(closeFt.left))
Seq(singleLineSplit, nlSplit)
}

Expand Down Expand Up @@ -1081,7 +1083,7 @@ class FormatOps(
val values = clause.values
if (
values.lengthCompare(cfg.minCount) >= 0 &&
(cfg.minSpan == 0 || cfg.minSpan < distance(ftOpen.left, close))
(cfg.minSpan == 0 || cfg.minSpan < distance(ftOpen.left, close.left))
) {
forces += ftOpen.meta.idx
values.foreach(x => clearQueues += getHead(x).meta.idx)
Expand Down Expand Up @@ -1111,7 +1113,7 @@ class FormatOps(
val lpOwner = ft.meta.leftOwner

val FormatToken(open, r, _) = ft
val close = matching(open)
val close = matching(open).left
val indentParam = Num(style.indent.getDefnSite(lpOwner))
val indentSep = Num((indentParam.n - 2).max(0))
val isBracket = open.is[T.LeftBracket]
Expand Down Expand Up @@ -1459,7 +1461,7 @@ class FormatOps(
case Some(x) => x
case None => findXmlLastLineIndent(prev(ft))
}
case t: T.Xml.SpliceEnd => findXmlLastLineIndent(tokens(matching(t), -1))
case t: T.Xml.SpliceEnd => findXmlLastLineIndent(prev(matching(t)))
case _ => findXmlLastLineIndent(prev(ft))
}

Expand All @@ -1468,8 +1470,8 @@ class FormatOps(
): Seq[Split] =
if (style.xmlLiterals.assumeFormatted) {
val end = matching(tok)
val indent = Num(findXmlLastLineIndent(tokens(end, -1)), true)
splits.map(_.withIndent(indent, end, ExpiresOn.After))
val indent = Num(findXmlLastLineIndent(prev(end)), true)
splits.map(_.withIndent(indent, end.left, ExpiresOn.After))
} else splits

def withIndentOnXmlSpliceStart(ft: FormatToken, splits: Seq[Split])(implicit
Expand All @@ -1478,7 +1480,7 @@ class FormatOps(
case t: T.Xml.SpliceStart if style.xmlLiterals.assumeFormatted =>
val end = matching(t)
val indent = Num(findXmlLastLineIndent(prev(ft)), true)
splits.map(_.withIndent(indent, end, ExpiresOn.After))
splits.map(_.withIndent(indent, end.left, ExpiresOn.After))
case _ => splits
}

Expand Down Expand Up @@ -1608,7 +1610,7 @@ class FormatOps(
_: Term.NewAnonymous => getSplits(getSpaceSplit(1))
case t: Term.ForYield => nextNonComment(bheadFT).right match { // skipping `for`
case x @ LeftParenOrBrace() =>
val exclude = TokenRanges(TokenRange(x, matching(x)))
val exclude = TokenRanges(TokenRange(x, matching(x).left))
(t.body match {
case b: Term.Block => getBracesIfEnclosed(b)
.map(x => getPolicySplits(1, getSlb(x._1.left, exclude)))
Expand Down Expand Up @@ -1754,7 +1756,7 @@ class FormatOps(
val right = nextNonComment(ft).right
val rpOpt = if (right.is[T.LeftParen]) matchingOpt(right) else None
val expire = nextNonCommentSameLine(rpOpt.fold(endFt) { rp =>
if (rp.end >= endFt.left.end) before(rp) else endFt
if (rp.left.end >= endFt.left.end) rp else endFt
})
nlSplit.withIndent(Num(style.indent.main), expire.left, ExpiresOn.After)
}
Expand Down Expand Up @@ -2180,7 +2182,7 @@ class FormatOps(
case t: Term.While => t.expr match {
case b: Term.Block
if isMultiStatBlock(b) &&
!matchingOpt(nft.right).exists(_.end >= b.pos.end) =>
!matchingOpt(nft.right).exists(_.left.end >= b.pos.end) =>
Some(new OptionalBracesRegion {
def owner = Some(t)
def splits = Some {
Expand Down Expand Up @@ -2358,7 +2360,7 @@ class FormatOps(
val nr = nft.right
t.cond match {
case b: Term.Block if (matchingOpt(nr) match {
case Some(t) => t.end < b.pos.end
case Some(t) => t.left.end < b.pos.end
case None => isMultiStatBlock(b)
}) =>
Some(new OptionalBracesRegion {
Expand Down Expand Up @@ -2473,7 +2475,7 @@ class FormatOps(
case _: T.LeftBrace => false
case _ => !isTreeSingleExpr(thenp) &&
(!before.right.is[T.LeftBrace] || matchingOpt(before.right)
.exists(rb => rb.end < thenp.pos.end))
.exists(_.left.end < thenp.pos.end))
}
}

Expand Down Expand Up @@ -2722,11 +2724,11 @@ class FormatOps(

def getEndOfBlock(ft: FormatToken, parensToo: => Boolean)(implicit
style: ScalafmtConfig,
): Option[T] = ft.left match {
): Option[FormatToken] = ft.left match {
case x: T.LeftBrace => matchingOpt(x)
case x: T.LeftParen => if (parensToo) matchingOpt(x) else None
case _ => OptionalBraces.get(ft)
.flatMap(_.rightBrace.map(x => nextNonCommentSameLine(x).left))
.flatMap(_.rightBrace.map(x => nextNonCommentSameLine(x)))
}

def isCloseDelimForTrailingCommasMultiple(ft: FormatToken): Boolean =
Expand Down Expand Up @@ -2887,7 +2889,7 @@ class FormatOps(
implicit val style: ScalafmtConfig = styleMap.at(open)
implicit val clauseSiteFlags: ClauseSiteFlags = ClauseSiteFlags
.atCallSite(ftAfterClose.meta.rightOwner)
val bpFlagsAfter = getBinpackCallSiteFlags(tokens(open), ftAfterClose)
val bpFlagsAfter = getBinpackCallSiteFlags(open, ftAfterClose)
scalaJsOptCloseOnRight(ftAfterClose, bpFlagsAfter)
} else ftBeforeClose
} else ftBeforeClose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken]
result.result()
}(arr)

private lazy val matchingParentheses: Map[TokenHash, Token] = TreeOps
.getMatchingParentheses(arr.view.map(_.right))
private lazy val matchingParentheses: Map[TokenHash, FormatToken] = TreeOps
.getMatchingParentheses(arr.view)(_.left)

override def length: Int = arr.length
override def apply(idx: Int): FormatToken = arr(idx)
Expand Down Expand Up @@ -77,19 +77,19 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken]
def next(ft: FormatToken): FormatToken = apply(ft, 1)

@inline
def matching(token: Token): Token = matchingParentheses.getOrElse(
def matching(token: Token): FormatToken = matchingParentheses.getOrElse(
FormatTokens.thash(token),
FormatTokens.throwNoToken(token, "Missing matching token index"),
)
@inline
def matchingOpt(token: Token): Option[Token] = matchingParentheses
def matchingOpt(token: Token): Option[FormatToken] = matchingParentheses
.get(FormatTokens.thash(token))
@inline
def hasMatching(token: Token): Boolean = matchingParentheses
.contains(FormatTokens.thash(token))
@inline
def areMatching(t1: Token)(t2: => Token): Boolean = matchingOpt(t1) match {
case Some(x) => x eq t2
case Some(x) => x.left eq t2
case _ => false
}

Expand All @@ -100,7 +100,7 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FormatToken]
.flatMap { head =>
matchingOpt(head.left).flatMap { other =>
val last = getLastNonTrivial(tokens, tree)
if (last.left eq other) Some((head, last)) else None
if (last eq other) Some((head, last)) else None
}
}
def getDelimsIfEnclosed(tree: Tree): Option[(FormatToken, FormatToken)] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class FormatWriter(formatOps: FormatOps) {
tok.left match {
case rb: T.RightBrace // look for "foo { bar }"
if RedundantBraces.canRewriteWithParensOnRightBrace(tok) =>
val beg = tokens(matching(rb)).meta.idx
val beg = matching(rb).idx
val bloc = locations(beg)
val style = bloc.style
if (
Expand Down
Loading

0 comments on commit 289112a

Please sign in to comment.