diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 3e76d3c978..12946b11b3 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -253,10 +253,17 @@ class RedundantBraces(implicit val ftoks: FormatTokens) owner match { case t: Term.FunctionTerm if t.tokens.last.is[Token.RightBrace] => if (!okToRemoveFunctionInApplyOrInit(t)) null else removeToken - case t: Term.PartialFunction if t.parent.exists { p => - SingleArgInBraces.orBlock(p).exists(_._2 eq t) && - t.pos.start != p.pos.start - } => removeToken + case t: Term.PartialFunction => t.parent match { + case Some(SingleArgInBraces.OrBlock(lft, `t`, _)) + if lft.left ne ft.right => + val ok = ftoks.findTokenWith(lft, ftoks.prev) { xft => + if (!xft.left.is[Token.LeftBrace]) Some(false) + else if (session.isRemovedOnLeft(xft, ok = true)) None + else Some(true) + }.contains(true) + if (ok) removeToken else null + case _ => null + } case t: Term.Block => t.parent match { case Some(f: Term.FunctionTerm) if okToReplaceFunctionInSingleArgApply(f) => removeToken diff --git a/scalafmt-tests/shared/src/test/resources/rewrite/RedundantBraces-if.stat b/scalafmt-tests/shared/src/test/resources/rewrite/RedundantBraces-if.stat index 35fe9ca762..14dd45356d 100644 --- a/scalafmt-tests/shared/src/test/resources/rewrite/RedundantBraces-if.stat +++ b/scalafmt-tests/shared/src/test/resources/rewrite/RedundantBraces-if.stat @@ -393,3 +393,23 @@ object a: end if else checkJSNativeLoadSpecOf(treePos, sym) end if +<<< #4133 partial function within if-else +val toIterator: Any => Iterator[_] = if (lenient) { + { + case i: scala.collection.Iterable[_] => i.iterator + case l: java.util.List[_] => l.iterator().asScala + case a: Array[_] => a.iterator + case o => unsupportedCollectionType(o.getClass) + } +} else { + unsupportedCollectionType(tag.runtimeClass) +} +>>> +val toIterator: Any => Iterator[_] = if (lenient) { + case i: scala.collection.Iterable[_] => i.iterator + case l: java.util.List[_] => l.iterator().asScala + case a: Array[_] => a.iterator + case o => unsupportedCollectionType(o.getClass) +} +else + unsupportedCollectionType(tag.runtimeClass)