diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala index ac78ce659c..2606a3e438 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala @@ -30,6 +30,7 @@ private class ConvertToNewScala3Syntax(ftoks: FormatTokens) override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = Option { val flag = style.rewrite.scala3.newSyntax @@ -108,6 +109,7 @@ private class ConvertToNewScala3Syntax(ftoks: FormatTokens) override def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = Option { def nextRight = ftoks.nextNonComment(ftoks.next(ft)).right diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala index ee5a6b2fdb..511ba239ec 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala @@ -117,9 +117,8 @@ class FormatTokensRewrite( * - for standalone tokens, simply invoke the rule and record any rewrites */ private def getRewrittenTokens: Iterable[Replacement] = { - implicit val claimed = new mutable.HashMap[Int, Rule]() - def claimedRule(implicit ft: FormatToken) = claimed.get(ft.meta.idx) - implicit val tokens = new mutable.ArrayBuffer[Replacement]() + implicit val session: Session = new Session + val tokens = session.tokens val leftDelimIndex = new mutable.ListBuffer[(Int, Option[Rule])]() val formatOffStack = new mutable.ListBuffer[Boolean]() arr.foreach { implicit ft => @@ -131,7 +130,7 @@ class FormatTokensRewrite( val ruleOpt = if (formatOff) None else - claimedRule match { + session.claimedRule match { case x @ Some(rule) => if (applyRule(rule)) x else None case _ => applyRules } @@ -145,7 +144,8 @@ class FormatTokensRewrite( formatOffStack.update(0, true) val replacement = ruleOpt match { case Some(rule) - if !ft.meta.formatOff && claimedRule.forall(_ eq rule) => + if !ft.meta.formatOff && + session.claimedRule.forall(_ eq rule) => implicit val style = styleMap.at(ft.right) if (rule.enabled) rule.onRight(tokens(ldelimIdx), formatOff) else None @@ -178,20 +178,18 @@ class FormatTokensRewrite( private def applyRules(implicit ft: FormatToken, - claimed: mutable.HashMap[Int, Rule], - tokens: mutable.ArrayBuffer[Replacement] + session: Session ): Option[Rule] = { implicit val style = styleMap.at(ft.right) - FormatTokensRewrite.applyRules(rules) + session.applyRules(rules) } private def applyRule(rule: Rule)(implicit ft: FormatToken, - claimed: mutable.HashMap[Int, Rule], - tokens: mutable.ArrayBuffer[Replacement] + session: Session ): Boolean = { implicit val style = styleMap.at(ft.right) - FormatTokensRewrite.applyRule(rule) + session.applyRule(rule) } } @@ -210,11 +208,13 @@ object FormatTokensRewrite { // act on or modify only ft.right; process standalone or open (left) delim def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] // act on or modify only ft.right; process close (right) delim def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] } @@ -239,6 +239,39 @@ object FormatTokensRewrite { else new FormatTokensRewrite(ftoks, styleMap, rules).rewrite } + private[rewrite] class Session { + private implicit val implicitSession: Session = this + private val claimed = new mutable.HashMap[Int, Rule]() + private[FormatTokensRewrite] val tokens = + new mutable.ArrayBuffer[Replacement]() + + def claimedRule(implicit ft: FormatToken): Option[Rule] = + claimed.get(ft.meta.idx) + + private[FormatTokensRewrite] def applyRule( + rule: Rule + )(implicit ft: FormatToken, style: ScalafmtConfig): Boolean = + rule.enabled && (rule.onToken match { + case Some(repl) => + repl.claim.foreach { claimed.getOrElseUpdate(_, rule) } + tokens.append(repl) + true + case _ => false + }) + + private[FormatTokensRewrite] def applyRules( + rules: Seq[Rule] + )(implicit ft: FormatToken, style: ScalafmtConfig): Option[Rule] = { + @tailrec + def iter(remainingRules: Seq[Rule]): Option[Rule] = remainingRules match { + case r +: rs => if (applyRule(r)) Some(r) else iter(rs) + case _ => None + } + iter(rules) + } + + } + private[rewrite] class Replacement( val ft: FormatToken, val how: ReplacementType, @@ -313,38 +346,4 @@ object FormatTokensRewrite { } } - private def onReplacement(repl: Replacement, rule: Rule)(implicit - claimed: mutable.HashMap[Int, Rule], - tokens: mutable.ArrayBuffer[Replacement] - ): Unit = { - repl.claim.foreach { claimed.getOrElseUpdate(_, rule) } - tokens.append(repl) - } - - private def applyRule(rule: Rule)(implicit - ft: FormatToken, - style: ScalafmtConfig, - claimed: mutable.HashMap[Int, Rule], - tokens: mutable.ArrayBuffer[Replacement] - ): Boolean = - rule.enabled && { - val res = rule.onToken - res.foreach(onReplacement(_, rule)) - res.isDefined - } - - private def applyRules(rules: Seq[Rule])(implicit - ft: FormatToken, - style: ScalafmtConfig, - claimed: mutable.HashMap[Int, Rule], - tokens: mutable.ArrayBuffer[Replacement] - ): Option[Rule] = { - @tailrec - def iter(remainingRules: Seq[Rule]): Option[Rule] = remainingRules match { - case r +: rs => if (applyRule(r)) Some(r) else iter(rs) - case _ => None - } - iter(rules) - } - } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/PreferCurlyFors.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/PreferCurlyFors.scala index 350d6dacd3..2fa1d6260e 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/PreferCurlyFors.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/PreferCurlyFors.scala @@ -61,6 +61,7 @@ private class PreferCurlyFors(ftoks: FormatTokens) override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = Option { ft.right match { @@ -91,6 +92,7 @@ private class PreferCurlyFors(ftoks: FormatTokens) override def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = ft.right match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 959c32870c..ada3783c55 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -70,6 +70,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = Option { ft.right match { @@ -81,6 +82,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { override def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = Option { ft.right match { @@ -133,6 +135,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { private def onLeftBrace(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Replacement = { onLeftBrace(ft.meta.rightOwner) @@ -141,6 +144,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { @tailrec private def onLeftBrace(owner: Tree)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Replacement = { owner match { @@ -285,9 +289,11 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { case _ => false } - private def processBlock( - b: Term.Block - )(implicit ft: FormatToken, style: ScalafmtConfig): Boolean = + private def processBlock(b: Term.Block)(implicit + ft: FormatToken, + session: Session, + style: ScalafmtConfig + ): Boolean = (ft.right match { case lb: Token.LeftBrace => b.tokens.headOption.contains(lb) && b.tokens.last.is[Token.RightBrace] @@ -327,7 +333,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { private def okToRemoveBlock( b: Term.Block - )(implicit style: ScalafmtConfig): Boolean = { + )(implicit style: ScalafmtConfig, session: Session): Boolean = { b.parent.exists { case p: Case => @@ -425,7 +431,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { /** Some blocks look redundant but aren't */ private def shouldRemoveSingleStatBlock( b: Term.Block - )(implicit style: ScalafmtConfig): Boolean = + )(implicit style: ScalafmtConfig, session: Session): Boolean = getSingleStatIfLineSpanOk(b).exists { stat => @tailrec def checkParent(tree: Tree): Boolean = tree match { @@ -442,7 +448,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { ftoks.matchingOpt(x) match { case Some(y) if y ne stat.tokens.last => redundantParensFunc.exists { parensRule => - parensRule.onToken(ftoks(x, -1), style).exists { + parensRule.onToken(ftoks(x, -1), session, style).exists { _.how eq ReplacementType.Remove } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala index 99713d6b47..3fc07a1e4f 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala @@ -71,6 +71,7 @@ class RedundantParens(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = ft.right match { @@ -83,6 +84,7 @@ class RedundantParens(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { override def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = ft.right match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveEmptyDocstrings.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveEmptyDocstrings.scala index 3d76e3261a..2d590388e6 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveEmptyDocstrings.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveEmptyDocstrings.scala @@ -21,6 +21,7 @@ object RemoveEmptyDocstrings override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = { val skip = ft.right.is[Token.Comment] && @@ -30,6 +31,7 @@ object RemoveEmptyDocstrings override def onRight(lt: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = None diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala index 32389590f5..715b9ad1e1 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala @@ -36,6 +36,7 @@ private class RemoveScala3OptionalBraces(ftoks: FormatTokens) override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = Option { ft.right match { @@ -67,6 +68,7 @@ private class RemoveScala3OptionalBraces(ftoks: FormatTokens) override def onRight(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = ft.right match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala index 9ed755ebb5..dde0d91c4f 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala @@ -34,6 +34,7 @@ private class RewriteTrailingCommas(ftoks: FormatTokens) override def onToken(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[Replacement] = { val ok = ft.right.is[Token.Comma] && { @@ -59,6 +60,7 @@ private class RewriteTrailingCommas(ftoks: FormatTokens) override def onRight(lt: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, + session: Session, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = None