Skip to content

Commit

Permalink
Merge pull request #555 from lichess-org/eval-cache-2
Browse files Browse the repository at this point in the history
use normalized BinaryFen as evalCache id (lichess-org/lila#15020)
  • Loading branch information
ornicar authored Apr 15, 2024
2 parents 9895673 + 3edd1b4 commit e8dfd33
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 108 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ val arch_ = arch.replace("-", "_")
val pekkoVersion = "1.0.2"
val kamonVersion = "2.7.1"
val nettyVersion = "4.1.108.Final"
val chessVersion = "16.0.0"
val chessVersion = "16.0.3"

lazy val `lila-ws` = project
.in(file("."))
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/Mongo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ final class Mongo(config: Config)(using Executor) extends MongoHandlers:
def relayTourColl = collNamed("relay_tour")
def relayRoundColl = collNamed("relay")
def studyColl = studyDb.map(_.collection("study"))(parasitic)
def evalCacheColl = yoloDb.map(_.collection("eval_cache"))(parasitic)
def evalCacheColl = yoloDb.map(_.collection("eval_cache2"))(parasitic)

def isDuplicateKey(wr: WriteResult) = wr.code.contains(11000)
def ignoreDuplicateKey: PartialFunction[Throwable, Unit] =
Expand Down Expand Up @@ -179,13 +179,13 @@ final class Mongo(config: Config)(using Executor) extends MongoHandlers:
}
.map(_.getOrElse(Set.empty))

import evalCache.EvalCacheEntry
def evalCacheEntry(id: EvalCacheEntry.Id): Future[Option[EvalCacheEntry]] =
import evalCache.{ Id, EvalCacheEntry }
def evalCacheEntry(id: Id): Future[Option[EvalCacheEntry]] =
import evalCache.EvalCacheBsonHandlers.given
evalCacheColl.flatMap:
_.find(selector = BSONDocument("_id" -> id))
.one[EvalCacheEntry]
def evalCacheUsedNow(id: EvalCacheEntry.Id): Unit =
def evalCacheUsedNow(id: Id): Unit =
import evalCache.EvalCacheBsonHandlers.given
evalCacheColl.foreach:
_.update(ordered = false, writeConcern = WriteConcern.Unacknowledged)
Expand Down
33 changes: 18 additions & 15 deletions src/main/scala/evalCache/EvalCacheApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,24 @@ final class EvalCacheApi(mongo: Mongo)(using
import EvalCacheBsonHandlers.given

def get(sri: Sri, e: EvalGet, emit: Emit[ClientIn]): Unit =
getEntry(Id.make(e.variant, e.fen))
.map:
_.flatMap(_.makeBestMultiPvEval(e.multiPv))
.map(monitorRequest(e.fen, Monitor.evalCache.single))
.foreach:
_.foreach: eval =>
emit:
ClientIn.EvalHit:
EvalCacheJsonHandlers.writeEval(eval, e.fen) + ("path" -> JsString(e.path.value))
Id.from(e.variant, e.fen)
.foreach: id =>
getEntry(id)
.map:
_.flatMap(_.makeBestMultiPvEval(e.multiPv))
.map(monitorRequest(e.fen, Monitor.evalCache.single))
.foreach:
_.foreach: eval =>
emit:
ClientIn.EvalHit:
EvalCacheJsonHandlers.writeEval(eval, e.fen) + ("path" -> JsString(e.path.value))
if e.up then upgrade.register(sri, e)

def getMulti(sri: Sri, e: EvalGetMulti, emit: Emit[ClientIn]): Unit =
e.fens
.traverse: fen =>
getEntry(Id.make(e.variant, fen))
.flatMap(fen => Id.from(e.variant, fen).map(fen -> _))
.traverse: (fen, id) =>
getEntry(id)
.map:
_.flatMap(_.makeBestSinglePvEval).map(fen -> _)
.map(monitorRequest(fen, Monitor.evalCache.multi))
Expand Down Expand Up @@ -92,8 +95,6 @@ final class EvalCacheApi(mongo: Mongo)(using
res

private def putTrusted(sri: Sri, user: User.Id, input: Input): Future[Unit] =
def destSize(fen: Fen.Full): Int =
chess.Game(chess.variant.Standard.some, fen.some).situation.moves.view.map(_._2.size).sum
mongo.evalCacheColl.flatMap: c =>
EvalCacheValidator(input) match
case Left(error) =>
Expand All @@ -104,7 +105,7 @@ final class EvalCacheApi(mongo: Mongo)(using
case None =>
val entry = EvalCacheEntry(
_id = input.id,
nbMoves = destSize(input.fen),
nbMoves = input.situation.moves.view.map(_._2.size).sum,
evals = List(input.eval),
usedAt = LocalDateTime.now,
updatedAt = LocalDateTime.now
Expand All @@ -130,4 +131,6 @@ final class EvalCacheApi(mongo: Mongo)(using
private object EvalCacheValidator:

def apply(in: EvalCacheEntry.Input): Either[ErrorStr, Unit] =
in.eval.pvs.traverse_(pv => chess.Replay.boardsFromUci(pv.moves.value.toList, in.fen.some, in.id.variant))
in.eval.pvs.traverse_ { pv =>
chess.Replay.boardsFromUci(pv.moves.value.toList, in.fen.some, in.situation.variant)
}
39 changes: 8 additions & 31 deletions src/main/scala/evalCache/EvalCacheBsonHandlers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package evalCache

import cats.data.NonEmptyList
import cats.syntax.all.*
import chess.format.Uci
import chess.format.{ BinaryFen, Uci }
import reactivemongo.api.bson.*
import reactivemongo.api.bson.exceptions.TypeDoesNotMatchException

Expand Down Expand Up @@ -61,38 +61,15 @@ object EvalCacheBsonHandlers:

private def handlerBadType[T](b: BSONValue): Try[T] =
Failure(TypeDoesNotMatchException("BSONValue", b.getClass.getSimpleName))
private def handlerBadValue[T](msg: String): Try[T] = Failure(new IllegalArgumentException(msg))

private def tryHandler[T](read: PartialFunction[BSONValue, Try[T]], write: T => BSONValue): BSONHandler[T] =
new:
def readTry(bson: BSONValue) = read.applyOrElse(bson, (b: BSONValue) => handlerBadType(b))
def writeTry(t: T) = Success(write(t))
given binaryFenHandler: BSONHandler[BinaryFen] = new:
def readTry(bson: BSONValue) =
bson match
case v: BSONBinary => Success(BinaryFen(v.byteArray))
case _ => handlerBadType(bson)
def writeTry(v: BinaryFen) = Success(BSONBinary(v.value, Subtype.GenericBinarySubtype))

given BSONHandler[Id] = tryHandler[Id](
{ case BSONString(value) =>
value.split(':') match
case Array(fen) => Success(Id(chess.variant.Standard, SmallFen(fen)))
case Array(variantId, fen) =>
import chess.variant.Variant
Success(
Id(
Variant.Id
.from(variantId.toIntOption)
.flatMap {
Variant(_)
}
.getOrElse(sys.error(s"Invalid evalcache variant $variantId")),
SmallFen(fen)
)
)
case _ => handlerBadValue(s"Invalid evalcache id $value")
},
x =>
BSONString {
if x.variant.standard || x.variant.fromPosition then x.smallFen.value
else s"${x.variant.id}:${x.smallFen.value}"
}
)
given BSONHandler[Id] = binaryFenHandler.as[Id](Id.apply, _.value)

given BSONDocumentHandler[Eval] = Macros.handler
given BSONDocumentHandler[EvalCacheEntry] = Macros.handler
31 changes: 18 additions & 13 deletions src/main/scala/evalCache/EvalCacheEntry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@ package lila.ws
package evalCache

import cats.data.NonEmptyList
import chess.format.{ Fen, Uci }
import chess.Situation
import chess.format.{ BinaryFen, Fen, Uci }
import chess.variant.Variant

import java.time.LocalDateTime

import Eval.Score

opaque type Id = BinaryFen
object Id:
def apply(fen: BinaryFen): Id = fen
def apply(situation: Situation): Id = BinaryFen.writeNormalized(situation)
def from(variant: Variant, fen: Fen.Full): Option[Id] =
Fen.read(variant, fen).map(BinaryFen.writeNormalized)
extension (id: Id) def value: BinaryFen = id

case class EvalCacheEntry(
_id: EvalCacheEntry.Id,
_id: Id,
nbMoves: Int, // multipv cannot be greater than number of legal moves
evals: List[EvalCacheEntry.Eval], // best ones first, by depth and nodes
usedAt: LocalDateTime,
Expand Down Expand Up @@ -43,6 +52,8 @@ case class EvalCacheEntry(

object EvalCacheEntry:

case class Input(id: Id, fen: Fen.Full, situation: Situation, eval: Eval)

case class Eval(pvs: NonEmptyList[Pv], knodes: Knodes, depth: Depth, by: User.Id, trust: Trust):

def multiPv = MultiPv(pvs.size)
Expand Down Expand Up @@ -72,16 +83,10 @@ object EvalCacheEntry:

def truncate = copy(moves = Moves.truncate(moves))

case class Id(variant: Variant, smallFen: SmallFen)
object Id:
def make(variant: Variant, fen: Fen.Full): Id =
Id(variant, SmallFen.make(variant, fen.simple))

case class Input(id: Id, fen: Fen.Full, eval: Eval)

def makeInput(variant: Variant, fen: Fen.Full, eval: Eval) =
SmallFen
.validate(variant, fen)
Fen
.read(variant, fen)
.filter(_.playable(false))
.ifTrue(eval.looksValid)
.map: smallFen =>
Input(Id(variant, smallFen), fen, eval.truncatePvs)
.map: situation =>
Input(Id(situation), fen, situation, eval.truncatePvs)
20 changes: 8 additions & 12 deletions src/main/scala/evalCache/EvalCacheMulti.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ final private class EvalCacheMulti(using
scheduler: org.apache.pekko.actor.typed.Scheduler
):
import EvalCacheMulti.*
import EvalCacheUpgrade.{ EvalState, SetupId, SriString }
import EvalCacheUpgrade.{ EvalState, SriString }

private val members = ConcurrentHashMap[SriString, WatchingMember](4096)
private val evals = ConcurrentHashMap[SetupId, EvalState](1024)
private val evals = ConcurrentHashMap[Id, EvalState](1024)
private val expirableSris = ExpireCallbackMemo[Sri](scheduler, 1 minute, expire)

private val upgradeMon = Monitor.evalCache.multi.upgrade
Expand All @@ -36,15 +36,14 @@ final private class EvalCacheMulti(using
WatchingMember(sri, e.variant, e.fens)
)
.setups
.foreach: setupId =>
evals.compute(setupId, (_, prev) => Option(prev).fold(EvalState(Set(sri), Depth(0)))(_.addSri(sri)))
.foreach: id =>
evals.compute(id, (_, prev) => Option(prev).fold(EvalState(Set(sri), Depth(0)))(_.addSri(sri)))
expirableSris.put(sri)

def onEval(input: EvalCacheEntry.Input, fromSri: Sri): Unit =
val setupId = makeSetupId(input.id.variant, input.fen)
Option(
evals.computeIfPresent(
setupId,
input.id,
(_, ev) =>
if ev.depth >= input.eval.depth then ev
else ev.copy(depth = input.eval.depth)
Expand All @@ -63,9 +62,9 @@ final private class EvalCacheMulti(using
Option(members.remove(sri.value)).foreach:
_.setups.foreach(unregisterEval(_, sri))

private def unregisterEval(setupId: SetupId, sri: Sri): Unit =
private def unregisterEval(id: Id, sri: Sri): Unit =
evals.computeIfPresent(
setupId,
id,
(_, eval) =>
val newSris = eval.sris - sri
if newSris.isEmpty then null
Expand All @@ -81,8 +80,5 @@ private object EvalCacheMulti:

import EvalCacheUpgrade.*

def makeSetupId(variant: Variant, fen: Fen.Full): SetupId =
s"${variant.id}${SmallFen.make(variant, fen.simple)}"

case class WatchingMember(sri: Sri, variant: Variant, fens: List[Fen.Full]):
def setups: List[SetupId] = fens.map(makeSetupId(variant, _))
def setups: List[Id] = fens.flatMap(Id.from(variant, _))
35 changes: 18 additions & 17 deletions src/main/scala/evalCache/EvalCacheUpgrade.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package lila.ws
package evalCache

import chess.format.{ Fen, UciPath }
import chess.variant.Variant
import chess.format.UciPath
import play.api.libs.json.JsString

import java.util.concurrent.ConcurrentHashMap
Expand All @@ -28,20 +27,24 @@ final private class EvalCacheUpgrade(using
private val upgradeMon = Monitor.evalCache.single.upgrade

def register(sri: Sri, e: EvalGet): Unit =
members.compute(
sri.value,
(_, prev) =>
Option(prev).foreach: member =>
unregisterEval(member.setupId, sri)
val setupId = makeSetupId(e.variant, e.fen, e.multiPv)
evals.compute(setupId, (_, eval) => Option(eval).fold(EvalState(Set(sri), Depth(0)))(_.addSri(sri)))
WatchingMember(sri, setupId, e.path)
)
expirableSris.put(sri)
Id
.from(e.variant, e.fen)
.foreach: entryId =>
members.compute(
sri.value,
(_, prev) =>
Option(prev).foreach: member =>
unregisterEval(member.setupId, sri)
val setupId = SetupId(entryId, e.multiPv)
evals
.compute(setupId, (_, eval) => Option(eval).fold(EvalState(Set(sri), Depth(0)))(_.addSri(sri)))
WatchingMember(sri, setupId, e.path)
)
expirableSris.put(sri)

def onEval(input: EvalCacheEntry.Input, fromSri: Sri): Unit =
(1 to input.eval.multiPv.value).foreach: multiPv =>
val setupId = makeSetupId(input.id.variant, input.fen, MultiPv(multiPv))
val setupId = SetupId(input.id, MultiPv(multiPv))
Option(
evals.computeIfPresent(
setupId,
Expand Down Expand Up @@ -82,12 +85,10 @@ final private class EvalCacheUpgrade(using
private object EvalCacheUpgrade:

type SriString = String
type SetupId = String

case class SetupId(entryId: Id, multiPv: MultiPv)

case class EvalState(sris: Set[Sri], depth: Depth):
def addSri(sri: Sri) = copy(sris = sris + sri)

def makeSetupId(variant: Variant, fen: Fen.Full, multiPv: MultiPv): SetupId =
s"${variant.id}${SmallFen.make(variant, fen.simple)}^$multiPv"

case class WatchingMember(sri: Sri, setupId: SetupId, path: UciPath)
15 changes: 1 addition & 14 deletions src/main/scala/evalCache/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package lila.ws
package evalCache

import cats.data.NonEmptyList
import chess.format.{ Fen, Uci }
import chess.variant.Variant
import chess.format.Uci

val MIN_KNODES = Knodes(3000)
val MIN_DEPTH = Depth(20)
Expand All @@ -26,15 +25,3 @@ object Moves extends TotalWrapper[Moves, NonEmptyList[Uci]]:
opaque type Trust = Double
object Trust extends OpaqueDouble[Trust]:
extension (a: Trust) def isEnough = a > -1

opaque type SmallFen = String
object SmallFen extends OpaqueString[SmallFen]:
def make(variant: Variant, fen: Fen.Simple): SmallFen =
val base = fen.value.split(' ').take(4).mkString("").filter { c =>
c != '/' && c != '-' && c != 'w'
}
if variant == chess.variant.ThreeCheck
then fen.value.split(' ').lift(6).foldLeft(base)(_ + _)
else base
def validate(variant: Variant, fen: Fen.Full): Option[SmallFen] =
Fen.read(variant, fen).exists(_.playable(false)).option(make(variant, fen.simple))
1 change: 0 additions & 1 deletion src/main/scala/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ type ClientBehavior = Behavior[ipc.ClientMsg]
type Client = ActorRef[ipc.ClientMsg]
type ClientEmit = Emit[ipc.ClientIn]

def nowSeconds: Int = (System.currentTimeMillis() / 1000).toInt
val startedAtMillis = System.currentTimeMillis()

0 comments on commit e8dfd33

Please sign in to comment.