Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BestFirstSearch: extract stats into separate class #4302

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private val noOptZones =
if (useNoOptZones(initStyle)) getNoOptZones(tokens) else null

var explored = 0
var deepestYet = State.start
val best = mutable.Map.empty[Int, State]
val visits = new Array[Int](tokens.length)
var pruneSlowStates = initStyle.runner.optimizer.pruneSlowStates

/** Returns true if it's OK to skip over state.
*/
def shouldEnterState(curr: State): Boolean = curr.policy.noDequeue ||
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.No) ||
// TODO(olafur) document why/how this optimization works.
best.get(curr.depth).forall(curr.possiblyBetter)
var stats = new StateStats(tokens, initStyle.runner)

private def getBlockCloseToRecurse(ft: FormatToken, stop: Token)(implicit
style: ScalafmtConfig,
Expand Down Expand Up @@ -107,14 +96,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit
val noOptZone = noOptZones == null || !useNoOptZones ||
noOptZones.contains(leftTok)

if (noOptZone || shouldEnterState(curr)) {
if (explored > style.runner.maxStateVisits) {
complete(deepestYet)
throw new Error.SearchStateExploded(
deepestYet,
s"exceeded `runner.maxStateVisits`=${style.runner.maxStateVisits}",
)
}
if (noOptZone || stats.shouldEnterState(curr)) {
stats.checkExplored(splitToken)

if (curr.split != null && curr.split.isNL)
if (
Expand All @@ -134,15 +117,11 @@ private class BestFirstSearch private (range: Set[Range])(implicit
else if (
optimizer.escapeInPathologicalCases &&
isSeqMulti(routes(curr.depth)) &&
visits(curr.depth) > optimizer.maxVisitsPerToken
) {
complete(deepestYet)
throw new Error.SearchStateExploded(
deepestYet,
splitToken,
s"exceeded `runner.optimizer.maxVisitsPerToken`=${optimizer.maxVisitsPerToken}",
)
} else {
stats.visits(curr.depth) > optimizer.maxVisitsPerToken
) stats.explode(splitToken)(
s"exceeded `runner.optimizer.maxVisitsPerToken`=${optimizer.maxVisitsPerToken}",
)
else {
val actualSplit = getActiveSplits(splitToken, curr, maxCost)
val allAltAreNL = actualSplit.forall(_.isNL)

Expand All @@ -164,10 +143,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
case _ => nextState
}
if (null ne stateToQueue) {
if (
(pruneSlowStates ne ScalafmtOptimizer.PruneSlowStates.No) &&
depth == 0 && split.isNL
) best.getOrElseUpdate(curr.depth, nextState)
if (depth == 0) stats.updateBest(nextState)
enqueue(stateToQueue)
}
}
Expand Down Expand Up @@ -209,7 +185,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
private def getActiveSplits(ft: FormatToken, state: State, maxCost: Int)(
implicit Q: StateQueue,
): Seq[Split] = {
trackState(state)
stats.trackState(state)
val useProvided = ft.meta.formatOff || !ft.inside(range)
val active = state.policy.execute(Decision(ft, routes(state.depth)))
.filter(x => x.isActive && x.cost <= maxCost)
Expand All @@ -225,15 +201,6 @@ private class BestFirstSearch private (range: Set[Range])(implicit
splits.sortBy(_.cost)
}

private def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
initStyle.runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
initStyle.runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

/** Follow states having single active non-newline split
*/
@tailrec
Expand Down Expand Up @@ -282,26 +249,23 @@ private class BestFirstSearch private (range: Set[Range])(implicit
case _ => state
}

private def complete(state: State)(implicit style: ScalafmtConfig): Unit =
style.runner.event(FormatEvent.CompleteFormat(explored, state, visits, best))

def getBestPath: SearchResult = {
initStyle.runner.event(FormatEvent.Routes(routes))
val state = {
val endToken = topSourceTree.tokens.last
def run = shortestPath(State.start, endToken)
val state = run
val retry = (null eq state) &&
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.Yes)
if (retry) {
pruneSlowStates = ScalafmtOptimizer.PruneSlowStates.No
if (null ne state) state
else stats.retry.fold(state) { x =>
stats = x
run
} else state
}
}
if (null != state) {
complete(state)(initStyle)
stats.complete(state)
SearchResult(state, reachedEOF = true)
} else {
val deepestYet = stats.deepestYet
val nextSplits = routes(deepestYet.depth)
val tok = tokens(deepestYet.depth)
val splitsAfterPolicy = deepestYet.policy
Expand All @@ -317,7 +281,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
s"""|Failed to format
|$msg""".stripMargin,
)
complete(deepestYet)(initStyle)
stats.complete(deepestYet)
SearchResult(deepestYet, reachedEOF = false)
}
}
Expand Down Expand Up @@ -385,4 +349,63 @@ object BestFirstSearch {
}
}

class StateStats private (
tokens: FormatTokens,
runner: ScalafmtRunner,
pruneSlowStates: ScalafmtOptimizer.PruneSlowStates,
) {
var explored = 0
var deepestYet = State.start
val best = mutable.Map.empty[Int, State]
val visits = new Array[Int](tokens.length)

def this(tokens: FormatTokens, runner: ScalafmtRunner) =
this(tokens, runner, runner.optimizer.pruneSlowStates)

/** Returns true if it's OK to skip over state.
*/
def shouldEnterState(state: State): Boolean = state.policy.noDequeue ||
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.No) ||
// TODO(olafur) document why/how this optimization works.
best.get(state.depth).forall(state.possiblyBetter)

def trackState(state: State)(implicit Q: StateQueue): Unit = {
val idx = state.depth
if (idx > deepestYet.depth) deepestYet = state
runner.event(FormatEvent.VisitToken(tokens(idx)))
visits(idx) += 1
explored += 1
runner.event(FormatEvent.Explored(explored, Q.nested, Q.length))
}

def updateBest(state: State): Boolean =
(pruneSlowStates ne ScalafmtOptimizer.PruneSlowStates.No) &&
state.split.isNL &&
(best.getOrElseUpdate(state.prev.depth, state) eq state)

def checkExplored(ft: FormatToken)(implicit
formatWriter: FormatWriter,
): Unit = if (explored > runner.maxStateVisits)
explode(ft)(s"exceeded `runner.maxStateVisits`=${runner.maxStateVisits}")

def explode(
ft: FormatToken,
)(msg: String)(implicit formatWriter: FormatWriter): Unit = {
complete(deepestYet)
throw new Error.SearchStateExploded(deepestYet, ft, msg)
}

def complete(state: State): Unit = runner
.event(FormatEvent.CompleteFormat(explored, state, visits, best))

def retry: Option[StateStats] = {
val ok = best.nonEmpty &&
(pruneSlowStates eq ScalafmtOptimizer.PruneSlowStates.Yes)
if (ok) Some(
new StateStats(tokens, runner, ScalafmtOptimizer.PruneSlowStates.No),
)
else None
}
}

}
Loading