diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index ef7da3a143..c52983d3f2 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -470,6 +470,7 @@ class FormatOps( def insideInfixSplit( app: Member.Infix, + isBeforeOp: Boolean, ft: FormatToken )(implicit style: ScalafmtConfig): Seq[Split] = app match { @@ -477,9 +478,17 @@ class FormatOps( if style.spaces.neverAroundInfixTypes.contains(t.op.value) => Seq(Split(NoSplit, 0)) case t => + def useSpace = isBeforeOp || + style.spaces.beforeInfixArgInParens(app.op.value) || + !(app.arg match { + case _: Lit.Unit => true + case x: Member.ArgClause if x.values.lengthCompare(1) != 0 => true + case x => isEnclosedInParens(x) + }) val afterInfix = style.breakAfterInfix(t) if (afterInfix ne Newlines.AfterInfix.keep) { - if (ft.meta.leftOwner ne app.op) Seq(Split(Space, 0)) + val spaceMod = Space(useSpace) + if (ft.meta.leftOwner ne app.op) Seq(Split(spaceMod, 0)) else { val fullInfix = InfixSplits.findEnclosingInfix(app) val ok = isEnclosedInParens(fullInfix) || fullInfix.parent.forall { @@ -488,13 +497,16 @@ class FormatOps( case _ => true } if (ok) - InfixSplits(app, ft, fullInfix).getBeforeLhsOrRhs(afterInfix) - else Seq(Split(Space, 0)) + InfixSplits(app, ft, fullInfix) + .getBeforeLhsOrRhs(afterInfix, spaceMod = spaceMod) + else Seq(Split(spaceMod, 0)) } } else { // we don't modify line breaks generally around infix expressions // TODO: if that ever changes, modify how rewrite rules handle infix - Seq(InfixSplits.withNLIndent(Split(getMod(ft), 0))(app, ft)) + val mod = getMod(ft) + val modOrNoSplit = if (mod != Space || useSpace) mod else NoSplit + Seq(InfixSplits.withNLIndent(Split(modOrNoSplit, 0))(app, ft)) } } @@ -700,7 +712,8 @@ class FormatOps( def getBeforeLhsOrRhs( afterInfix: Newlines.AfterInfix, - newStmtMod: Option[Modification] = None + newStmtMod: Option[Modification] = None, + spaceMod: Modification = Space ): Seq[Split] = { val beforeLhs = ft.meta.leftOwner ne app.op val maxPrecedence = @@ -768,7 +781,7 @@ class FormatOps( .withSingleLine(singleLineExpire) .andPolicyOpt(singleLinePolicy) .andPolicyOpt(delayedBreak) - val spaceSingleLine = Split(Space, 0) + val spaceSingleLine = Split(spaceMod, 0) .onlyIf(newStmtMod.isEmpty) .withSingleLine(singleLineExpire) .andPolicyOpt(singleLinePolicy) @@ -799,7 +812,7 @@ class FormatOps( .andPolicyOpt(breakAfterClose) .withIndent(nlIndent) .withPolicy(nlPolicy) - val singleLineSplit = Split(Space, 0) + val singleLineSplit = Split(spaceMod, 0) .notIf(noSingleLine) .withSingleLine(endOfNextOp.fold(close)(_.left)) .andPolicyOpt(breakAfterClose) @@ -817,7 +830,7 @@ class FormatOps( val exclude = if (breakMany) TokenRanges.empty else insideBracesBlock(nextFT, expire, true) - Split(ModExt(newStmtMod.getOrElse(Space)), cost) + Split(ModExt(newStmtMod.getOrElse(spaceMod)), cost) .withSingleLine(expire, exclude) } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index 4410c3693c..fdbe16d8b6 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -682,6 +682,8 @@ class Router(formatOps: FormatOps) { case Term.Name(name) => style.spaces.afterTripleEquals && name == "===" || (rightOwner match { + case _: Term.ArgClause => + style.spaces.beforeApplyArgInParens(name) case _: Member.ParamClause => style.spaces.afterSymbolicDefs && isSymbolicName(name) case _ => false @@ -1944,7 +1946,7 @@ class Router(formatOps: FormatOps) { Seq( Split(NoSplit, 0) ) - case FormatToken(op @ T.Ident(_), right, _) if leftOwner.parent.exists { + case FormatToken(op: T.Ident, right, _) if leftOwner.parent.exists { case unary: Term.ApplyUnary => unary.op.tokens.head == op case _ => false @@ -2275,9 +2277,9 @@ class Router(formatOps: FormatOps) { // Infix operator. case FormatToken(_: T.Ident, _, FormatToken.LeftOwner(AsInfixOp(app))) => - insideInfixSplit(app, formatToken) + insideInfixSplit(app, false, formatToken) case FormatToken(_, _: T.Ident, FormatToken.RightOwner(AsInfixOp(app))) => - insideInfixSplit(app, formatToken) + insideInfixSplit(app, true, formatToken) // Case case FormatToken(_: T.KwCase, _, _) diff --git a/scalafmt-tests/src/test/resources/unit/Apply.stat b/scalafmt-tests/src/test/resources/unit/Apply.stat index 4393e9b15c..01d5703d1e 100644 --- a/scalafmt-tests/src/test/resources/unit/Apply.stat +++ b/scalafmt-tests/src/test/resources/unit/Apply.stat @@ -79,14 +79,14 @@ object a { >>> object a { +() - ===() - bar() + === () + bar () +(baz) - ===(baz) - bar(baz) + === (baz) + bar (baz) +(baz, qux) - ===(baz, qux) - bar(baz, qux) + === (baz, qux) + bar (baz, qux) } <<< #3607 beforeApplyArgInParens=never spaces.beforeApplyArgInParens = never @@ -131,12 +131,12 @@ object a { >>> object a { +() - ===() + === () bar() +(baz) - ===(baz) + === (baz) bar(baz) +(baz, qux) - ===(baz, qux) + === (baz, qux) bar(baz, qux) } diff --git a/scalafmt-tests/src/test/resources/unit/ApplyInfix.stat b/scalafmt-tests/src/test/resources/unit/ApplyInfix.stat index f076a268c5..ceba30c46f 100644 --- a/scalafmt-tests/src/test/resources/unit/ApplyInfix.stat +++ b/scalafmt-tests/src/test/resources/unit/ApplyInfix.stat @@ -235,15 +235,15 @@ object a { foo + baz foo === baz foo bar baz - foo + () - foo === () - foo bar () - foo + (baz) - foo === (baz) - foo bar (baz) - foo + (baz, qux) - foo === (baz, qux) - foo bar (baz, qux) + foo +() + foo ===() + foo bar() + foo +(baz) + foo ===(baz) + foo bar(baz) + foo +(baz, qux) + foo ===(baz, qux) + foo bar(baz, qux) } <<< #3607 beforeInfixArgInParens=aftersymbolic spaces.beforeInfixArgInParens = aftersymbolic @@ -269,11 +269,11 @@ object a { foo bar baz foo + () foo === () - foo bar () + foo bar() foo + (baz) foo === (baz) - foo bar (baz) + foo bar(baz) foo + (baz, qux) foo === (baz, qux) - foo bar (baz, qux) + foo bar(baz, qux) }