diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt index 61cfdee57..e7402af5a 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt @@ -5,6 +5,29 @@ import arrow.fx.coroutines.timeInMillis import kotlin.time.Duration import kotlin.time.Duration.Companion.days +sealed interface CachedToolEvent { + /** Fired when a new entry is added to the cache. */ + data class Created( + val key: CachedToolKey, + val value: CachedToolValue, + val mapSize: Int + ) : CachedToolEvent + + /** Fired when an entry is found in the cache, but its value has been updated. */ + data class Updated( + val key: CachedToolKey, + val oldValue: CachedToolValue, + val newValue: CachedToolValue + ) : CachedToolEvent + + /** Fired when an expired event has been removed from the cache. */ + data class Evicted(val key: CachedToolKey, val value: CachedToolValue) : + CachedToolEvent + + /** Fired when all expired entries have been removed from the cache. */ + data class ExpiredPurged(val mapSize: Int, val removedEntries: Int) : CachedToolEvent +} + data class CachedToolKey(val value: K, val seed: String) data class CachedToolValue(val value: V, val accessTimestamp: Long, val writeTimestamp: Long) { @@ -73,6 +96,9 @@ abstract class CachedTool( */ abstract suspend fun onCacheMissed(input: Input): Output + /** Invoked on [CachedToolEvent] firing, after every cache mutation. */ + open fun onCacheEvent(event: CachedToolEvent) = Unit + /** * Criteria to check if the cache should be used for the given [input]. By default, it returns * true, meaning always use the cache if available. @@ -108,7 +134,7 @@ abstract class CachedTool( suspend fun getValidCacheSnapshot(): Map { val validEntries = cache.modify { cachedToolInfo -> - val validEntries = cachedToolInfo.filterExpired().filter { (key, _) -> key.seed == seed } + val validEntries = cachedToolInfo.purgeExpired().filter { (key, _) -> key.seed == seed } Pair(cachedToolInfo, validEntries) } return validEntries.map { it.key.value to it.value.value }.toMap() @@ -120,12 +146,22 @@ abstract class CachedTool( if (output.isExpired()) { val updatedCache = when (config.cacheEvictionPolicy) { - CachedToolConfig.CacheEvictionPolicy.SINGLE -> cachedToolInfo.apply { remove(input) } - CachedToolConfig.CacheEvictionPolicy.ALL -> cachedToolInfo.filterExpired() + CachedToolConfig.CacheEvictionPolicy.SINGLE -> + cachedToolInfo.apply { + remove(input)?.also { removedOutput -> + onCacheEvent(CachedToolEvent.Evicted(input, removedOutput)) + } + } + CachedToolConfig.CacheEvictionPolicy.ALL -> + cachedToolInfo.purgeExpired(sendCacheEvents = true).also { purged -> + val removedEntries = cachedToolInfo.size - purged.size + onCacheEvent(CachedToolEvent.ExpiredPurged(cachedToolInfo.size, removedEntries)) + } } Pair(updatedCache, null) } else { val updatedOutput = output.withAccessTimestamp() + onCacheEvent(CachedToolEvent.Updated(input, output, updatedOutput)) Pair(cachedToolInfo, updatedOutput.value) } } ?: Pair(cachedToolInfo, null) @@ -134,15 +170,24 @@ abstract class CachedTool( val response = block() if (shouldCacheOutput(input.value, response)) { cache.update { cachedToolInfo -> - cachedToolInfo[input] = CachedToolValue.withActualResponse(response) + val output = CachedToolValue.withActualResponse(response) + cachedToolInfo[input] = output + onCacheEvent(CachedToolEvent.Created(input, output, cachedToolInfo.size)) cachedToolInfo } } response } - private fun MutableMap, CachedToolValue>.filterExpired() = - this.filter { (_, value) -> !value.isExpired() }.toMutableMap() + private fun MutableMap, CachedToolValue>.purgeExpired( + sendCacheEvents: Boolean = false + ) = + this.filterNot { (key, value) -> + val expired = value.isExpired() + if (sendCacheEvents && expired) onCacheEvent(CachedToolEvent.Evicted(key, value)) + expired + } + .toMutableMap() private fun CachedToolValue.isExpired(): Boolean = when (config.cacheExpirationPolicy) {