From c4aa98aad3ffd73fa5ae2e30ad4e8981d00e96db Mon Sep 17 00:00:00 2001 From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:57:14 +0200 Subject: [PATCH] feat: expose cache as a map of filtered seed --- .../xef/llm/assistants/CachedTool.kt | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt index 2585f57d4..3b1451333 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt @@ -19,10 +19,33 @@ abstract class CachedTool( return cache(CachedToolKey(input, seed)) { onCacheMissed(input) } } + /** + * Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and + * [timeCachePolicy]. Removes expired cache entries. + */ + suspend fun getCache(): Map { + val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds + val withoutExpired = + cache.modify { cachedToolInfo -> + // Filter entries belonging to the current seed and have not expired + val validEntries = + cachedToolInfo + .filter { (key, value) -> + if (key.seed == seed) lastTimeInCache <= value.timestamp else true + } + .toMutableMap() + // Remove expired entries for the current seed only + cachedToolInfo.keys.removeAll { key -> key.seed == seed && !validEntries.containsKey(key) } + // Modifies state A, and returns state B + Pair(cachedToolInfo, validEntries) + } + return withoutExpired.map { it.key.value to it.value.value }.toMap() + } + abstract suspend fun onCacheMissed(input: Input): Output private suspend fun cache(input: CachedToolKey, block: suspend () -> Output): Output { - val cachedToolInfo = cache.get().get(input) + val cachedToolInfo = cache.get()[input] if (cachedToolInfo != null) { val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds if (lastTimeInCache > cachedToolInfo.timestamp) { @@ -32,7 +55,7 @@ abstract class CachedTool( } } val response = block() - cache.get().put(input, CachedToolValue(response, timeInMillis())) + cache.get()[input] = CachedToolValue(response, timeInMillis()) return response } }