Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BestFirstSearch: extract queues into a new class #4295

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,12 @@ private class BestFirstSearch private (range: Set[Range])(implicit
depth: Int = 0,
maxCost: Int = Integer.MAX_VALUE,
): State = {
def newGeneration = new mutable.PriorityQueue[State]()
var Q = newGeneration
var generations: List[mutable.PriorityQueue[State]] = Nil
def addGeneration() = if (Q.nonEmpty) {
generations = Q :: generations
Q = newGeneration
}
Q += start

implicit val Q: StateQueue = new StateQueue(depth)
def enqueue(state: State) = Q.enqueue(state)
enqueue(start)

// TODO(olafur) this while loop is waaaaaaaaaaaaay tooo big.
while (true) {
while (!Q.isEmpty()) {
val curr = Q.dequeue()
if (curr.depth >= tokens.length) return curr

Expand All @@ -115,8 +108,6 @@ private class BestFirstSearch private (range: Set[Range])(implicit
noOptZones.contains(leftTok)

if (noOptZone || shouldEnterState(curr)) {
trackState(curr, depth, Q.length)

if (explored > style.runner.maxStateVisits) {
complete(deepestYet)
throw new Error.SearchStateExploded(
Expand All @@ -131,7 +122,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
optimizer.dequeueOnNewStatements && curr.allAltAreNL &&
!(depth == 0 && noOptZone) &&
(leftTok.is[Token.KwElse] || statementStarts.contains(curr.depth))
) addGeneration()
) Q.addGeneration()

val noBlockClose = start == curr && 0 != maxCost || !noOptZone ||
!optimizer.recurseOnBlocks
Expand Down Expand Up @@ -168,7 +159,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
else shortestPath(nextState, opt.token, depth + 1, maxCost = 0)
val furtherState =
if (null == nextNextState) null
else traverseSameLine(nextNextState, depth)
else traverseSameLine(nextNextState)
if (null == furtherState) if (killOnFail(opt)) null else nextState
else if (
furtherState.appliedPenalty > nextNextState.appliedPenalty
Expand Down Expand Up @@ -202,24 +193,15 @@ private class BestFirstSearch private (range: Set[Range])(implicit
}
}
}

if (Q.isEmpty) {
if (generations.isEmpty) return null

Q = generations.head
generations = generations.tail
}
}

// unreachable
null
}

private def getActiveSplits(
ft: FormatToken,
state: State,
maxCost: Int,
private def getActiveSplits(ft: FormatToken, state: State, maxCost: Int)(
implicit Q: StateQueue,
): Seq[Split] = {
trackState(state)
val useProvided = ft.meta.formatOff || !ft.inside(range)
val active = state.policy.execute(Decision(ft, routes(state.depth)))
.filter(x => x.isActive && x.cost <= maxCost)
Expand All @@ -235,38 +217,38 @@ private class BestFirstSearch private (range: Set[Range])(implicit
splits.sortBy(_.cost)
}

private def trackState(state: State, depth: Int, queueSize: Int)(implicit
style: ScalafmtConfig,
): Unit = {
if (state.depth > deepestYet.depth) deepestYet = state
style.runner.event(FormatEvent.VisitToken(tokens(state.depth)))
visits(state.depth) += 1
private def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
initStyle.runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
style.runner.event(FormatEvent.Explored(explored, depth, queueSize))
initStyle.runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

/** Follow states having single active non-newline split
*/
@tailrec
private def traverseSameLine(state: State, depth: Int): State =
private def traverseSameLine(
state: State,
)(implicit queue: StateQueue): State =
if (state.depth >= tokens.length) state
else {
val splitToken = tokens(state.depth)
implicit val style: ScalafmtConfig = styleMap.at(splitToken)
trackState(state, depth, 0)
getActiveSplits(splitToken, state, Int.MaxValue) match {
case Seq() => null // dead end if empty
case Seq(split) =>
if (split.isNL) state
else {
style.runner.event(FormatEvent.Enqueue(split))
val nextState = state.next(split, nextAllAltAreNL = false)
traverseSameLine(nextState, depth)
traverseSameLine(nextState)
}
case ss
if state.appliedPenalty == 0 &&
RightParenOrBracket(splitToken.right) =>
traverseSameLineZeroCost(ss.filter(_.cost == 0), state, depth)
traverseSameLineZeroCost(ss.filter(_.cost == 0), state)
case _ => state
}
}
Expand All @@ -275,8 +257,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private def traverseSameLineZeroCost(
splits: Seq[Split],
state: State,
depth: Int,
)(implicit style: ScalafmtConfig): State = splits match {
)(implicit style: ScalafmtConfig, queue: StateQueue): State = splits match {
case Seq(split) if !split.isNL =>
style.runner.event(FormatEvent.Enqueue(split))
val nextState = state.next(split, nextAllAltAreNL = false)
Expand All @@ -286,9 +267,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit
val nextToken = tokens(nextState.depth)
if (RightParenOrBracket(nextToken.right)) {
implicit val style: ScalafmtConfig = styleMap.at(nextToken)
trackState(nextState, depth, 0)
val nextSplits = getActiveSplits(nextToken, nextState, maxCost = 0)
traverseSameLineZeroCost(nextSplits, nextState, depth)
traverseSameLineZeroCost(nextSplits, nextState)
} else nextState
}
case _ => state
Expand Down Expand Up @@ -374,4 +354,29 @@ object BestFirstSearch {
private def useNoOptZones(implicit style: ScalafmtConfig): Boolean =
style.runner.optimizer.disableOptimizationsInsideSensitiveAreas

class StateQueue(val nested: Int)(implicit stateOrdering: Ordering[State]) {
private def newGeneration = new mutable.PriorityQueue[State]()
var generation: mutable.PriorityQueue[State] = newGeneration
var generations: List[mutable.PriorityQueue[State]] = Nil

def addGeneration(): Unit = if (generation.nonEmpty) {
generations = generation :: generations
generation = newGeneration
}

def dequeue(): State = generation.dequeue()
def enqueue(state: State): Unit = generation.enqueue(state)
def length: Int = generation.length
@tailrec
final def isEmpty(): Boolean = generation.isEmpty && {
generations match {
case head :: tail =>
generation = head
generations = tail
isEmpty()
case _ => true
}
}
}

}
Loading