diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala index 8c514c6987..2551578cf2 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala @@ -87,19 +87,12 @@ private class BestFirstSearch private (range: Set[Range])(implicit depth: Int = 0, maxCost: Int = Integer.MAX_VALUE, ): State = { - def newGeneration = new mutable.PriorityQueue[State]() - var Q = newGeneration - var generations: List[mutable.PriorityQueue[State]] = Nil - def addGeneration() = if (Q.nonEmpty) { - generations = Q :: generations - Q = newGeneration - } - Q += start - + implicit val Q: StateQueue = new StateQueue(depth) def enqueue(state: State) = Q.enqueue(state) + enqueue(start) // TODO(olafur) this while loop is waaaaaaaaaaaaay tooo big. - while (true) { + while (!Q.isEmpty()) { val curr = Q.dequeue() if (curr.depth >= tokens.length) return curr @@ -115,8 +108,6 @@ private class BestFirstSearch private (range: Set[Range])(implicit noOptZones.contains(leftTok) if (noOptZone || shouldEnterState(curr)) { - trackState(curr, depth, Q.length) - if (explored > style.runner.maxStateVisits) { complete(deepestYet) throw new Error.SearchStateExploded( @@ -131,7 +122,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit optimizer.dequeueOnNewStatements && curr.allAltAreNL && !(depth == 0 && noOptZone) && (leftTok.is[Token.KwElse] || statementStarts.contains(curr.depth)) - ) addGeneration() + ) Q.addGeneration() val noBlockClose = start == curr && 0 != maxCost || !noOptZone || !optimizer.recurseOnBlocks @@ -168,7 +159,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit else shortestPath(nextState, opt.token, depth + 1, maxCost = 0) val furtherState = if (null == nextNextState) null - else traverseSameLine(nextNextState, depth) + else traverseSameLine(nextNextState) if (null == furtherState) if (killOnFail(opt)) null else nextState else if ( furtherState.appliedPenalty > nextNextState.appliedPenalty @@ -202,24 +193,15 @@ private class BestFirstSearch private (range: Set[Range])(implicit } } } - - if (Q.isEmpty) { - if (generations.isEmpty) return null - - Q = generations.head - generations = generations.tail - } } - // unreachable null } - private def getActiveSplits( - ft: FormatToken, - state: State, - maxCost: Int, + private def getActiveSplits(ft: FormatToken, state: State, maxCost: Int)( + implicit Q: StateQueue, ): Seq[Split] = { + trackState(state) val useProvided = ft.meta.formatOff || !ft.inside(range) val active = state.policy.execute(Decision(ft, routes(state.depth))) .filter(x => x.isActive && x.cost <= maxCost) @@ -235,25 +217,25 @@ private class BestFirstSearch private (range: Set[Range])(implicit splits.sortBy(_.cost) } - private def trackState(state: State, depth: Int, queueSize: Int)(implicit - style: ScalafmtConfig, - ): Unit = { - if (state.depth > deepestYet.depth) deepestYet = state - style.runner.event(FormatEvent.VisitToken(tokens(state.depth))) - visits(state.depth) += 1 + private def trackState(state: State)(implicit Q: StateQueue): Unit = { + val idx = state.depth + if (idx > deepestYet.depth) deepestYet = state + initStyle.runner.event(FormatEvent.VisitToken(tokens(idx))) + visits(idx) += 1 explored += 1 - style.runner.event(FormatEvent.Explored(explored, depth, queueSize)) + initStyle.runner.event(FormatEvent.Explored(explored, Q.nested, Q.length)) } /** Follow states having single active non-newline split */ @tailrec - private def traverseSameLine(state: State, depth: Int): State = + private def traverseSameLine( + state: State, + )(implicit queue: StateQueue): State = if (state.depth >= tokens.length) state else { val splitToken = tokens(state.depth) implicit val style: ScalafmtConfig = styleMap.at(splitToken) - trackState(state, depth, 0) getActiveSplits(splitToken, state, Int.MaxValue) match { case Seq() => null // dead end if empty case Seq(split) => @@ -261,12 +243,12 @@ private class BestFirstSearch private (range: Set[Range])(implicit else { style.runner.event(FormatEvent.Enqueue(split)) val nextState = state.next(split, nextAllAltAreNL = false) - traverseSameLine(nextState, depth) + traverseSameLine(nextState) } case ss if state.appliedPenalty == 0 && RightParenOrBracket(splitToken.right) => - traverseSameLineZeroCost(ss.filter(_.cost == 0), state, depth) + traverseSameLineZeroCost(ss.filter(_.cost == 0), state) case _ => state } } @@ -275,8 +257,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit private def traverseSameLineZeroCost( splits: Seq[Split], state: State, - depth: Int, - )(implicit style: ScalafmtConfig): State = splits match { + )(implicit style: ScalafmtConfig, queue: StateQueue): State = splits match { case Seq(split) if !split.isNL => style.runner.event(FormatEvent.Enqueue(split)) val nextState = state.next(split, nextAllAltAreNL = false) @@ -286,9 +267,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit val nextToken = tokens(nextState.depth) if (RightParenOrBracket(nextToken.right)) { implicit val style: ScalafmtConfig = styleMap.at(nextToken) - trackState(nextState, depth, 0) val nextSplits = getActiveSplits(nextToken, nextState, maxCost = 0) - traverseSameLineZeroCost(nextSplits, nextState, depth) + traverseSameLineZeroCost(nextSplits, nextState) } else nextState } case _ => state @@ -374,4 +354,29 @@ object BestFirstSearch { private def useNoOptZones(implicit style: ScalafmtConfig): Boolean = style.runner.optimizer.disableOptimizationsInsideSensitiveAreas + class StateQueue(val nested: Int)(implicit stateOrdering: Ordering[State]) { + private def newGeneration = new mutable.PriorityQueue[State]() + var generation: mutable.PriorityQueue[State] = newGeneration + var generations: List[mutable.PriorityQueue[State]] = Nil + + def addGeneration(): Unit = if (generation.nonEmpty) { + generations = generation :: generations + generation = newGeneration + } + + def dequeue(): State = generation.dequeue() + def enqueue(state: State): Unit = generation.enqueue(state) + def length: Int = generation.length + @tailrec + final def isEmpty(): Boolean = generation.isEmpty && { + generations match { + case head :: tail => + generation = head + generations = tail + isEmpty() + case _ => true + } + } + } + }