Skip to content

Commit

Permalink
BestFirstSearch: extract stats into separate class
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Sep 20, 2024
1 parent ac849f9 commit e8548d0
Showing 1 changed file with 76 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private val noOptZones =
if (useNoOptZones(initStyle)) getNoOptZones(tokens) else null

var explored = 0
var deepestYet = State.start
val best = mutable.Map.empty[Int, State]
val visits = new Array[Int](tokens.length)
var pruneSlowStates = initStyle.runner.optimizer.pruneSlowStates

/** Returns true if it's OK to skip over state.
*/
def shouldEnterState(curr: State): Boolean = curr.policy.noDequeue ||
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.No) ||
// TODO(olafur) document why/how this optimization works.
best.get(curr.depth).forall(curr.possiblyBetter)
var stats = new StateStats(tokens, initStyle.runner)

private def getBlockCloseToRecurse(ft: FormatToken, stop: Token)(implicit
style: ScalafmtConfig,
Expand Down Expand Up @@ -107,14 +96,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit
val noOptZone = noOptZones == null || !useNoOptZones ||
noOptZones.contains(leftTok)

if (noOptZone || shouldEnterState(curr)) {
if (explored > style.runner.maxStateVisits) {
complete(deepestYet)
throw new Error.SearchStateExploded(
deepestYet,
s"exceeded `runner.maxStateVisits`=${style.runner.maxStateVisits}",
)
}
if (noOptZone || stats.shouldEnterState(curr)) {
stats.checkExplored(splitToken)

if (curr.split != null && curr.split.isNL)
if (
Expand All @@ -134,15 +117,11 @@ private class BestFirstSearch private (range: Set[Range])(implicit
else if (
optimizer.escapeInPathologicalCases &&
isSeqMulti(routes(curr.depth)) &&
visits(curr.depth) > optimizer.maxVisitsPerToken
) {
complete(deepestYet)
throw new Error.SearchStateExploded(
deepestYet,
splitToken,
s"exceeded `runner.optimizer.maxVisitsPerToken`=${optimizer.maxVisitsPerToken}",
)
} else {
stats.visits(curr.depth) > optimizer.maxVisitsPerToken
) stats.explode(splitToken)(
s"exceeded `runner.optimizer.maxVisitsPerToken`=${optimizer.maxVisitsPerToken}",
)
else {
val actualSplit = getActiveSplits(splitToken, curr, maxCost)
val allAltAreNL = actualSplit.forall(_.isNL)

Expand All @@ -164,10 +143,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
case _ => nextState
}
if (null ne stateToQueue) {
if (
(pruneSlowStates ne ScalafmtOptimizer.PruneSlowStates.No) &&
depth == 0 && split.isNL
) best.getOrElseUpdate(curr.depth, nextState)
if (depth == 0) stats.updateBest(nextState)
enqueue(stateToQueue)
}
}
Expand Down Expand Up @@ -209,7 +185,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private def getActiveSplits(ft: FormatToken, state: State, maxCost: Int)(
implicit Q: StateQueue,
): Seq[Split] = {
trackState(state)
stats.trackState(state)
val useProvided = ft.meta.formatOff || !ft.inside(range)
val active = state.policy.execute(Decision(ft, routes(state.depth)))
.filter(x => x.isActive && x.cost <= maxCost)
Expand All @@ -225,15 +201,6 @@ private class BestFirstSearch private (range: Set[Range])(implicit
splits.sortBy(_.cost)
}

private def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
initStyle.runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
initStyle.runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

/** Follow states having single active non-newline split
*/
@tailrec
Expand Down Expand Up @@ -282,26 +249,23 @@ private class BestFirstSearch private (range: Set[Range])(implicit
case _ => state
}

private def complete(state: State)(implicit style: ScalafmtConfig): Unit =
style.runner.event(FormatEvent.CompleteFormat(explored, state, visits, best))

def getBestPath: SearchResult = {
initStyle.runner.event(FormatEvent.Routes(routes))
val state = {
val endToken = topSourceTree.tokens.last
def run = shortestPath(State.start, endToken)
val state = run
val retry = (null eq state) &&
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.Yes)
if (retry) {
pruneSlowStates = ScalafmtOptimizer.PruneSlowStates.No
if (null ne state) state
else stats.retry.fold(state) { x =>
stats = x
run
} else state
}
}
if (null != state) {
complete(state)(initStyle)
stats.complete(state)
SearchResult(state, reachedEOF = true)
} else {
val deepestYet = stats.deepestYet
val nextSplits = routes(deepestYet.depth)
val tok = tokens(deepestYet.depth)
val splitsAfterPolicy = deepestYet.policy
Expand All @@ -317,7 +281,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
s"""|Failed to format
|$msg""".stripMargin,
)
complete(deepestYet)(initStyle)
stats.complete(deepestYet)
SearchResult(deepestYet, reachedEOF = false)
}
}
Expand Down Expand Up @@ -385,4 +349,63 @@ object BestFirstSearch {
}
}

class StateStats private (
tokens: FormatTokens,
runner: ScalafmtRunner,
pruneSlowStates: ScalafmtOptimizer.PruneSlowStates,
) {
var explored = 0
var deepestYet = State.start
val best = mutable.Map.empty[Int, State]
val visits = new Array[Int](tokens.length)

def this(tokens: FormatTokens, runner: ScalafmtRunner) =
this(tokens, runner, runner.optimizer.pruneSlowStates)

/** Returns true if it's OK to skip over state.
*/
def shouldEnterState(state: State): Boolean = state.policy.noDequeue ||
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.No) ||
// TODO(olafur) document why/how this optimization works.
best.get(state.depth).forall(state.possiblyBetter)

def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

def updateBest(state: State): Boolean =
(pruneSlowStates ne ScalafmtOptimizer.PruneSlowStates.No) &&
state.split.isNL &&
(best.getOrElseUpdate(state.prev.depth, state) eq state)

def checkExplored(ft: FormatToken)(implicit
formatWriter: FormatWriter,
): Unit = if (explored > runner.maxStateVisits)
explode(ft)(s"exceeded `runner.maxStateVisits`=${runner.maxStateVisits}")

def explode(
ft: FormatToken,
)(msg: String)(implicit formatWriter: FormatWriter): Unit = {
complete(deepestYet)
throw new Error.SearchStateExploded(deepestYet, ft, msg)
}

def complete(state: State): Unit = runner
.event(FormatEvent.CompleteFormat(explored, state, visits, best))

def retry: Option[StateStats] = {
val ok = best.nonEmpty &&
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.Yes)
if (ok) Some(
new StateStats(tokens, runner, ScalafmtOptimizer.PruneSlowStates.No),
)
else None
}
}

}

0 comments on commit e8548d0

Please sign in to comment.