From c4aa98aad3ffd73fa5ae2e30ad4e8981d00e96db Mon Sep 17 00:00:00 2001
From: David Vega Lichacz <7826728+realdavidvega@users.noreply.github.com>
Date: Tue, 13 Aug 2024 17:57:14 +0200
Subject: [PATCH] feat: expose cache as a map of filtered seed
---
.../xef/llm/assistants/CachedTool.kt | 27 +++++++++++++++++--
1 file changed, 25 insertions(+), 2 deletions(-)
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
index 2585f57d4..3b1451333 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
@@ -19,10 +19,33 @@ abstract class CachedTool(
return cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
}
+ /**
+ * Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
+ * [timeCachePolicy]. Removes expired cache entries.
+ */
+ suspend fun getCache(): Map {
+ val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
+ val withoutExpired =
+ cache.modify { cachedToolInfo ->
+ // Filter entries belonging to the current seed and have not expired
+ val validEntries =
+ cachedToolInfo
+ .filter { (key, value) ->
+ if (key.seed == seed) lastTimeInCache <= value.timestamp else true
+ }
+ .toMutableMap()
+ // Remove expired entries for the current seed only
+ cachedToolInfo.keys.removeAll { key -> key.seed == seed && !validEntries.containsKey(key) }
+ // Modifies state A, and returns state B
+ Pair(cachedToolInfo, validEntries)
+ }
+ return withoutExpired.map { it.key.value to it.value.value }.toMap()
+ }
+
abstract suspend fun onCacheMissed(input: Input): Output
private suspend fun cache(input: CachedToolKey, block: suspend () -> Output): Output {
- val cachedToolInfo = cache.get().get(input)
+ val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
if (lastTimeInCache > cachedToolInfo.timestamp) {
@@ -32,7 +55,7 @@ abstract class CachedTool(
}
}
val response = block()
- cache.get().put(input, CachedToolValue(response, timeInMillis()))
+ cache.get()[input] = CachedToolValue(response, timeInMillis())
return response
}
}